@@ -33,7 +33,9 @@ IREEExecutionProvider::~IREEExecutionProvider() {
33
33
}
34
34
35
35
common::Status IREEExecutionProvider::Initialize () {
36
- ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_instance_->Initialize ()));
36
+ if (info_.find (" device" ) == info_.end ())
37
+ info_[" device" ] = " local-task" ;
38
+ ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_instance_->Initialize (info_[" device" ])));
37
39
return common::Status::OK ();
38
40
}
39
41
@@ -98,15 +100,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
98
100
// TODO: The target needs to be synchronized with the runtime based on EP options.
99
101
// TODO: We should just be adding the target to the module instead of specifying via
100
102
// flags.
101
- std::string device_flag = " --iree-hal-target-backends =" ;
103
+ std::string device_flag = " --iree-hal-target-device =" ;
102
104
if (info_.find (" hal_target_device" ) == info_.end ()) {
103
- // In case device info is absent, set `llvm-cpu` as default hal-target-backend .
105
+ // In case device info is absent, set `llvm-cpu` as default hal-target-device .
104
106
device_flag.append (" llvm-cpu" );
105
107
} else {
106
108
device_flag.append (info_[" hal_target_device" ]);
107
109
}
108
- LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compile: setting device flag as " << device_flag;
110
+ LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compile: setting flag " << device_flag;
109
111
ORT_RETURN_IF_ERROR (compiler.SetFlag (device_flag.c_str ()));
112
+
113
+ // Set all the compile-time flags.
114
+ // TODO(Shukla-Gaurav): Use ireeCompilerSessionSetFlags API to set all the flags at once.
115
+ // TODO(Shukla-Gaurav): support more than one extra flags by parsing the input string.
116
+ if (info_.find (" compile_time_flags" ) != info_.end ()) {
117
+ std::string extra_flag = info_[" compile_time_flags" ];
118
+ LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compile: setting flag " << extra_flag;
119
+ ORT_RETURN_IF_ERROR (compiler.SetFlag (extra_flag.c_str ()));
120
+ }
121
+
110
122
ORT_RETURN_IF_ERROR (compiler.Initialize ());
111
123
std::string module_name = " ort" ;
112
124
iree_ep_jit::CompilerInvocation inv (compiler, module_name.c_str ());
@@ -133,20 +145,32 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
133
145
if (auto * err = ireeCompilerOutputOpenMembuffer (&vmfb_output.output )) {
134
146
return iree_ep_jit::ErrorToStatus (err, " Failure opening compiler output buffer: " );
135
147
}
136
- ORT_RETURN_IF_ERROR (inv.CompileAndOutputVMFB (vmfb_output.output ));
148
+
149
+ // This will save the compiled module to temporary directory.
150
+ fs::path save_to = fs::temp_directory_path ();
151
+ if (info_.find (" save_to" ) != info_.end () && fs::is_directory (info_[" save_to" ]))
152
+ save_to = fs::path (info_[" save_to" ]);
153
+
154
+ fs::path file_name (" compiled_model.vmfb" );
155
+ fs::path vmfb_path = save_to / file_name;
156
+
157
+ ORT_RETURN_IF_ERROR (inv.CompileAndOutputVMFB (vmfb_output.output , vmfb_path));
158
+ LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compiled vmfb saved at this location " << vmfb_path;
137
159
138
160
// Map raw memory.
139
- void * vmfb_contents;
140
- uint64_t vmfb_size;
141
- ORT_RETURN_IF_ERROR (vmfb_output.MapMemory (&vmfb_contents, &vmfb_size));
161
+ // void* vmfb_contents = nullptr;
162
+ // uint64_t vmfb_size = 0;
163
+ // TODO(Shukla-Gaurav): Map memory instead of storing the compiled module as a file
164
+ // ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
142
165
143
166
// Create a new runtime session.
144
167
auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
168
+ // In case device info is absent, set `local-task` as default device.
145
169
ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_session->Initialize ()));
146
170
147
171
// Load the compiled module, releasing our ownership of the CompilerOutput.
148
- ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_session->AppendBytecodeModule (
149
- vmfb_contents, vmfb_size, vmfb_output.Release ())));
172
+ ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_session->AppendBytecodeModule (vmfb_path,
173
+ vmfb_output.Release (vmfb_path ))));
150
174
151
175
for (auto & entrypoint_name : entrypoint_names) {
152
176
node_compute_funcs.push_back (CreateNodeComputeFunc (entrypoint_name, rt_session));
0 commit comments