Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions crates/cubecl-cpu/src/compiler/memref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,18 @@ impl LineMemRef {
stride: [1],
}
}

/// Create a LineMemRef from a raw pointer and length.
/// # Safety
/// The pointer must be valid and point to at least `len` bytes of writable memory.
pub unsafe fn from_raw_parts(pointer: *mut u8, len: usize) -> Self {
let pointer = pointer as *mut c_void;
Self {
allocated: pointer,
aligned: pointer,
offset: 0,
shape: [len as c_longlong],
stride: [1],
}
}
}
40 changes: 15 additions & 25 deletions crates/cubecl-cpu/src/compiler/mlir_data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::passes::shared_memories::SharedMemories;
use crate::{
compiler::{builtin::BuiltinArray, memref::LineMemRef, passes::shared_memories::SharedMemory},
compiler::{builtin::BuiltinArray, memref::LineMemRef},
compute::schedule::BindingsResource,
};
use cubecl_common::stream_id::StreamId;
Expand Down Expand Up @@ -86,33 +86,23 @@ impl MlirData {
}

let stream_id = StreamId::current();
let mut smem_handles = Vec::with_capacity(shared_memories.0.len());
for shared_memory in shared_memories.0.iter() {
let (handle, length) = match shared_memory {
SharedMemory::Array { ty, length, .. } => {
let length = (ty.size() * *length as usize) as u64;
let handle = memory_management_shared_memory.reserve(length).unwrap();
(handle, length)
}
SharedMemory::Value { ty, .. } => {
let length = ty.size() as u64;
let handle = memory_management_shared_memory.reserve(length).unwrap();
(handle, length)
}
};

smem_handles.push(handle.clone());

let b = Handle::new(handle, None, None, stream_id, 0, length).binding();
let mut handle = memory_management_shared_memory
if let Some(smem_size) = shared_memories.size() {
let handle = memory_management_shared_memory.reserve(smem_size).unwrap();
let b = Handle::new(handle.clone(), None, None, stream_id, 0, smem_size).binding();
let mut resource = memory_management_shared_memory
.get_resource(b.memory, b.offset_start, b.offset_end)
.expect("Failed to find resource");
let ptr = handle.write();
let line_memref = LineMemRef::new(ptr);
push_undirected(line_memref);

let smem_pool_ptr = resource.write().as_mut_ptr();
for shared_memory in shared_memories.0.iter() {
// Compute pointer into the pool at the appropriate offset
let offset = shared_memory.offset() as usize;
let size = shared_memory.size() as usize;
let ptr = unsafe { smem_pool_ptr.add(offset) };
let line_memref = unsafe { LineMemRef::from_raw_parts(ptr, size) };
push_undirected(line_memref);
}
}
// It is important to make sure multiple shared memories don't shared the same handle.
core::mem::drop(smem_handles);

let ptr = shared_mlir_data.metadata.as_mut();
let line_memref = LineMemRef::new(ptr);
Expand Down
7 changes: 3 additions & 4 deletions crates/cubecl-cpu/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use cubecl_core::{
prelude::KernelDefinition,
server::ExecutionMode,
};
use cubecl_opt::OptimizerBuilder;
use cubecl_opt::{OptimizerBuilder, SharedLiveness};
use mlir_engine::MlirEngine;

use crate::compiler::passes::{
Expand Down Expand Up @@ -63,7 +63,7 @@ impl Compiler for MlirCompiler {

#[cfg(feature = "mlir-dump")]
dump_scope(&kernel.body, &kernel.options.kernel_name);
let opt = OptimizerBuilder::default()
let mut opt = OptimizerBuilder::default()
.with_transformer(ErfTransform)
.with_transformer(HypotTransform)
.with_transformer(RhypotTransform)
Expand All @@ -72,8 +72,7 @@ impl Compiler for MlirCompiler {
.with_processor(PredicateProcessor)
.optimize(kernel.body.clone(), kernel.cube_dim);

let mut shared_memories = SharedMemories::default();
shared_memories.visit(&opt);
let shared_memories = SharedMemories::from_liveness(&opt.analysis::<SharedLiveness>());

#[cfg(feature = "mlir-dump")]
dump_opt(&opt, &kernel.options.kernel_name);
Expand Down
92 changes: 45 additions & 47 deletions crates/cubecl-cpu/src/compiler/passes/shared_memories.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use cubecl_core::ir::{OperationReflect, StorageType, Variable, VariableKind};
use cubecl_opt::Optimizer;
use cubecl_core::ir::Type;
use cubecl_opt::SharedLiveness;

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum SharedMemory {
Array {
id: u32,
ty: StorageType,
// Length include the vectorization factor
ty: Type,
// Length includes unroll_factor; vectorization is in ty.size()
length: u32,
offset: u32,
},
Value {
id: u32,
ty: StorageType,
ty: Type,
offset: u32,
},
}

Expand All @@ -22,55 +24,51 @@ impl SharedMemory {
SharedMemory::Value { id, .. } => *id,
}
}

pub fn offset(&self) -> u32 {
match self {
SharedMemory::Array { offset, .. } => *offset,
SharedMemory::Value { offset, .. } => *offset,
}
}

pub fn size(&self) -> u32 {
match self {
SharedMemory::Array { ty, length, .. } => *length * ty.size() as u32,
SharedMemory::Value { ty, .. } => ty.size() as u32,
}
}
}

#[derive(Default)]
pub struct SharedMemories(pub Vec<SharedMemory>);

impl SharedMemories {
pub fn visit_variable(&mut self, variable: Variable) {
// Alignment is ignored for the moment it is taken from the type
match variable.kind {
VariableKind::SharedArray { id, length, .. } => {
if self.0.iter().all(|shared_memory| shared_memory.id() != id) {
let elem = variable.storage_type();
let vectorization = variable.line_size();
let length = length * vectorization;
self.0.push(SharedMemory::Array {
id,
ty: elem,
length,
});
}
}
VariableKind::Shared { id } => {
if self.0.iter().all(|shared_memory| shared_memory.id() != id) {
let elem = variable.storage_type();
self.0.push(SharedMemory::Value { id, ty: elem });
}
}
_ => {}
}
/// Build from the [SharedLiveness] analysis so non-overlapping lifetimes can reuse memory.
pub fn from_liveness(shared_liveness: &SharedLiveness) -> Self {
let mut memories: Vec<SharedMemory> = shared_liveness
.allocations
.values()
.map(|alloc| match alloc.smem {
cubecl_opt::SharedMemory::Array { id, length, ty, .. } => SharedMemory::Array {
id,
ty,
length,
offset: alloc.offset,
},
cubecl_opt::SharedMemory::Value { id, ty, .. } => SharedMemory::Value {
id,
ty,
offset: alloc.offset,
},
})
.collect();

memories.sort_by_key(|m| m.id());
Self(memories)
}
pub fn visit(&mut self, opt: &Optimizer) {
for node in opt.program.node_indices().collect::<Vec<_>>() {
let phi = opt.program[node].phi_nodes.clone();
let ops = opt.program[node].ops.clone();

for phi in phi.borrow_mut().iter_mut() {
self.visit_variable(phi.out);
}
for op in ops.borrow_mut().values_mut() {
if let Some(out) = op.out {
self.visit_variable(out);
}
if let Some(args) = op.operation.args() {
for arg in args {
self.visit_variable(arg);
}
}
}
}
self.0.sort_by_key(|a| a.id());
pub fn size(&self) -> Option<u64> {
self.0.iter().map(|m| (m.offset() + m.size()) as u64).max()
}
}
Loading