Skip to content

Commit b159e1c

Browse files
committed
[IREE][EP] Add support for rocm backend
This commit adds support for rocm backend in iree-ep. Signed-Off-by: Gaurav Shukla<[email protected]> (cherry picked from commit 00b1696)
1 parent 08acace commit b159e1c

File tree

7 files changed

+62
-36
lines changed

7 files changed

+62
-36
lines changed

onnxruntime/core/providers/iree/compiler/jit_compiler.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,16 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
222222
return common::Status::OK();
223223
}
224224

225-
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output) {
225+
common::Status CompilerInvocation::CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to) {
226226
// Main compilation.
227227
if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) {
228228
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "IREE compilation error.", ConsumeDiagnostics());
229229
}
230230

231+
// Attach the compiled output to a file.
232+
save_to.append("compiled_model.vmfb");
233+
ireeCompilerOutputOpenFile(save_to.c_str(), &output);
234+
231235
// Output.
232236
if (auto* err = ireeCompilerInvocationOutputVMBytecode(inv, output)) {
233237
return ErrorToStatus(err, "Failure emitting VM bytecode: ");

onnxruntime/core/providers/iree/compiler/jit_compiler.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct CompilerOutput {
4545
// Releases ownership of the output, returning a callback that can be used to
4646
// destroy it at a later date.
4747
std::function<void()> Release() {
48-
iree_compiler_output_t* local_output = output;
48+
iree_compiler_output_t* local_output = this->output;
4949
this->output = nullptr;
5050
return [local_output]() {
5151
if (local_output) {
@@ -84,7 +84,7 @@ struct CompilerInvocation {
8484
common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name);
8585

8686
// Compile and output a VMFB.
87-
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output);
87+
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output, std::string save_to);
8888

8989
// If there are any diagnostics, clears them and returns a loggable string.
9090
std::string ConsumeDiagnostics();

onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) {
392392
int8_conversion.reserve(tp.int32_data_size());
393393
for (int32_t v : tp.int32_data())
394394
int8_conversion.push_back(v);
395-
return mlirDenseElementsAttrInt8Get(
396-
tensor_type, int8_conversion.size(), int8_conversion.data());
395+
return mlirDenseElementsAttrInt8Get(tensor_type, int8_conversion.size(),
396+
int8_conversion.data());
397397
}
398398
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
399399
return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(),

onnxruntime/core/providers/iree/iree_ep_runtime.cc

+21-18
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) {
1313
return common::Status::OK();
1414
}
1515

16-
std::string buffer;
17-
iree_host_size_t actual_len;
18-
buffer.resize(1024);
19-
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
20-
&actual_len)) {
21-
buffer.resize(actual_len);
22-
if (!iree_status_format(iree_status, buffer.size(), buffer.data(),
23-
&actual_len)) {
24-
actual_len = 0;
25-
}
26-
}
27-
buffer.resize(actual_len);
16+
std::string buffer = iree::Status::ToString(iree_status);
2817

2918
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "IREE Runtime Error: ", std::move(buffer));
3019
}
@@ -43,13 +32,13 @@ Instance::~Instance() {
4332
}
4433
}
4534

46-
iree_status_t Instance::Initialize() {
35+
iree_status_t Instance::Initialize(std::string device_str) {
4736
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
4837
&options, iree_allocator_system(), &instance));
4938

5039
// TODO: Need real device selection.
5140
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
52-
instance, iree_make_cstring_view("local-task"), &device));
41+
instance, iree_make_cstring_view(device_str.c_str()), &device));
5342

5443
return iree_ok_status();
5544
}
@@ -74,11 +63,15 @@ iree_status_t Session::Initialize() {
7463
&session);
7564
}
7665

77-
iree_status_t Session::AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback) {
66+
iree_status_t Session::AppendBytecodeModule(std::string file_loc, std::function<void()> dispose_callback) {
7867
dispose_callbacks.push_back(std::move(dispose_callback));
79-
return iree_runtime_session_append_bytecode_module_from_memory(
80-
session, iree_make_const_byte_span(contents, size),
81-
iree_allocator_null());
68+
// TODO: load from memory instead of file.
69+
// return iree_runtime_session_append_bytecode_module_from_memory(
70+
// session, iree_make_const_byte_span(contents, size),
71+
// iree_allocator_null());
72+
file_loc.append("compiled_model.vmfb");
73+
return iree_runtime_session_append_bytecode_module_from_file(
74+
session, file_loc.c_str());
8275
}
8376

8477
namespace {
@@ -245,6 +238,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
245238
iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
246239
// TODO: Synchronous mapping read, like everything in this function, is not a
247240
// great idea. It isn't supported on all device types and will need a scrub.
241+
iree_string_view_t device_val = iree_hal_device_id(device);
242+
auto device_str = std::string(device_val.data, device_val.size);
243+
if (device_str == "hip") {
244+
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
245+
iree_runtime_session_device(session),
246+
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
247+
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
248+
iree_infinite_timeout())));
249+
return common::Status::OK();
250+
}
248251
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
249252
output_tensor.GetTensorMutableRawData(),
250253
iree_hal_buffer_view_byte_length(ret.bv))));

onnxruntime/core/providers/iree/iree_ep_runtime.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct Instance {
2727

2828
// Initializes the instance.
2929
// TODO: We should probably pass the options in here and use it to set up.
30-
iree_status_t Initialize();
30+
iree_status_t Initialize(std::string device_str);
3131

3232
// Instance globals.
3333
iree_runtime_instance_options_t options;
@@ -48,7 +48,7 @@ struct Session {
4848
// Append a user-compiled bytecode module buffer to the session, along with a dispose callback.
4949
// The dispose callback will be invoked when Session is destroyed regardless of success/failure
5050
// of this call.
51-
iree_status_t AppendBytecodeModule(void* contents, uint64_t size, std::function<void()> dispose_callback);
51+
iree_status_t AppendBytecodeModule(std::string file_loc, std::function<void()> dispose_callback);
5252

5353
// Calls the entrypoint. This returns an ORT Status and normalizes any IREE statuses to that
5454
// because that can arise from ORT interactions.

onnxruntime/core/providers/iree/iree_execution_provider.cc

+29-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ IREEExecutionProvider::~IREEExecutionProvider() {
3333
}
3434

3535
common::Status IREEExecutionProvider::Initialize() {
36-
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize()));
36+
if (info_.find("device") == info_.end())
37+
info_["device"] = "local-task";
38+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider runtime device set as " << info_["device"];
39+
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_instance_->Initialize(info_["device"])));
3740
return common::Status::OK();
3841
}
3942

@@ -102,15 +105,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
102105
// TODO: The target needs to be synchronized with the runtime based on EP options.
103106
// TODO: We should just be adding the target to the module instead of specifying via
104107
// flags.
105-
std::string device_flag = "--iree-hal-target-backends=";
108+
std::string device_flag = "--iree-hal-target-device=";
106109
if (info_.find("hal_target_device") == info_.end()) {
107-
// In case device info is absent, set `llvm-cpu` as default hal-target-backend.
110+
// In case device info is absent, set `llvm-cpu` as default hal-target-device.
108111
device_flag.append("llvm-cpu");
109112
} else {
110113
device_flag.append(info_["hal_target_device"]);
111114
}
112-
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting device flag as " << device_flag;
115+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << device_flag;
113116
ORT_RETURN_IF_ERROR(compiler.SetFlag(device_flag.c_str()));
117+
118+
// Set all the compile-time flags.
119+
// TODO: Use ireeCompilerSessionSetFlags API to set all the flags at once.
120+
// TODO: support more than one extra flags by parsing the input string.
121+
if (info_.find("compile_time_flags") != info_.end()) {
122+
std::string extra_flag = info_["compile_time_flags"];
123+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compile: setting flag " << extra_flag;
124+
ORT_RETURN_IF_ERROR(compiler.SetFlag(extra_flag.c_str()));
125+
}
126+
114127
ORT_RETURN_IF_ERROR(compiler.Initialize());
115128
std::string module_name = "ort";
116129
iree_ep_jit::CompilerInvocation inv(compiler, module_name.c_str());
@@ -137,20 +150,26 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
137150
if (auto* err = ireeCompilerOutputOpenMembuffer(&vmfb_output.output)) {
138151
return iree_ep_jit::ErrorToStatus(err, "Failure opening compiler output buffer: ");
139152
}
140-
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output));
153+
154+
// This will save the compiled module to current working directory.
155+
if (info_.find("save_to") == info_.end())
156+
info_["save_to"] = "";
157+
ORT_RETURN_IF_ERROR(inv.CompileAndOutputVMFB(vmfb_output.output, info_["save_to"]));
158+
LOGS(*GetLogger(), INFO) << "IREEExecutionProvider compiled vmfb saved at this location " << info_["save_to"];
141159

142160
// Map raw memory.
143-
void* vmfb_contents;
144-
uint64_t vmfb_size;
145-
ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
161+
// void* vmfb_contents = nullptr;
162+
// uint64_t vmfb_size = 0;
163+
// TODO: Map memory instead of storing the compiled module as a file
164+
// ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
146165

147166
// Create a new runtime session.
148167
auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
168+
// In case device info is absent, set `local-task` as default device.
149169
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize()));
150170

151171
// Load the compiled module, releasing our ownership of the CompilerOutput.
152-
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(
153-
vmfb_contents, vmfb_size, vmfb_output.Release())));
172+
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(info_["save_to"], vmfb_output.Release())));
154173

155174
for (auto& entrypoint_name : entrypoint_names) {
156175
node_compute_funcs.push_back(CreateNodeComputeFunc(entrypoint_name, rt_session));

onnxruntime/python/onnxruntime_pybind_state.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
11411141
#endif
11421142
} else if (type == kIreeExecutionProvider) {
11431143
#if USE_IREE
1144-
const auto &it = provider_options_map.find(type);
1144+
const auto& it = provider_options_map.find(type);
11451145
ProviderOptions iree_option_map = ProviderOptions{};
11461146
if (it != provider_options_map.end()) {
11471147
iree_option_map = it->second;

0 commit comments

Comments
 (0)