Skip to content

Commit 479a487

Browse files
committed
[IREE][EP] Add support for rocm backend
This commit adds support for rocm backend in iree-ep. Signed-Off-by: Gaurav Shukla<[email protected]>
1 parent 1d4576f commit 479a487

File tree

5 files changed

+59
-33
lines changed

5 files changed

+59
-33
lines changed

onnxruntime/core/providers/iree/compiler/jit_compiler.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,16 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
208208
return common::Status::OK();
209209
}
210210

211-
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output) {
211+
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to) {
212212
// Main compilation.
213213
if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) {
214214
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "IREE compilation error.", ConsumeDiagnostics());
215215
}
216216

217+
// Attach the compiled output to a file.
218+
save_to.append("compiled_model.vmfb");
219+
ireeCompilerOutputOpenFile(save_to.c_str(), &output);
220+
217221
// Output.
218222
if (auto* err = ireeCompilerInvocationOutputVMBytecode(inv, output)) {
219223
return ErrorToStatus(err, "Failure emitting VM bytecode: ");

onnxruntime/core/providers/iree/compiler/jit_compiler.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct CompilerOutput {
4545
// Releases ownership of the output, returning a callback that can be used to
4646
// destroy it at a later date.
4747
std::function<void()> Release() {
48-
iree_compiler_output_t* local_output = output;
48+
iree_compiler_output_t* local_output = this->output;
4949
this->output = nullptr;
5050
return [local_output]() {
5151
if (local_output) {
@@ -84,7 +84,7 @@ struct CompilerInvocation {
8484
common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name);
8585

8686
// Compile and output a VMFB.
87-
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output);
87+
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to);
8888

8989
// If there are any diagnostics, clears them and returns a loggable string.
9090
std::string ConsumeDiagnostics();

onnxruntime/core/providers/iree/iree_ep_runtime.cc

+21-18
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) {
1313
return common::Status::OK();
1414
}
1515

16-
std::string buffer;
17-
iree_host_size_t actual_len;
18-
buffer.resize(1024);
19-
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
20-
&actual_len)) {
21-
buffer.resize(actual_len);
22-
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
23-
&actual_len)) {
24-
actual_len = 0;
25-
}
26-
}
27-
buffer.resize(actual_len);
16+
std::string buffer = iree::Status::ToString(iree_status);
2817

2918
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "IREE Runtime Error: ", std::move(buffer));
3019
}
@@ -43,13 +32,13 @@ Instance::~Instance() {
4332
}
4433
}
4534

46-
iree_status_t Instance::Initialize() {
35+
iree_status_t Instance::Initialize(std::string device_str) {
4736
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
4837
&options, iree_allocator_system(), &instance));
4938

5039
// TODO: Need real device selection.
5140
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
52-
instance, iree_make_cstring_view("local-task"), &device));
41+
instance, iree_make_cstring_view(device_str.c_str()), &device));
5342

5443
return iree_ok_status();
5544
}
@@ -74,11 +63,15 @@ iree_status_t Session::Initialize() {
7463
&session);
7564
}
7665

77-
iree_status_t Session::AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback) {
66+
iree_status_t Session::AppendBytecodeModule(std::string file_loc, std::function<void()> dispose_callback) {
7867
dispose_callbacks.push_back(std::move(dispose_callback));
79-
return iree_runtime_session_append_bytecode_module_from_memory(
80-
session, iree_make_const_byte_span(contents, size),
81-
iree_allocator_null());
68+
// TODO(Shukla-Gaurav): load from memory instead of file.
69+
// return iree_runtime_session_append_bytecode_module_from_memory(
70+
// session, iree_make_const_byte_span(contents, size),
71+
// iree_allocator_null());
72+
file_loc.append("compiled_model.vmfb");
73+
return iree_runtime_session_append_bytecode_module_from_file(
74+
session, file_loc.c_str());
8275
}
8376

8477
namespace {
@@ -245,6 +238,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
245238
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
246239
// TODO: Synchronous mapping read, like everything in this function, is not a
247240
// great idea. It isn't supported on all device types and will need a scrub.
241+
iree_string_view_t device_val = iree_hal_device_id(device);
242+
auto device_str = std::string(device_val.data, device_val.size);
243+
if (device_str == "hip") {
244+
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
245+
iree_runtime_session_device(session),
246+
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
247+
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
248+
iree_infinite_timeout())));
249+
return common::Status::OK();
250+
}
248251
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
249252
output_tensor.GetTensorMutableRawData(),
250253
iree_hal_buffer_view_byte_length(ret.bv))));

onnxruntime/core/providers/iree/iree_ep_runtime.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct Instance {
2727

2828
// Initializes the instance.
2929
// TODO: We should probably pass the options in here and use it to set up.
30-
iree_status_t Initialize();
30+
iree_status_t Initialize(std::string device_str);
3131

3232
// Instance globals.
3333
iree_runtime_instance_options_t options;
@@ -48,7 +48,7 @@ struct Session {
4848
// Append a user-compiled bytecode module buffer to the session, along with a dispose callback.
4949
// The dispose callback will be invoked when Session is destroyed regardless of success/failure
5050
// of this call.
51-
iree_status_t AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback);
51+
iree_status_t AppendBytecodeModule(std::string file_loc, std::function<void()> dispose_callback);
5252

5353
// Calls the entrypoint. This returns an ORT Status and normalizes any IREE statuses to that
5454
// because that can arise from ORT interactions.

onnxruntime/core/providers/iree/iree_execution_provider.cc

+29-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ IREEExecutionProvider::~IREEExecutionProvider() {
3333
}
3434

3535
common::Status IREEExecutionProvider::Initialize() {
36-
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize()));
36+
if (info_.find("device") == info_.end())
37+
info_["device"] = "local-task";
38+
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize(info_["device"])));
3739
return common::Status::OK();
3840
}
3941

@@ -98,15 +100,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
98100
// TODO: The target needs to be synchronized with the runtime based on EP options.
99101
// TODO: We should just be adding the target to the module instead of specifying via
100102
// flags.
101-
std::string device_flag = "--iree-hal-target-backends=";
103+
std::string device_flag = "--iree-hal-target-device=";
102104
if (info_.find("hal_target_device") == info_.end()) {
103-
// In case device info is absent, set `llvm-cpu` as default hal-target-backend.
105+
// In case device info is absent, set `llvm-cpu` as default hal-target-device.
104106
device_flag.append("llvm-cpu");
105107
} else {
106108
device_flag.append(info_["hal_target_device"]);
107109
}
108-
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting device flag as " << device_flag;
110+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << device_flag;
109111
ORT_RETURN_IF_ERROR(compiler.SetFlag(device_flag.c_str()));
112+
113+
// Set all the compile-time flags.
114+
// TODO(Shukla-Gaurav): Use ireeCompilerSessionSetFlags API to set all the flags at once.
115+
// TODO(Shukla-Gaurav): support more than one extra flags by parsing the input string.
116+
if (info_.find("compile_time_flags") != info_.end()) {
117+
std::string extra_flag = info_["compile_time_flags"];
118+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
119+
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
120+
}
121+
110122
ORT_RETURN_IF_ERROR(compiler.Initialize());
111123
std::string module_name = "ort";
112124
iree_ep_jit::CompilerInvocation inv(compiler, module_name.c_str());
@@ -133,20 +145,27 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
133145
if (auto* err = ireeCompilerOutputOpenMembuffer(&vmfb_output.output)) {
134146
return iree_ep_jit::ErrorToStatus(err, "Failure opening compiler output buffer: ");
135147
}
136-
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output));
148+
149+
// This will save the compiled module to current working directory.
150+
if (info_.find("save_to") == info_.end())
151+
info_["save_to"] = "";
152+
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output, info_["save_to"]));
153+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compiled vmfb saved at this location " << info_["save_to"];
137154

138155
// Map raw memory.
139-
void* vmfb_contents;
140-
uint64_t vmfb_size;
141-
ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
156+
// void* vmfb_contents = nullptr;
157+
// uint64_t vmfb_size = 0;
158+
// TODO(Shukla-Gaurav): Map memory instead of storing the compiled module as a file
159+
// ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
142160

143161
// Create a new runtime session.
144162
auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
163+
// In case device info is absent, set `local-task` as default device.
145164
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize()));
146165

147166
// Load the compiled module, releasing our ownership of the CompilerOutput.
148-
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(
149-
vmfb_contents, vmfb_size, vmfb_output.Release())));
167+
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(info_["save_to"],
168+
vmfb_output.Release())));
150169

151170
for (auto& entrypoint_name : entrypoint_names) {
152171
node_compute_funcs.push_back(CreateNodeComputeFunc(entrypoint_name, rt_session));

0 commit comments

Comments
 (0)