@@ -33,7 +33,9 @@ IREEExecutionProvider::~IREEExecutionProvider() {
3333}
3434
3535common::Status IREEExecutionProvider::Initialize () {
36- ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_instance_->Initialize ()));
36+ if (info_.find (" device" ) == info_.end ())
37+ info_[" device" ] = " local-task" ;
38+ ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_instance_->Initialize (info_[" device" ])));
3739 return common::Status::OK ();
3840}
3941
@@ -98,15 +100,25 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
98100 // TODO: The target needs to be synchronized with the runtime based on EP options.
99101 // TODO: We should just be adding the target to the module instead of specifying via
100102 // flags.
101- std::string device_flag = " --iree-hal-target-backends =" ;
103+ std::string device_flag = " --iree-hal-target-device =" ;
102104 if (info_.find (" hal_target_device" ) == info_.end ()) {
103- // In case device info is absent, set `llvm-cpu` as default hal-target-backend .
105+ // In case device info is absent, set `llvm-cpu` as default hal-target-device .
104106 device_flag.append (" llvm-cpu" );
105107 } else {
106108 device_flag.append (info_[" hal_target_device" ]);
107109 }
108- LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compile: setting device flag as " << device_flag;
110+ LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compile: setting flag " << device_flag;
109111 ORT_RETURN_IF_ERROR (compiler.SetFlag (device_flag.c_str ()));
112+
113+ // Set all the compile-time flags.
114+ // TODO(Shukla-Gaurav): Use ireeCompilerSessionSetFlags API to set all the flags at once.
115+ // TODO(Shukla-Gaurav): support more than one extra flags by parsing the input string.
116+ if (info_.find (" compile_time_flags" ) != info_.end ()) {
117+ std::string extra_flag = info_[" compile_time_flags" ];
118+ LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compile: setting flag " << extra_flag;
119+ ORT_RETURN_IF_ERROR (compiler.SetFlag (extra_flag.c_str ()));
120+ }
121+
110122 ORT_RETURN_IF_ERROR (compiler.Initialize ());
111123 std::string module_name = " ort" ;
112124 iree_ep_jit::CompilerInvocation inv (compiler, module_name.c_str ());
@@ -133,20 +145,32 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
133145 if (auto * err = ireeCompilerOutputOpenMembuffer (&vmfb_output.output )) {
134146 return iree_ep_jit::ErrorToStatus (err, " Failure opening compiler output buffer: " );
135147 }
136- ORT_RETURN_IF_ERROR (inv.CompileAndOutputVMFB (vmfb_output.output ));
148+
149+ // This will save the compiled module to temporary directory.
150+ fs::path save_to = fs::temp_directory_path ();
151+ if (info_.find (" save_to" ) != info_.end () && fs::is_directory (info_[" save_to" ]))
152+ save_to = fs::path (info_[" save_to" ]);
153+
154+ fs::path file_name (" compiled_model.vmfb" );
155+ fs::path vmfb_path = save_to / file_name;
156+
157+ ORT_RETURN_IF_ERROR (inv.CompileAndOutputVMFB (vmfb_output.output , vmfb_path));
158+ LOGS (*GetLogger (), INFO) << " IREEExecutionProvider compiled vmfb saved at this location " << vmfb_path;
137159
138160 // Map raw memory.
139- void * vmfb_contents;
140- uint64_t vmfb_size;
141- ORT_RETURN_IF_ERROR (vmfb_output.MapMemory (&vmfb_contents, &vmfb_size));
161+ // void* vmfb_contents = nullptr;
162+ // uint64_t vmfb_size = 0;
163+ // TODO(Shukla-Gaurav): Map memory instead of storing the compiled module as a file
164+ // ORT_RETURN_IF_ERROR(vmfb_output.MapMemory(&vmfb_contents, &vmfb_size));
142165
143166 // Create a new runtime session.
144167 auto rt_session = std::make_shared<iree_ep_rt::Session>(rt_instance_);
168+ // In case device info is absent, set `local-task` as default device.
145169 ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_session->Initialize ()));
146170
147171 // Load the compiled module, releasing our ownership of the CompilerOutput.
148- ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_session->AppendBytecodeModule (
149- vmfb_contents, vmfb_size, vmfb_output.Release ())));
172+ ORT_RETURN_IF_ERROR (iree_ep_rt::HandleIREEStatus (rt_session->AppendBytecodeModule (vmfb_path,
173+ vmfb_output.Release (vmfb_path ))));
150174
151175 for (auto & entrypoint_name : entrypoint_names) {
152176 node_compute_funcs.push_back (CreateNodeComputeFunc (entrypoint_name, rt_session));
0 commit comments