@@ -13,18 +13,7 @@ common::Status HandleFailingIREEStatus(iree_status_t iree_status) {
13
13
return common::Status::OK ();
14
14
}
15
15
16
- std::string buffer;
17
- iree_host_size_t actual_len;
18
- buffer.resize (1024 );
19
- if (!iree_status_format (iree_status, buffer.size (), buffer.data (),
20
- &actual_len)) {
21
- buffer.resize (actual_len);
22
- if (!iree_status_format (iree_status, buffer.size (), buffer.data (),
23
- &actual_len)) {
24
- actual_len = 0 ;
25
- }
26
- }
27
- buffer.resize (actual_len);
16
+ std::string buffer = iree::Status::ToString (iree_status);
28
17
29
18
return ORT_MAKE_STATUS (ONNXRUNTIME, RUNTIME_EXCEPTION, " IREE Runtime Error: " , std::move (buffer));
30
19
}
@@ -43,13 +32,13 @@ Instance::~Instance() {
43
32
}
44
33
}
45
34
46
- iree_status_t Instance::Initialize () {
35
+ iree_status_t Instance::Initialize (std::string device_str ) {
47
36
IREE_RETURN_IF_ERROR (iree_runtime_instance_create (
48
37
&options, iree_allocator_system (), &instance));
49
38
50
39
// TODO: Need real device selection.
51
40
IREE_RETURN_IF_ERROR (iree_runtime_instance_try_create_default_device (
52
- instance, iree_make_cstring_view (" local-task " ), &device));
41
+ instance, iree_make_cstring_view (device_str. c_str () ), &device));
53
42
54
43
return iree_ok_status ();
55
44
}
@@ -74,11 +63,15 @@ iree_status_t Session::Initialize() {
74
63
&session);
75
64
}
76
65
77
- iree_status_t Session::AppendBytecodeModule (void * contents, uint64_t size , std::function<void ()> dispose_callback) {
66
+ iree_status_t Session::AppendBytecodeModule (std::string file_loc , std::function<void ()> dispose_callback) {
78
67
dispose_callbacks.push_back (std::move (dispose_callback));
79
- return iree_runtime_session_append_bytecode_module_from_memory (
80
- session, iree_make_const_byte_span (contents, size),
81
- iree_allocator_null ());
68
+ // TODO: load from memory instead of file.
69
+ // return iree_runtime_session_append_bytecode_module_from_memory(
70
+ // session, iree_make_const_byte_span(contents, size),
71
+ // iree_allocator_null());
72
+ file_loc.append (" compiled_model.vmfb" );
73
+ return iree_runtime_session_append_bytecode_module_from_file (
74
+ session, file_loc.c_str ());
82
75
}
83
76
84
77
namespace {
@@ -245,6 +238,16 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
245
238
iree_hal_buffer_t * ret_buffer = iree_hal_buffer_view_buffer (ret.bv );
246
239
// TODO: Synchronous mapping read, like everything in this function, is not a
247
240
// great idea. It isn't supported on all device types and will need a scrub.
241
+ iree_string_view_t device_val = iree_hal_device_id (device);
242
+ auto device_str = std::string (device_val.data , device_val.size );
243
+ if (device_str == " hip" ) {
244
+ ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_device_transfer_d2h (
245
+ iree_runtime_session_device (session),
246
+ ret_buffer, 0 , output_tensor.GetTensorMutableRawData (),
247
+ iree_hal_buffer_view_byte_length (ret.bv ), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
248
+ iree_infinite_timeout ())));
249
+ return common::Status::OK ();
250
+ }
248
251
ORT_RETURN_IF_ERROR (HandleIREEStatus (iree_hal_buffer_map_read (ret_buffer, /* source_offset=*/ 0 ,
249
252
output_tensor.GetTensorMutableRawData (),
250
253
iree_hal_buffer_view_byte_length (ret.bv ))));
0 commit comments