Skip to content

Commit 6417ad4

Browse files
committed
[IREE-EP] Integrate iree async module in the IREE-EP
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 7b2046f commit 6417ad4

File tree

4 files changed

+79
-13
lines changed

4 files changed

+79
-13
lines changed

onnxruntime/core/providers/iree/iree_ep_runtime.cc

+70-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/providers/iree/iree_ep_runtime.h"
55

66
#include "core/session/onnxruntime_cxx_api.h"
7+
#include <iostream>
78

89
namespace onnxruntime::iree_ep_rt {
910

@@ -57,10 +58,18 @@ Session::~Session() {
5758
}
5859

5960
iree_status_t Session::Initialize() {
60-
return iree_runtime_session_create_with_device(
61+
iree_status_t res_status = iree_runtime_session_create_with_device(
6162
instance->instance, &session_options, instance->device,
6263
iree_runtime_instance_host_allocator(instance->instance),
6364
&session);
65+
iree_vm_module_t* custom_module = NULL;
66+
iree_allocator_t host_allocator = iree_allocator_system();
67+
IREE_CHECK_OK(iree_custom_module_async_create(
68+
iree_runtime_instance_vm_instance(instance->instance), instance->device,
69+
host_allocator, &custom_module));
70+
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module));
71+
iree_vm_module_release(custom_module);
72+
return res_status;
6473
}
6574

6675
iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
@@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
147156
common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
148157
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
149158
// better but this gets points for simplicity and lets us bootstrap the tests.
159+
iree_vm_list_t* inputs = NULL;
160+
iree_allocator_t host_allocator = iree_allocator_system();
161+
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
162+
host_allocator, &inputs));
163+
iree_vm_list_t* outputs = NULL;
164+
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
165+
host_allocator, &outputs));
150166
Ort::KernelContext context(ort_context_c);
151167
SynchronousCall call(session);
152168
ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name)));
@@ -161,8 +177,10 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
161177

162178
// Process inputs. We could be smarter about this in a lot of ways, including carrying
163179
// more state from compilation so we are doing less munging here.
164-
for (size_t i = 0; i < context.GetInputCount(); ++i) {
165-
auto input_tensor = context.GetInput(i);
180+
181+
std::cout<<"input count: "<<context.GetInputCount()<<"\n";
182+
// for (size_t i = 0; i < context.GetInputCount(); ++i) {
183+
auto input_tensor = context.GetInput(0);
166184
ORT_ENFORCE(input_tensor.IsTensor());
167185

168186
// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
@@ -207,13 +225,45 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
207225
// Buffer view + storage are returned and owned by the caller:
208226
&arg.bv)));
209227

228+
iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref(arg.bv);
229+
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_view_ref));
230+
231+
iree_hal_semaphore_t* semaphore = NULL;
232+
IREE_CHECK_OK(iree_hal_semaphore_create(
233+
device, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
234+
iree_hal_fence_t* fence_t1 = NULL;
235+
IREE_CHECK_OK(
236+
iree_hal_fence_create_at(semaphore, 1ull, host_allocator, &fence_t1));
237+
iree_hal_fence_t* fence_t2 = NULL;
238+
IREE_CHECK_OK(
239+
iree_hal_fence_create_at(semaphore, 2ull, host_allocator, &fence_t2));
240+
iree_hal_semaphore_release(semaphore);
241+
std::cout<<"\n semaphore released";
242+
iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref(fence_t1);
243+
std::cout<<"\n semaphore released1";
244+
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t1_ref));
245+
std::cout<<"\n semaphore released2";
246+
iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref(fence_t2);
247+
std::cout<<"\n semaphore released3";
248+
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t2_ref));
249+
std::cout<<"\n semaphore released4";
250+
IREE_CHECK_OK(iree_hal_fence_signal(fence_t1));
251+
std::cout<<"\n T=1 reached";
210252
// Add it to the call.
211-
iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
212-
ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
213-
}
253+
iree_string_view_t entry_point = iree_make_cstring_view(entrypoint_name);
254+
IREE_CHECK_OK(
255+
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs));
256+
// We could go do other things now while the async work progresses. Here we
257+
// just immediately wait.
258+
IREE_CHECK_OK(iree_hal_fence_wait(fence_t2, iree_infinite_timeout()));
259+
std::cout<<"\n T=2 reached";
260+
// iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
261+
// ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
262+
// }
263+
// Read back the tensor<?xi32> result:
214264

215265
// Invoke.
216-
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0)));
266+
// ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0)));
217267

218268
// Marshal the outputs.
219269
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
@@ -222,16 +272,19 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
222272
// convention, which allows passing in slabs of result buffers. Further, that would
223273
// run the host-side computation (which would compute output metadata) inline.
224274
// For static cases, we could also side-load the shape from the compile time.
225-
std::vector<int64_t> shape;
226-
for (size_t i = 0; i < context.GetOutputCount(); ++i) {
275+
// std::vector<int64_t> shape;
276+
std::cout<<"output count: "<<context.GetOutputCount()<<"\n";
277+
// for (size_t i = 0; i < context.GetOutputCount(); ++i) {
227278
HalBufferView ret;
228-
ORT_RETURN_IF_ERROR(HandleIREEStatus(
229-
iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
279+
ret.bv = iree_vm_list_get_buffer_view_assign(outputs, 0);
280+
// ORT_RETURN_IF_ERROR(HandleIREEStatus(
281+
// iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
230282
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
231283
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
284+
shape.clear();
232285
shape.resize(ret_rank);
233286
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());
234-
auto output_tensor = context.GetOutput(i, shape.data(), shape.size());
287+
auto output_tensor = context.GetOutput(0, shape.data(), shape.size());
235288
ORT_ENFORCE(output_tensor.IsTensor());
236289

237290
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
@@ -250,8 +303,12 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
250303
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
251304
output_tensor.GetTensorMutableRawData(),
252305
iree_hal_buffer_view_byte_length(ret.bv))));
253-
}
306+
// }
254307

308+
iree_vm_list_release(inputs);
309+
iree_vm_list_release(outputs);
310+
iree_hal_fence_release(fence_t1);
311+
iree_hal_fence_release(fence_t2);
255312
return common::Status::OK();
256313
}
257314

onnxruntime/core/providers/iree/iree_ep_runtime.h

+3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
#include "core/common/common.h"
77
#include "core/session/onnxruntime_c_api.h"
8+
#include "iree/modules/hal/types.h"
89
#include "iree/runtime/api.h"
910

11+
#include "module.h"
12+
1013
#include <filesystem>
1114

1215
namespace fs = std::filesystem;

onnxruntime/core/providers/iree/iree_execution_provider.cc

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
118118
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
119119
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
120120
}
121+
std::string extra_flag_2 = "--iree-execution-model=async-external";
122+
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag_2.c_str()));
121123

122124
ORT_RETURN_IF_ERROR(compiler.Initialize());
123125
std::string module_name = "ort";

onnxruntime/test/perftest/ort_test_session.cc

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#include "core/providers/dml/dml_session_options_config_keys.h"
2323
#endif
2424

25+
#ifdef USE_IREE
26+
#include "core/providers/iree/iree_provider_factory.h"
27+
#endif
28+
2529
#ifdef _WIN32
2630
#define strdup _strdup
2731
#endif

0 commit comments

Comments
 (0)