Skip to content

Commit 16f08f8

Browse files
committed
[IREE-EP] Integrate iree async module in the IREE-EP
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 7b2046f commit 16f08f8

File tree

4 files changed

+145
-79
lines changed

4 files changed

+145
-79
lines changed

onnxruntime/core/providers/iree/iree_ep_runtime.cc

+136-79
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/providers/iree/iree_ep_runtime.h"
55

66
#include "core/session/onnxruntime_cxx_api.h"
7+
#include <iostream>
78

89
namespace onnxruntime::iree_ep_rt {
910

@@ -57,10 +58,18 @@ Session::~Session() {
5758
}
5859

5960
iree_status_t Session::Initialize() {
60-
return iree_runtime_session_create_with_device(
61+
iree_status_t res_status = iree_runtime_session_create_with_device(
6162
instance->instance, &session_options, instance->device,
6263
iree_runtime_instance_host_allocator(instance->instance),
6364
&session);
65+
iree_vm_module_t* custom_module = NULL;
66+
iree_allocator_t host_allocator = iree_allocator_system();
67+
IREE_CHECK_OK(iree_custom_module_async_create(
68+
iree_runtime_instance_vm_instance(instance->instance), instance->device,
69+
host_allocator, &custom_module));
70+
IREE_CHECK_OK(iree_runtime_session_append_module(session, custom_module));
71+
iree_vm_module_release(custom_module);
72+
return res_status;
6473
}
6574

6675
iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
@@ -147,6 +156,13 @@ iree_hal_element_type_t ConvertOrtElementType(ONNXTensorElementDataType et) {
147156
common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, OrtKernelContext* ort_context_c) {
148157
// TODO: This is far from the most efficient way to make a call. Synchronous and copying. We can do
149158
// better but this gets points for simplicity and lets us bootstrap the tests.
159+
iree_vm_list_t* inputs = NULL;
160+
iree_allocator_t host_allocator = iree_allocator_system();
161+
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
162+
host_allocator, &inputs));
163+
iree_vm_list_t* outputs = NULL;
164+
IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1,
165+
host_allocator, &outputs));
150166
Ort::KernelContext context(ort_context_c);
151167
SynchronousCall call(session);
152168
ORT_RETURN_IF_ERROR(HandleIREEStatus(call.InitializeByName(entrypoint_name)));
@@ -161,59 +177,93 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
161177

162178
// Process inputs. We could be smarter about this in a lot of ways, including carrying
163179
// more state from compilation so we are doing less munging here.
164-
for (size_t i = 0; i < context.GetInputCount(); ++i) {
165-
auto input_tensor = context.GetInput(i);
166-
ORT_ENFORCE(input_tensor.IsTensor());
167-
168-
// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
169-
// is useful for anything.
170-
auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType();
171-
ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU);
172-
173-
const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo();
174-
auto element_type = ConvertOrtElementType(tensor_type.GetElementType());
175-
ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ",
176-
static_cast<int>(tensor_type.GetElementType()));
177-
ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type));
178-
size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type);
179-
180-
// Yes, that's right, returned as an std::vector by value :(
181-
// And of a different type than we expect.
182-
std::vector<int64_t> shape = tensor_type.GetShape();
183-
dims.resize(shape.size());
184-
std::copy(shape.begin(), shape.end(), dims.begin());
185-
186-
// No convenient way to get the byte size of the raw data.
187-
size_t element_count = tensor_type.GetElementCount();
188-
const void* raw_data = input_tensor.GetTensorRawData();
189-
190-
HalBufferView arg;
191-
iree_hal_buffer_params_t buffer_params;
192-
memset(&buffer_params, 0, sizeof(buffer_params));
193-
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
194-
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
195-
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
196-
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy(
197-
device, device_allocator,
198-
// Shape rank and dimensions:
199-
dims.size(), dims.data(),
200-
// Element type:
201-
element_type,
202-
// Encoding type:
203-
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
204-
buffer_params,
205-
// The actual heap buffer to wrap or clone and its allocator:
206-
iree_make_const_byte_span(raw_data, element_count * element_size_bytes),
207-
// Buffer view + storage are returned and owned by the caller:
208-
&arg.bv)));
209-
210-
// Add it to the call.
211-
iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
212-
ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
213-
}
180+
181+
std::cout << "input count: " << context.GetInputCount() << "\n";
182+
// for (size_t i = 0; i < context.GetInputCount(); ++i) {
183+
auto input_tensor = context.GetInput(0);
184+
ORT_ENFORCE(input_tensor.IsTensor());
185+
186+
// The device type is rather... sparse... CPU, GPU and FPGA. Not sure how that
187+
// is useful for anything.
188+
auto ort_device_type = input_tensor.GetTensorMemoryInfo().GetDeviceType();
189+
ORT_ENFORCE(ort_device_type == OrtMemoryInfoDeviceType_CPU);
190+
191+
const auto& tensor_type = input_tensor.GetTensorTypeAndShapeInfo();
192+
auto element_type = ConvertOrtElementType(tensor_type.GetElementType());
193+
ORT_ENFORCE(element_type != IREE_HAL_ELEMENT_TYPE_NONE, "Unsupported element type ",
194+
static_cast<int>(tensor_type.GetElementType()));
195+
ORT_ENFORCE(iree_hal_element_is_byte_aligned(element_type));
196+
size_t element_size_bytes = iree_hal_element_dense_byte_count(element_type);
197+
198+
// Yes, that's right, returned as an std::vector by value :(
199+
// And of a different type than we expect.
200+
std::vector<int64_t> shape = tensor_type.GetShape();
201+
dims.resize(shape.size());
202+
std::copy(shape.begin(), shape.end(), dims.begin());
203+
204+
// No convenient way to get the byte size of the raw data.
205+
size_t element_count = tensor_type.GetElementCount();
206+
const void* raw_data = input_tensor.GetTensorRawData();
207+
208+
HalBufferView arg;
209+
iree_hal_buffer_params_t buffer_params;
210+
memset(&buffer_params, 0, sizeof(buffer_params));
211+
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
212+
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
213+
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
214+
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_view_allocate_buffer_copy(
215+
device, device_allocator,
216+
// Shape rank and dimensions:
217+
dims.size(), dims.data(),
218+
// Element type:
219+
element_type,
220+
// Encoding type:
221+
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
222+
buffer_params,
223+
// The actual heap buffer to wrap or clone and its allocator:
224+
iree_make_const_byte_span(raw_data, element_count * element_size_bytes),
225+
// Buffer view + storage are returned and owned by the caller:
226+
&arg.bv)));
227+
228+
iree_vm_ref_t input_view_ref = iree_hal_buffer_view_move_ref(arg.bv);
229+
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &input_view_ref));
230+
231+
iree_hal_semaphore_t* semaphore = NULL;
232+
IREE_CHECK_OK(iree_hal_semaphore_create(
233+
device, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore));
234+
iree_hal_fence_t* fence_t1 = NULL;
235+
IREE_CHECK_OK(
236+
iree_hal_fence_create_at(semaphore, 1ull, host_allocator, &fence_t1));
237+
iree_hal_fence_t* fence_t2 = NULL;
238+
IREE_CHECK_OK(
239+
iree_hal_fence_create_at(semaphore, 2ull, host_allocator, &fence_t2));
240+
iree_hal_semaphore_release(semaphore);
241+
std::cout << "\n semaphore released";
242+
iree_vm_ref_t fence_t1_ref = iree_hal_fence_retain_ref(fence_t1);
243+
std::cout << "\n semaphore released1";
244+
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t1_ref));
245+
std::cout << "\n semaphore released2";
246+
iree_vm_ref_t fence_t2_ref = iree_hal_fence_retain_ref(fence_t2);
247+
std::cout << "\n semaphore released3";
248+
IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs, &fence_t2_ref));
249+
std::cout << "\n semaphore released4";
250+
IREE_CHECK_OK(iree_hal_fence_signal(fence_t1));
251+
std::cout << "\n T=1 reached";
252+
// Add it to the call.
253+
iree_string_view_t entry_point = iree_make_cstring_view(entrypoint_name);
254+
IREE_CHECK_OK(
255+
iree_runtime_session_call_by_name(session, entry_point, inputs, outputs));
256+
// We could go do other things now while the async work progresses. Here we
257+
// just immediately wait.
258+
IREE_CHECK_OK(iree_hal_fence_wait(fence_t2, iree_infinite_timeout()));
259+
std::cout << "\n T=2 reached";
260+
// iree_status_t status = iree_runtime_call_inputs_push_back_buffer_view(&call.call, arg.bv);
261+
// ORT_RETURN_IF_ERROR(HandleIREEStatus(status));
262+
// }
263+
// Read back the tensor<?xi32> result:
214264

215265
// Invoke.
216-
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, /*flags=*/0)));
266+
// ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_runtime_call_invoke(&call.call, [>flags=<]0)));
217267

218268
// Marshal the outputs.
219269
// TODO: Accessing the ORT output requires the shape and then we could get zero copy
@@ -222,37 +272,44 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
222272
// convention, which allows passing in slabs of result buffers. Further, that would
223273
// run the host-side computation (which would compute output metadata) inline.
224274
// For static cases, we could also side-load the shape from the compile time.
225-
std::vector<int64_t> shape;
226-
for (size_t i = 0; i < context.GetOutputCount(); ++i) {
227-
HalBufferView ret;
228-
ORT_RETURN_IF_ERROR(HandleIREEStatus(
229-
iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
230-
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
231-
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
232-
shape.resize(ret_rank);
233-
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());
234-
auto output_tensor = context.GetOutput(i, shape.data(), shape.size());
235-
ORT_ENFORCE(output_tensor.IsTensor());
236-
237-
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
238-
// TODO: Synchronous mapping read, like everything in this function, is not a
239-
// great idea. It isn't supported on all device types and will need a scrub.
240-
iree_string_view_t device_val = iree_hal_device_id(device);
241-
auto device_str = std::string(device_val.data, device_val.size);
242-
if (device_str == "hip") {
243-
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
244-
iree_runtime_session_device(session),
245-
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
246-
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
247-
iree_infinite_timeout())));
248-
return common::Status::OK();
249-
}
275+
// std::vector<int64_t> shape;
276+
std::cout << "output count: " << context.GetOutputCount() << "\n";
277+
// for (size_t i = 0; i < context.GetOutputCount(); ++i) {
278+
HalBufferView ret;
279+
ret.bv = iree_vm_list_get_buffer_view_assign(outputs, 0);
280+
// ORT_RETURN_IF_ERROR(HandleIREEStatus(
281+
// iree_runtime_call_outputs_pop_front_buffer_view(&call.call, &ret.bv)));
282+
size_t ret_rank = iree_hal_buffer_view_shape_rank(ret.bv);
283+
const iree_hal_dim_t* ret_dims = iree_hal_buffer_view_shape_dims(ret.bv);
284+
shape.clear();
285+
shape.resize(ret_rank);
286+
std::copy(ret_dims, ret_dims + ret_rank, shape.begin());
287+
auto output_tensor = context.GetOutput(0, shape.data(), shape.size());
288+
ORT_ENFORCE(output_tensor.IsTensor());
289+
290+
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
291+
// TODO: Synchronous mapping read, like everything in this function, is not a
292+
// great idea. It isn't supported on all device types and will need a scrub.
293+
iree_string_view_t device_val = iree_hal_device_id(device);
294+
auto device_str = std::string(device_val.data, device_val.size);
295+
if (device_str == "hip") {
296+
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
297+
iree_runtime_session_device(session),
298+
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
299+
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
300+
iree_infinite_timeout())));
301+
return common::Status::OK();
302+
}
250303
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
251304
output_tensor.GetTensorMutableRawData(),
252305
iree_hal_buffer_view_byte_length(ret.bv))));
253-
}
306+
// }
254307

255-
return common::Status::OK();
308+
iree_vm_list_release(inputs);
309+
iree_vm_list_release(outputs);
310+
iree_hal_fence_release(fence_t1);
311+
iree_hal_fence_release(fence_t2);
312+
return common::Status::OK();
256313
}
257314

258315
} // namespace onnxruntime::iree_ep_rt

onnxruntime/core/providers/iree/iree_ep_runtime.h

+3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
#include "core/common/common.h"
77
#include "core/session/onnxruntime_c_api.h"
8+
#include "iree/modules/hal/types.h"
89
#include "iree/runtime/api.h"
910

11+
#include "module.h"
12+
1013
#include <filesystem>
1114

1215
namespace fs = std::filesystem;

onnxruntime/core/providers/iree/iree_execution_provider.cc

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
118118
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
119119
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
120120
}
121+
std::string extra_flag_2 = "--iree-execution-model=async-external";
122+
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag_2.c_str()));
121123

122124
ORT_RETURN_IF_ERROR(compiler.Initialize());
123125
std::string module_name = "ort";

onnxruntime/test/perftest/ort_test_session.cc

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#include "core/providers/dml/dml_session_options_config_keys.h"
2323
#endif
2424

25+
#ifdef USE_IREE
26+
#include "core/providers/iree/iree_provider_factory.h"
27+
#endif
28+
2529
#ifdef _WIN32
2630
#define strdup _strdup
2731
#endif

0 commit comments

Comments
 (0)