Skip to content

Commit 1c81f11

Browse files
committed
[IREE][EP] Add support for rocm backend
This commit adds support for rocm backend in iree-ep. Signed-Off-by: Gaurav Shukla<[email protected]>
1 parent 1d4576f commit 1c81f11

File tree

5 files changed

+73
-36
lines changed

5 files changed

+73
-36
lines changed

onnxruntime/core/providers/iree/compiler/jit_compiler.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "mlir-c/BuiltinAttributes.h"
1111

1212
#include <cstring>
13-
#include <filesystem>
1413

1514
namespace onnxruntime::iree_ep_jit {
1615

@@ -208,12 +207,15 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
208207
return common::Status::OK();
209208
}
210209

211-
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output) {
210+
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output, fs::path vmfb_path) {
212211
// Main compilation.
213212
if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) {
214213
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "IREE compilation error.", ConsumeDiagnostics());
215214
}
216215

216+
// Attach the compiled output to a file.
217+
ireeCompilerOutputOpenFile(vmfb_path.c_str(), &output);
218+
217219
// Output.
218220
if (auto* err = ireeCompilerInvocationOutputVMBytecode(inv, output)) {
219221
return ErrorToStatus(err, "Failure emitting VM bytecode: ");

onnxruntime/core/providers/iree/compiler/jit_compiler.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
#include "iree/compiler/embedding_api.h"
1313
#include "iree/compiler/mlir_interop.h"
1414

15+
#include <filesystem>
1516
#include <string>
1617
#include <string_view>
1718

19+
namespace fs = std::filesystem;
20+
1821
namespace onnxruntime::iree_ep_jit {
1922

2023
common::Status ErrorToStatus(iree_compiler_error_t* err, std::string message_prefix);
@@ -44,13 +47,14 @@ struct CompilerOutput {
4447

4548
// Releases ownership of the output, returning a callback that can be used to
4649
// destroy it at a later date.
47-
std::function<void()> Release() {
48-
iree_compiler_output_t* local_output = output;
50+
std::function<void()> Release(fs::path vmfb_path) {
51+
iree_compiler_output_t* local_output = this->output;
4952
this->output = nullptr;
50-
return [local_output]() {
53+
return [local_output, &vmfb_path]() {
5154
if (local_output) {
5255
ireeCompilerOutputDestroy(local_output);
5356
}
57+
fs::remove(vmfb_path);
5458
};
5559
}
5660

@@ -84,7 +88,7 @@ struct CompilerInvocation {
8488
common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name);
8589

8690
// Compile and output a VMFB.
87-
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output);
91+
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output, fs::path vmfb_path);
8892

8993
// If there are any diagnostics, clears them and returns a loggable string.
9094
std::string ConsumeDiagnostics();

onnxruntime/core/providers/iree/iree_ep_runtime.cc

+20-18
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) {
1313
return common::Status::OK();
1414
}
1515

16-
std::string buffer;
17-
iree_host_size_t actual_len;
18-
buffer.resize(1024);
19-
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
20-
&actual_len)) {
21-
buffer.resize(actual_len);
22-
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
23-
&actual_len)) {
24-
actual_len = 0;
25-
}
26-
}
27-
buffer.resize(actual_len);
16+
std::string buffer = iree::Status::ToString(iree_status);
2817

2918
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "IREE Runtime Error: ", std::move(buffer));
3019
}
@@ -43,13 +32,13 @@ Instance::~Instance() {
4332
}
4433
}
4534

46-
iree_status_t Instance::Initialize() {
35+
iree_status_t Instance::Initialize(std::string device_str) {
4736
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
4837
&options, iree_allocator_system(), &instance));
4938

5039
// TODO: Need real device selection.
5140
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
52-
instance, iree_make_cstring_view("local-task"), &device));
41+
instance, iree_make_cstring_view(device_str.c_str()), &device));
5342

5443
return iree_ok_status();
5544
}
@@ -74,11 +63,14 @@ iree_status_t Session::Initialize() {
7463
&session);
7564
}
7665

77-
iree_status_t Session::AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback) {
66+
iree_status_t Session::AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback) {
7867
dispose_callbacks.push_back(std::move(dispose_callback));
79-
return iree_runtime_session_append_bytecode_module_from_memory(
80-
session, iree_make_const_byte_span(contents, size),
81-
iree_allocator_null());
68+
// TODO(Shukla-Gaurav): load from memory instead of file.
69+
// return iree_runtime_session_append_bytecode_module_from_memory(
70+
// session, iree_make_const_byte_span(contents, size),
71+
// iree_allocator_null());
72+
return iree_runtime_session_append_bytecode_module_from_file(
73+
session, file_loc.c_str());
8274
}
8375

8476
namespace {
@@ -245,6 +237,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
245237
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
246238
// TODO: Synchronous mapping read, like everything in this function, is not a
247239
// great idea. It isn't supported on all device types and will need a scrub.
240+
iree_string_view_t device_val = iree_hal_device_id(device);
241+
auto device_str = std::string(device_val.data, device_val.size);
242+
if (device_str == "hip") {
243+
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
244+
iree_runtime_session_device(session),
245+
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
246+
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
247+
iree_infinite_timeout())));
248+
return common::Status::OK();
249+
}
248250
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
249251
output_tensor.GetTensorMutableRawData(),
250252
iree_hal_buffer_view_byte_length(ret.bv))));

onnxruntime/core/providers/iree/iree_ep_runtime.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include "core/session/onnxruntime_c_api.h"
88
#include "iree/runtime/api.h"
99

10+
#include <filesystem>
11+
12+
namespace fs = std::filesystem;
13+
1014
namespace onnxruntime::iree_ep_rt {
1115

1216
// Handles a failing IREE status.
@@ -27,7 +31,7 @@ struct Instance {
2731

2832
// Initializes the instance.
2933
// TODO: We should probably pass the options in here and use it to set up.
30-
iree_status_t Initialize();
34+
iree_status_t Initialize(std::string device_str);
3135

3236
// Instance globals.
3337
iree_runtime_instance_options_t options;
@@ -48,7 +52,7 @@ struct Session {
4852
// Append a user-compiled bytecode module buffer to the session, along with a dispose callback.
4953
// The dispose callback will be invoked when Session is destroyed regardless of success/failure
5054
// of this call.
51-
iree_status_t AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback);
55+
iree_status_t AppendBytecodeModule(fs::path vmfb_path, std::function<void()> dispose_callback);
5256

5357
// Calls the entrypoint. This returns an ORT Status and normalizes any IREE statuses to that
5458
// because that can arise from ORT interactions.

onnxruntime/core/providers/iree/iree_execution_provider.cc

+35-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ IREEExecutionProvider::~IREEExecutionProvider() {
3333
}
3434

3535
common::Status IREEExecutionProvider::Initialize() {
36-
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize()));
36+
if (info_.find("device") == info_.end())
37+
info_["device"] = "local-task";
38+
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize(info_["device"])));
3739
return common::Status::OK();
3840
}
3941

@@ -98,15 +100,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
98100
// TODO: The target needs to be synchronized with the runtime based on EP options.
99101
// TODO: We should just be adding the target to the module instead of specifying via
100102
// flags.
101-
std::string device_flag = "--iree-hal-target-backends=";
103+
std::string device_flag = "--iree-hal-target-device=";
102104
if (info_.find("hal_target_device") == info_.end()) {
103-
// In case device info is absent, set `llvm-cpu` as default hal-target-backend.
105+
// In case device info is absent, set `llvm-cpu` as default hal-target-device.
104106
device_flag.append("llvm-cpu");
105107
} else {
106108
device_flag.append(info_["hal_target_device"]);
107109
}
108-
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting device flag as " << device_flag;
110+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << device_flag;
109111
ORT_RETURN_IF_ERROR(compiler.SetFlag(device_flag.c_str()));
112+
113+
// Set all the compile-time flags.
114+
// TODO(Shukla-Gaurav): Use ireeCompilerSessionSetFlags API to set all the flags at once.
115+
// TODO(Shukla-Gaurav): support more than one extra flags by parsing the input string.
116+
if (info_.find("compile_time_flags") != info_.end()) {
117+
std::string extra_flag = info_["compile_time_flags"];
118+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
119+
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
120+
}
121+
110122
ORT_RETURN_IF_ERROR(compiler.Initialize());
111123
std::string module_name = "ort";
112124
iree_ep_jit::CompilerInvocation inv(compiler, module_name.c_str());
@@ -133,20 +145,33 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
133145
if (auto* err = ireeCompilerOutputOpenMembuffer(&vmfb_output.output)) {
134146
return iree_ep_jit::ErrorToStatus(err, "Failure opening compiler output buffer: ");
135147
}
136-
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output));
148+
149+
// This will save the compiled module to temporary directory.
150+
fs::path save_to = fs::temp_directory_path();
151+
if (info_.find("save_to") != info_.end() && fs::is_directory(info_["save_to"])
152+
save_to = fs::path(info_["save_to"]);
153+
154+
fs::path file_name("compiled_model.vmfb");
155+
fs::path vmfb_path = save_to / file_name;
156+
157+
158+
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output, vmfb_path));
159+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compiled vmfb saved at this location " << vmfb_path;
137160

138161
// Map raw memory.
139-
void* vmfb_contents;
140-
uint64_t vmfb_size;
141-
ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
162+
// void* vmfb_contents = nullptr;
163+
// uint64_t vmfb_size = 0;
164+
// TODO(Shukla-Gaurav): Map memory instead of storing the compiled module as a file
165+
// ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
142166

143167
// Create a new runtime session.
144168
auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
169+
// In case device info is absent, set `local-task` as default device.
145170
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize()));
146171

147172
// Load the compiled module, releasing our ownership of the CompilerOutput.
148-
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(
149-
vmfb_contents, vmfb_size, vmfb_output.Release())));
173+
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(vmfb_path,
174+
vmfb_output.Release(vmfb_path))));
150175

151176
for (auto& entrypoint_name : entrypoint_names) {
152177
node_compute_funcs.push_back(CreateNodeComputeFunc(entrypoint_name, rt_session));

0 commit comments

Comments
 (0)