-
Notifications
You must be signed in to change notification settings - Fork 178
Expand file tree
/
Copy pathmod.rs
More file actions
115 lines (98 loc) · 3.54 KB
/
mod.rs
File metadata and controls
115 lines (98 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
pub mod builtin;
pub(super) mod external_function;
pub(super) mod memref;
pub mod mlir_data;
pub mod mlir_engine;
pub mod module;
pub mod passes;
pub(super) mod visitor;
use cubecl_common::backtrace::BackTrace;
use cubecl_runtime::compiler::CompilationError;
use passes::shared_memories::SharedMemories;
pub use visitor::elem::register_supported_types;
use cubecl_core::{
Compiler,
ir::{self},
post_processing::{
checked_io::CheckedIoProcessor, predicate::PredicateProcessor,
saturating::SaturatingArithmeticProcessor,
},
prelude::KernelDefinition,
server::ExecutionMode,
};
use cubecl_opt::{OptimizerBuilder, SharedLiveness};
use mlir_engine::MlirEngine;
use crate::compiler::passes::{
erf_transform::ErfTransform,
trigonometries_transform::{HypotTransform, RhypotTransform},
};
#[derive(Clone, Debug, Default)]
pub struct MlirCompiler {}
#[derive(Default, Debug)]
pub struct MlirCompilerOptions {}
impl Compiler for MlirCompiler {
type Representation = MlirEngine;
type CompilationOptions = MlirCompilerOptions;
fn compile(
&mut self,
mut kernel: KernelDefinition,
_compilation_options: &Self::CompilationOptions, // TODO pass this through the visitor, though it doesn't need anything for the moment
mode: ExecutionMode, // TODO support this by adding array bound checking
) -> Result<Self::Representation, CompilationError> {
let errors = kernel.body.pop_errors();
if !errors.is_empty() {
let mut reason = "Can't compile mlir kernel".to_string();
for error in errors {
reason += error.as_str();
reason += "\n";
}
return Err(CompilationError::Validation {
reason,
backtrace: BackTrace::capture(),
});
}
#[cfg(feature = "mlir-dump")]
dump_scope(&kernel.body, &kernel.options.kernel_name);
let mut opt = OptimizerBuilder::default()
.with_transformer(ErfTransform)
.with_transformer(HypotTransform)
.with_transformer(RhypotTransform)
.with_processor(CheckedIoProcessor::new(mode))
.with_processor(SaturatingArithmeticProcessor::new(true))
.with_processor(PredicateProcessor)
.optimize(kernel.body.clone(), kernel.cube_dim);
let shared_memories = SharedMemories::from_liveness(&opt.analysis::<SharedLiveness>());
#[cfg(feature = "mlir-dump")]
dump_opt(&opt, &kernel.options.kernel_name);
Ok(MlirEngine::from_cubecl_ir(kernel, &opt, shared_memories))
}
fn elem_size(&self, elem: ir::ElemType) -> usize {
elem.size()
}
fn extension(&self) -> &'static str {
"mlir"
}
}
#[cfg(feature = "mlir-dump")]
fn dump_scope(scope: &cubecl_core::prelude::Scope, name: &str) {
use std::fs;
if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
let path = format!("{dir}/{name}");
let _ = fs::create_dir(&path);
fs::write(format!("{path}/cubecl.ir.txt"), format!("{}", scope)).unwrap();
}
}
#[cfg(feature = "mlir-dump")]
fn dump_opt(opt: &cubecl_opt::Optimizer, name: &str) {
if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
use std::fs;
let path = format!("{dir}/{name}");
let _ = fs::create_dir(&path);
fs::write(format!("{path}/cubecl-opt.ir.txt"), format!("{}", opt)).unwrap();
fs::write(
format!("{path}/cubecl-opt.ir.dot"),
format!("{}", opt.dot_viz()),
)
.unwrap();
}
}