-
Notifications
You must be signed in to change notification settings - Fork 178
Expand file tree
/
Copy pathlauncher.rs
More file actions
317 lines (277 loc) · 9.87 KB
/
launcher.rs
File metadata and controls
317 lines (277 loc) · 9.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
use std::{collections::BTreeMap, marker::PhantomData};
use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
use crate::{CubeScalar, KernelSettings};
use crate::{MetadataBuilder, Runtime};
#[cfg(feature = "std")]
use core::cell::RefCell;
#[cfg(not(feature = "std"))]
use cubecl_common::stub::{Arc, Lazy, Mutex};
use cubecl_ir::{AddressType, StorageType};
use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
use cubecl_runtime::{
client::ComputeClient,
kernel::{CubeKernel, KernelTask},
server::Bindings,
};
/// Prepare a kernel for [launch](KernelLauncher::launch).
pub struct KernelLauncher<R: Runtime> {
tensors: TensorState<R>,
scalars: ScalarState,
pub settings: KernelSettings,
runtime: PhantomData<R>,
}
impl<R: Runtime> KernelLauncher<R> {
/// Register a tensor to be launched.
pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
self.tensors.push_tensor(tensor);
}
/// Register a mapped tensor to be launched.
pub fn register_tensor_map<K: TensorMapKind>(&mut self, tensor: &TensorMapArg<'_, R, K>) {
self.tensors.push_tensor_map(tensor);
}
/// Register an input array to be launched.
pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
self.tensors.push_array(array);
}
/// Register a scalar to be launched.
pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
self.scalars.push(scalar);
}
/// Register a scalar to be launched from raw data.
pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
self.scalars.push_raw(bytes, dtype);
}
/// Launch the kernel.
#[track_caller]
pub fn launch<K: CubeKernel>(
self,
cube_count: CubeCount,
kernel: K,
client: &ComputeClient<R>,
) -> Result<(), LaunchError> {
let bindings = self.into_bindings();
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.launch(kernel, cube_count, bindings)
}
/// Launch the kernel without check bounds.
///
/// # Safety
///
/// The kernel must not:
/// - Contain any out of bounds reads or writes. Doing so is immediate UB.
/// - Contain any loops that never terminate. These may be optimized away entirely or cause
/// other unpredictable behaviour.
#[track_caller]
pub unsafe fn launch_unchecked<K: CubeKernel>(
self,
cube_count: CubeCount,
kernel: K,
client: &ComputeClient<R>,
) -> Result<(), LaunchError> {
unsafe {
let bindings = self.into_bindings();
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.launch_unchecked(kernel, cube_count, bindings)
}
}
/// We need to create the bindings in the same order they are defined in the compilation step.
///
/// The function [`crate::KernelIntegrator::integrate`] stars by registering the input tensors followed
/// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
/// are registered in the same order they are added. This is why we store the scalar data type
/// in the `scalar_order` vector, so that we can register them in the same order.
///
/// Also returns an ordered list of constant bindings. The ordering between constants and tensors
/// is up to the runtime.
fn into_bindings(self) -> Bindings {
let mut bindings = Bindings::new();
self.tensors.register(&mut bindings);
self.scalars.register(&mut bindings);
bindings
}
}
#[cfg(feature = "std")]
thread_local! {
static METADATA: RefCell<MetadataBuilder> = RefCell::new(MetadataBuilder::default());
}
#[cfg(feature = "std")]
fn with_metadata<R>(fun: impl FnMut(&mut MetadataBuilder) -> R) -> R {
METADATA.with_borrow_mut(fun)
}
#[cfg(not(feature = "std"))]
static METADATA: Lazy<Arc<Mutex<MetadataBuilder>>> =
Lazy::new(|| Arc::new(Mutex::new(MetadataBuilder::default())));
#[cfg(not(feature = "std"))]
fn with_metadata<R>(mut fun: impl FnMut(&mut MetadataBuilder) -> R) -> R {
let mut metadata = METADATA.lock().unwrap();
fun(&mut *metadata)
}
/// Handles the tensor state.
pub enum TensorState<R: Runtime> {
/// No tensor is registered yet.
Empty { addr_type: AddressType },
/// The registered tensors.
Some {
buffers: Vec<Binding>,
tensor_maps: Vec<TensorMapBinding>,
addr_type: AddressType,
runtime: PhantomData<R>,
},
}
/// Handles the scalar state of an element type
///
/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
#[derive(Default, Clone)]
pub struct ScalarState {
data: BTreeMap<StorageType, ScalarValues>,
}
/// Stores the data and type for a scalar arg
pub type ScalarValues = Vec<u8>;
impl<R: Runtime> TensorState<R> {
fn maybe_init(&mut self) {
if let TensorState::Empty { addr_type } = self {
*self = TensorState::Some {
buffers: Vec::new(),
tensor_maps: Vec::new(),
addr_type: *addr_type,
runtime: PhantomData,
};
}
}
fn buffers(&mut self) -> &mut Vec<Binding> {
self.maybe_init();
let TensorState::Some { buffers, .. } = self else {
panic!("Should be init");
};
buffers
}
fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
self.maybe_init();
let TensorState::Some { tensor_maps, .. } = self else {
panic!("Should be init");
};
tensor_maps
}
fn address_type(&self) -> AddressType {
match self {
TensorState::Empty { addr_type } => *addr_type,
TensorState::Some { addr_type, .. } => *addr_type,
}
}
/// Push a new input tensor to the state.
pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
if let Some(tensor) = self.process_tensor(tensor) {
self.buffers().push(tensor);
}
}
fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
let (tensor, line_size) = match tensor {
TensorArg::Handle {
handle, line_size, ..
} => (handle, line_size),
TensorArg::Alias { .. } => return None,
};
let elem_size = tensor.elem_size * *line_size;
let buffer_len = tensor.handle.size() / elem_size as u64;
let len = tensor.shape.iter().product::<usize>() / *line_size;
with_metadata(|meta| {
meta.register_tensor(
tensor.strides.len() as u64,
buffer_len,
len as u64,
tensor.shape,
tensor.strides,
self.address_type(),
)
});
Some(tensor.handle.clone().binding())
}
/// Push a new input array to the state.
pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
if let Some(tensor) = self.process_array(array) {
self.buffers().push(tensor);
}
}
fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
let (array, line_size) = match array {
ArrayArg::Handle {
handle, line_size, ..
} => (handle, line_size),
ArrayArg::Alias { .. } => return None,
};
let elem_size = array.elem_size * *line_size;
let buffer_len = array.handle.size() / elem_size as u64;
with_metadata(|meta| {
meta.register_array(
buffer_len,
array.length[0] as u64 / *line_size as u64,
self.address_type(),
)
});
Some(array.handle.clone().binding())
}
/// Push a new tensor to the state.
pub fn push_tensor_map<K: TensorMapKind>(&mut self, map: &TensorMapArg<'_, R, K>) {
let binding = self
.process_tensor(&map.tensor)
.expect("Can't use alias for TensorMap");
let map = map.metadata.clone();
self.tensor_maps().push(TensorMapBinding { binding, map });
}
fn register(self, bindings_global: &mut Bindings) {
if let Self::Some {
buffers,
tensor_maps,
addr_type,
..
} = self
{
let metadata = with_metadata(|meta| meta.finish(addr_type));
bindings_global.buffers = buffers;
bindings_global.tensor_maps = tensor_maps;
bindings_global.metadata = metadata;
}
}
}
impl ScalarState {
/// Add a new scalar value to the state.
pub fn push<T: CubeScalar>(&mut self, val: T) {
let val = [val];
let bytes = T::as_bytes(&val);
self.data
.entry(T::cube_type())
.or_default()
.extend(bytes.iter().copied());
}
/// Add a new raw value to the state.
pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
self.data
.entry(dtype)
.or_default()
.extend(bytes.iter().copied());
}
fn register(&self, bindings: &mut Bindings) {
for (ty, values) in self.data.iter() {
let len = values.len() / ty.size();
let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
let mut data = vec![0; len_u64];
let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
slice[0..values.len()].copy_from_slice(values);
bindings
.scalars
.insert(*ty, ScalarBinding::new(*ty, len, data));
}
}
}
impl<R: Runtime> KernelLauncher<R> {
pub fn new(settings: KernelSettings) -> Self {
Self {
tensors: TensorState::Empty {
addr_type: settings.address_type,
},
scalars: Default::default(),
settings,
runtime: PhantomData,
}
}
}