4
4
#include " core/providers/iree/iree_ep_runtime.h"
5
5
6
6
#include " core/session/onnxruntime_cxx_api.h"
7
+ #include < iostream>
7
8
8
9
namespace onnxruntime ::iree_ep_rt {
9
10
@@ -57,10 +58,18 @@ Session::~Session() {
57
58
}
58
59
59
60
iree_status_t Session::Initialize () {
60
- return iree_runtime_session_create_with_device (
61
+ iree_status_t res_status = iree_runtime_session_create_with_device (
61
62
instance->instance , &session_options, instance->device ,
62
63
iree_runtime_instance_host_allocator (instance->instance ),
63
64
&session);
65
+ iree_vm_module_t * custom_module = NULL ;
66
+ iree_allocator_t host_allocator = iree_allocator_system ();
67
+ IREE_CHECK_OK (iree_custom_module_async_create (
68
+ iree_runtime_instance_vm_instance (instance->instance ), instance->device ,
69
+ host_allocator, &custom_module));
70
+ IREE_CHECK_OK (iree_runtime_session_append_module (session, custom_module));
71
+ iree_vm_module_release (custom_module);
72
+ return res_status;
64
73
}
65
74
66
75
iree_status_t Session::AppendBytecodeModule (fs::path vmfb_path, std::function<void ()> dispose_callback) {
@@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
147
156
common::Status Session::Call (const char * entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
148
157
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
149
158
// better but this gets points for simplicity and lets us bootstrap the tests.
159
+ iree_vm_list_t * inputs = NULL ;
160
+ iree_allocator_t host_allocator = iree_allocator_system ();
161
+ IREE_CHECK_OK (iree_vm_list_create (iree_vm_make_undefined_type_def (), 1 ,
162
+ host_allocator, &inputs));
163
+ iree_vm_list_t * outputs = NULL ;
164
+ IREE_CHECK_OK (iree_vm_list_create (iree_vm_make_undefined_type_def (), 1 ,
165
+ host_allocator, &outputs));
150
166
Ort::KernelContext context (ort_context_c);
151
167
SynchronousCall call (session);
152
168
ORT_RETURN_IF_ERROR (HandleIREEStatus (call.InitializeByName (entrypoint_name)));
@@ -161,59 +177,93 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
161
177
162
178
// Process inputs. We could be smarter about this in a lot of ways, including carrying
163
179
// more state from compilation so we are doing less munging here.
164
- for (size_t i = 0 ; i < context.GetInputCount (); ++i) {
165
- auto input_tensor = context.GetInput (i);
166
- ORT_ENFORCE (input_tensor.IsTensor ());
167
-
168
- // The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
169
- // is useful for anything.
170
- auto ort_device_type = input_tensor.GetTensorMemoryInfo ().GetDeviceType ();
171
- ORT_ENFORCE (ort_device_type == OrtMemoryInfoDeviceType_CPU);
172
-
173
- const auto & tensor_type = input_tensor.GetTensorTypeAndShapeInfo ();
174
- auto element_type = ConvertOrtElementType (tensor_type.GetElementType ());
175
- ORT_ENFORCE (element_type != IREE_HAL_ELEMENT_TYPE_NONE, " Unsupported element type " ,
176
- static_cast <int >(tensor_type.GetElementType ()));
177
- ORT_ENFORCE (iree_hal_element_is_byte_aligned (element_type));
178
- size_t element_size_bytes = iree_hal_element_dense_byte_count (element_type);
179
-
180
- // Yes, that's right, returned as an std::vector by value :(
181
- // And of a different type than we expect.
182
- std::vector<int64_t > shape = tensor_type.GetShape ();
183
- dims.resize (shape.size ());
184
- std::copy (shape.begin (), shape.end (), dims.begin ());
185
-
186
- // No convenient way to get the byte size of the raw data.
187
- size_t element_count = tensor_type.GetElementCount ();
188
- const void * raw_data = input_tensor.GetTensorRawData ();
189
-
190
- HalBufferView arg;
191
- iree_hal_buffer_params_t buffer_params;
192
- memset (&buffer_params, 0 , sizeof (buffer_params));
193
- buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
194
- buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
195
- buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
196
- ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_buffer_view_allocate_buffer_copy (
197
- device, device_allocator,
198
- // Shape rank and dimensions:
199
- dims.size (), dims.data (),
200
- // Element type:
201
- element_type,
202
- // Encoding type:
203
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
204
- buffer_params,
205
- // The actual heap buffer to wrap or clone and its allocator:
206
- iree_make_const_byte_span (raw_data, element_count * element_size_bytes),
207
- // Buffer view + storage are returned and owned by the caller:
208
- &arg.bv )));
209
-
210
- // Add it to the call.
211
- iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view (&call.call , arg.bv );
212
- ORT_RETURN_IF_ERROR (HandleIREEStatus (status));
213
- }
180
+
181
+ std::cout << " input count: " << context.GetInputCount () << " \n " ;
182
+ // for (size_t i = 0; i < context.GetInputCount(); ++i) {
183
+ auto input_tensor = context.GetInput (0 );
184
+ ORT_ENFORCE (input_tensor.IsTensor ());
185
+
186
+ // The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
187
+ // is useful for anything.
188
+ auto ort_device_type = input_tensor.GetTensorMemoryInfo ().GetDeviceType ();
189
+ ORT_ENFORCE (ort_device_type == OrtMemoryInfoDeviceType_CPU);
190
+
191
+ const auto & tensor_type = input_tensor.GetTensorTypeAndShapeInfo ();
192
+ auto element_type = ConvertOrtElementType (tensor_type.GetElementType ());
193
+ ORT_ENFORCE (element_type != IREE_HAL_ELEMENT_TYPE_NONE, " Unsupported element type " ,
194
+ static_cast <int >(tensor_type.GetElementType ()));
195
+ ORT_ENFORCE (iree_hal_element_is_byte_aligned (element_type));
196
+ size_t element_size_bytes = iree_hal_element_dense_byte_count (element_type);
197
+
198
+ // Yes, that's right, returned as an std::vector by value :(
199
+ // And of a different type than we expect.
200
+ std::vector<int64_t > shape = tensor_type.GetShape ();
201
+ dims.resize (shape.size ());
202
+ std::copy (shape.begin (), shape.end (), dims.begin ());
203
+
204
+ // No convenient way to get the byte size of the raw data.
205
+ size_t element_count = tensor_type.GetElementCount ();
206
+ const void * raw_data = input_tensor.GetTensorRawData ();
207
+
208
+ HalBufferView arg;
209
+ iree_hal_buffer_params_t buffer_params;
210
+ memset (&buffer_params, 0 , sizeof (buffer_params));
211
+ buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
212
+ buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
213
+ buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
214
+ ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_buffer_view_allocate_buffer_copy (
215
+ device, device_allocator,
216
+ // Shape rank and dimensions:
217
+ dims.size (), dims.data (),
218
+ // Element type:
219
+ element_type,
220
+ // Encoding type:
221
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
222
+ buffer_params,
223
+ // The actual heap buffer to wrap or clone and its allocator:
224
+ iree_make_const_byte_span (raw_data, element_count * element_size_bytes),
225
+ // Buffer view + storage are returned and owned by the caller:
226
+ &arg.bv )));
227
+
228
+ iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref (arg.bv );
229
+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &input_view_ref));
230
+
231
+ iree_hal_semaphore_t * semaphore = NULL ;
232
+ IREE_CHECK_OK (iree_hal_semaphore_create (
233
+ device, 0ull , IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
234
+ iree_hal_fence_t * fence_t1 = NULL ;
235
+ IREE_CHECK_OK (
236
+ iree_hal_fence_create_at (semaphore, 1ull , host_allocator, &fence_t1));
237
+ iree_hal_fence_t * fence_t2 = NULL ;
238
+ IREE_CHECK_OK (
239
+ iree_hal_fence_create_at (semaphore, 2ull , host_allocator, &fence_t2));
240
+ iree_hal_semaphore_release (semaphore);
241
+ std::cout << " \n semaphore released" ;
242
+ iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref (fence_t1);
243
+ std::cout << " \n semaphore released1" ;
244
+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &fence_t1_ref));
245
+ std::cout << " \n semaphore released2" ;
246
+ iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref (fence_t2);
247
+ std::cout << " \n semaphore released3" ;
248
+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &fence_t2_ref));
249
+ std::cout << " \n semaphore released4" ;
250
+ IREE_CHECK_OK (iree_hal_fence_signal (fence_t1));
251
+ std::cout << " \n T=1 reached" ;
252
+ // Add it to the call.
253
+ iree_string_view_t entry_point = iree_make_cstring_view (entrypoint_name);
254
+ IREE_CHECK_OK (
255
+ iree_runtime_session_call_by_name (session, entry_point, inputs, outputs));
256
+ // We could go do other things now while the async work progresses. Here we
257
+ // just immediately wait.
258
+ IREE_CHECK_OK (iree_hal_fence_wait (fence_t2, iree_infinite_timeout ()));
259
+ std::cout << " \n T=2 reached" ;
260
+ // iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
261
+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
262
+ // }
263
+ // Read back the tensor<?xi32> result:
214
264
215
265
// Invoke.
216
- ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_runtime_call_invoke (&call.call , /* flags=*/ 0 )));
266
+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [> flags=<] 0)));
217
267
218
268
// Marshal the outputs.
219
269
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
@@ -222,37 +272,44 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
222
272
// convention, which allows passing in slabs of result buffers. Further, that would
223
273
// run the host-side computation (which would compute output metadata) inline.
224
274
// For static cases, we could also side-load the shape from the compile time.
225
- std::vector<int64_t > shape;
226
- for (size_t i = 0 ; i < context.GetOutputCount (); ++i) {
227
- HalBufferView ret;
228
- ORT_RETURN_IF_ERROR (HandleIREEStatus (
229
- iree_runtime_call_outputs_pop_front_buffer_view (&call.call , &ret.bv )));
230
- size_t ret_rank = iree_hal_buffer_view_shape_rank (ret.bv );
231
- const iree_hal_dim_t * ret_dims = iree_hal_buffer_view_shape_dims (ret.bv );
232
- shape.resize (ret_rank);
233
- std::copy (ret_dims, ret_dims + ret_rank, shape.begin ());
234
- auto output_tensor = context.GetOutput (i, shape.data (), shape.size ());
235
- ORT_ENFORCE (output_tensor.IsTensor ());
236
-
237
- iree_hal_buffer_t * ret_buffer = iree_hal_buffer_view_buffer (ret.bv );
238
- // TODO: Synchronous mapping read, like everything in this function, is not a
239
- // great idea. It isn't supported on all device types and will need a scrub.
240
- iree_string_view_t device_val = iree_hal_device_id (device);
241
- auto device_str = std::string (device_val.data , device_val.size );
242
- if (device_str == " hip" ) {
243
- ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_device_transfer_d2h (
244
- iree_runtime_session_device (session),
245
- ret_buffer, 0 , output_tensor.GetTensorMutableRawData (),
246
- iree_hal_buffer_view_byte_length (ret.bv ), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
247
- iree_infinite_timeout ())));
248
- return common::Status::OK ();
249
- }
275
+ // std::vector<int64_t> shape;
276
+ std::cout << " output count: " << context.GetOutputCount () << " \n " ;
277
+ // for (size_t i = 0; i < context.GetOutputCount(); ++i) {
278
+ HalBufferView ret;
279
+ ret.bv = iree_vm_list_get_buffer_view_assign (outputs, 0 );
280
+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(
281
+ // iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
282
+ size_t ret_rank = iree_hal_buffer_view_shape_rank (ret.bv );
283
+ const iree_hal_dim_t * ret_dims = iree_hal_buffer_view_shape_dims (ret.bv );
284
+ shape.clear ();
285
+ shape.resize (ret_rank);
286
+ std::copy (ret_dims, ret_dims + ret_rank, shape.begin ());
287
+ auto output_tensor = context.GetOutput (0 , shape.data (), shape.size ());
288
+ ORT_ENFORCE (output_tensor.IsTensor ());
289
+
290
+ iree_hal_buffer_t * ret_buffer = iree_hal_buffer_view_buffer (ret.bv );
291
+ // TODO: Synchronous mapping read, like everything in this function, is not a
292
+ // great idea. It isn't supported on all device types and will need a scrub.
293
+ iree_string_view_t device_val = iree_hal_device_id (device);
294
+ auto device_str = std::string (device_val.data , device_val.size );
295
+ if (device_str == " hip" ) {
296
+ ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_device_transfer_d2h (
297
+ iree_runtime_session_device (session),
298
+ ret_buffer, 0 , output_tensor.GetTensorMutableRawData (),
299
+ iree_hal_buffer_view_byte_length (ret.bv ), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
300
+ iree_infinite_timeout ())));
301
+ return common::Status::OK ();
302
+ }
250
303
ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_buffer_map_read (ret_buffer, /* source_offset=*/ 0 ,
251
304
output_tensor.GetTensorMutableRawData (),
252
305
iree_hal_buffer_view_byte_length (ret.bv ))));
253
- }
306
+ // }
254
307
255
- return common::Status::OK ();
308
+ iree_vm_list_release (inputs);
309
+ iree_vm_list_release (outputs);
310
+ iree_hal_fence_release (fence_t1);
311
+ iree_hal_fence_release (fence_t2);
312
+ return common::Status::OK ();
256
313
}
257
314
258
315
} // namespace onnxruntime::iree_ep_rt
0 commit comments