4
4
#include " core/providers/iree/iree_ep_runtime.h"
5
5
6
6
#include " core/session/onnxruntime_cxx_api.h"
7
+ #include < iostream>
7
8
8
9
namespace onnxruntime ::iree_ep_rt {
9
10
@@ -57,10 +58,18 @@ Session::~Session() {
57
58
}
58
59
59
60
iree_status_t Session::Initialize () {
60
- return iree_runtime_session_create_with_device (
61
+ iree_status_t res_status = iree_runtime_session_create_with_device (
61
62
instance->instance , &session_options, instance->device ,
62
63
iree_runtime_instance_host_allocator (instance->instance ),
63
64
&session);
65
+ iree_vm_module_t * custom_module = NULL ;
66
+ iree_allocator_t host_allocator = iree_allocator_system ();
67
+ IREE_CHECK_OK (iree_custom_module_async_create (
68
+ iree_runtime_instance_vm_instance (instance->instance ), instance->device ,
69
+ host_allocator, &custom_module));
70
+ IREE_CHECK_OK (iree_runtime_session_append_module (session, custom_module));
71
+ iree_vm_module_release (custom_module);
72
+ return res_status;
64
73
}
65
74
66
75
iree_status_t Session::AppendBytecodeModule (fs::path vmfb_path, std::function<void ()> dispose_callback) {
@@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
147
156
common::Status Session::Call (const char * entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
148
157
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
149
158
// better but this gets points for simplicity and lets us bootstrap the tests.
159
+ iree_vm_list_t * inputs = NULL ;
160
+ iree_allocator_t host_allocator = iree_allocator_system ();
161
+ IREE_CHECK_OK (iree_vm_list_create (iree_vm_make_undefined_type_def (), 1 ,
162
+ host_allocator, &inputs));
163
+ iree_vm_list_t * outputs = NULL ;
164
+ IREE_CHECK_OK (iree_vm_list_create (iree_vm_make_undefined_type_def (), 1 ,
165
+ host_allocator, &outputs));
150
166
Ort::KernelContext context (ort_context_c);
151
167
SynchronousCall call (session);
152
168
ORT_RETURN_IF_ERROR (HandleIREEStatus (call.InitializeByName (entrypoint_name)));
@@ -161,8 +177,10 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
161
177
162
178
// Process inputs. We could be smarter about this in a lot of ways, including carrying
163
179
// more state from compilation so we are doing less munging here.
164
- for (size_t i = 0 ; i < context.GetInputCount (); ++i) {
165
- auto input_tensor = context.GetInput (i);
180
+
181
+ std::cout<<" input count: " <<context.GetInputCount ()<<" \n " ;
182
+ // for (size_t i = 0; i < context.GetInputCount(); ++i) {
183
+ auto input_tensor = context.GetInput (0 );
166
184
ORT_ENFORCE (input_tensor.IsTensor ());
167
185
168
186
// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
@@ -207,13 +225,45 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
207
225
// Buffer view + storage are returned and owned by the caller:
208
226
&arg.bv )));
209
227
228
+ iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref (arg.bv );
229
+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &input_view_ref));
230
+
231
+ iree_hal_semaphore_t * semaphore = NULL ;
232
+ IREE_CHECK_OK (iree_hal_semaphore_create (
233
+ device, 0ull , IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
234
+ iree_hal_fence_t * fence_t1 = NULL ;
235
+ IREE_CHECK_OK (
236
+ iree_hal_fence_create_at (semaphore, 1ull , host_allocator, &fence_t1));
237
+ iree_hal_fence_t * fence_t2 = NULL ;
238
+ IREE_CHECK_OK (
239
+ iree_hal_fence_create_at (semaphore, 2ull , host_allocator, &fence_t2));
240
+ iree_hal_semaphore_release (semaphore);
241
+ std::cout<<" \n semaphore released" ;
242
+ iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref (fence_t1);
243
+ std::cout<<" \n semaphore released1" ;
244
+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &fence_t1_ref));
245
+ std::cout<<" \n semaphore released2" ;
246
+ iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref (fence_t2);
247
+ std::cout<<" \n semaphore released3" ;
248
+ IREE_CHECK_OK (iree_vm_list_push_ref_move (inputs, &fence_t2_ref));
249
+ std::cout<<" \n semaphore released4" ;
250
+ IREE_CHECK_OK (iree_hal_fence_signal (fence_t1));
251
+ std::cout<<" \n T=1 reached" ;
210
252
// Add it to the call.
211
- iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view (&call.call , arg.bv );
212
- ORT_RETURN_IF_ERROR (HandleIREEStatus (status));
213
- }
253
+ iree_string_view_t entry_point = iree_make_cstring_view (entrypoint_name);
254
+ IREE_CHECK_OK (
255
+ iree_runtime_session_call_by_name (session, entry_point, inputs, outputs));
256
+ // We could go do other things now while the async work progresses. Here we
257
+ // just immediately wait.
258
+ IREE_CHECK_OK (iree_hal_fence_wait (fence_t2, iree_infinite_timeout ()));
259
+ std::cout<<" \n T=2 reached" ;
260
+ // iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
261
+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
262
+ // }
263
+ // Read back the tensor<?xi32> result:
214
264
215
265
// Invoke.
216
- ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_runtime_call_invoke (&call.call , /* flags=*/ 0 )));
266
+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [> flags=<] 0)));
217
267
218
268
// Marshal the outputs.
219
269
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
@@ -222,16 +272,19 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
222
272
// convention, which allows passing in slabs of result buffers. Further, that would
223
273
// run the host-side computation (which would compute output metadata) inline.
224
274
// For static cases, we could also side-load the shape from the compile time.
225
- std::vector<int64_t > shape;
226
- for (size_t i = 0 ; i < context.GetOutputCount (); ++i) {
275
+ // std::vector<int64_t> shape;
276
+ std::cout<<" output count: " <<context.GetOutputCount ()<<" \n " ;
277
+ // for (size_t i = 0; i < context.GetOutputCount(); ++i) {
227
278
HalBufferView ret;
228
- ORT_RETURN_IF_ERROR (HandleIREEStatus (
229
- iree_runtime_call_outputs_pop_front_buffer_view (&call.call , &ret.bv )));
279
+ ret.bv = iree_vm_list_get_buffer_view_assign (outputs, 0 );
280
+ // ORT_RETURN_IF_ERROR(HandleIREEStatus(
281
+ // iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
230
282
size_t ret_rank = iree_hal_buffer_view_shape_rank (ret.bv );
231
283
const iree_hal_dim_t * ret_dims = iree_hal_buffer_view_shape_dims (ret.bv );
284
+ shape.clear ();
232
285
shape.resize (ret_rank);
233
286
std::copy (ret_dims, ret_dims + ret_rank, shape.begin ());
234
- auto output_tensor = context.GetOutput (i , shape.data (), shape.size ());
287
+ auto output_tensor = context.GetOutput (0 , shape.data (), shape.size ());
235
288
ORT_ENFORCE (output_tensor.IsTensor ());
236
289
237
290
iree_hal_buffer_t * ret_buffer = iree_hal_buffer_view_buffer (ret.bv );
@@ -250,8 +303,12 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
250
303
ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_buffer_map_read (ret_buffer, /* source_offset=*/ 0 ,
251
304
output_tensor.GetTensorMutableRawData (),
252
305
iree_hal_buffer_view_byte_length (ret.bv ))));
253
- }
306
+ // }
254
307
308
+ iree_vm_list_release (inputs);
309
+ iree_vm_list_release (outputs);
310
+ iree_hal_fence_release (fence_t1);
311
+ iree_hal_fence_release (fence_t2);
255
312
return common::Status::OK ();
256
313
}
257
314
0 commit comments