Skip to content

Commit f3b6c56

Browse files
Epctx node base path fix and lint fix (#569)
* Use ep.context_file_path to get base path when creating session from memory * Fixed lint issues --------- Co-authored-by: Javier E. Martinez <[email protected]>
1 parent 8c482a9 commit f3b6c56

File tree

5 files changed

+28
-16
lines changed

5 files changed

+28
-16
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

+14-6
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ BackendManager::BackendManager(SessionContext& session_context,
7676
ptr_stream_t model_stream;
7777
std::unique_ptr<onnx::ModelProto> model_proto;
7878
if (subgraph_context_.is_ep_ctx_graph) {
79-
model_stream = ep_ctx_handle_.GetModelBlobStream(subgraph);
79+
model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph);
8080
} else {
8181
model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
8282
}
@@ -214,21 +214,29 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie
214214
// If not embed_mode, dump the blob here and only pass on the path to the blob
215215
std::string model_blob_str;
216216
auto compiled_model = concrete_backend_->GetOVCompiledModel();
217-
if (session_context_.so_context_embed_mode) {
218-
// Internal blob
217+
if (session_context_.so_context_embed_mode) { // Internal blob
219218
std::ostringstream model_blob_stream;
220219
compiled_model.export_model(model_blob_stream);
221220
model_blob_str = std::move(model_blob_stream).str();
222221
if (model_blob_str.empty()) {
223222
ORT_THROW("Model blob stream is empty after exporting the compiled model.");
224223
}
225-
} else {
226-
// External blob
224+
} else { // External blob
225+
// Build name by combining EpCtx model name (if available) and subgraph name. Model
226+
// name is not available in when creating a session from memory
227+
auto name = session_context_.so_context_file_path.stem().string();
228+
if (!name.empty() && !graph_body_viewer.ModelPath().empty()) {
229+
name = graph_body_viewer.ModelPath().stem().string();
230+
}
231+
if (!name.empty()) {
232+
name += "_";
233+
}
234+
name += subgraph_context_.subgraph_name;
235+
227236
std::filesystem::path blob_filename = session_context_.so_context_file_path;
228237
if (blob_filename.empty()) {
229238
blob_filename = session_context_.onnx_model_path_name;
230239
}
231-
const auto name = graph_body_viewer.ModelPath().stem().string() + "_" + subgraph_context_.subgraph_name;
232240
blob_filename = blob_filename.parent_path() / name;
233241
blob_filename.replace_extension("blob");
234242
std::ofstream blob_file(blob_filename,

onnxruntime/core/providers/openvino/backend_utils.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Met
9292

9393
size_t safe_num_dimensions = num_dimensions;
9494

95-
if(num_dimensions == 0 || safe_num_dimensions > MAX_SAFE_DIMENSIONS) {
96-
ORT_THROW("Invalid number of dimensions provided.");
95+
if (num_dimensions == 0 || safe_num_dimensions > MAX_SAFE_DIMENSIONS) {
96+
ORT_THROW("Invalid number of dimensions provided.");
9797
}
9898
try {
99-
value.dimensions.resize(safe_num_dimensions);
99+
value.dimensions.resize(safe_num_dimensions);
100100
} catch (const std::bad_alloc&) {
101-
ORT_THROW("Error: Memory allocation failed while resizing dimensions.");
101+
ORT_THROW("Error: Memory allocation failed while resizing dimensions.");
102102
}
103103

104104
for (auto& dim : value.dimensions) {

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer,
9999
return Status::OK();
100100
}
101101

102-
std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const GraphViewer& graph_viewer) const {
102+
std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const {
103103
auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin();
104104
auto node = graph_viewer.GetNode(first_index);
105105
ORT_ENFORCE(node != nullptr);
@@ -115,7 +115,11 @@ std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const GraphViewer
115115
if (embed_mode) {
116116
result.reset((std::istream*)new std::istringstream(ep_cache_context));
117117
} else {
118-
const auto& blob_filepath = graph_viewer.ModelPath().parent_path() / ep_cache_context;
118+
auto blob_filepath = so_context_file_path;
119+
if (blob_filepath.empty() && !graph_viewer.ModelPath().empty()) {
120+
blob_filepath = graph_viewer.ModelPath();
121+
}
122+
blob_filepath = blob_filepath.parent_path() / ep_cache_context;
119123
ORT_ENFORCE(std::filesystem::exists(blob_filepath), "Blob file not found: ", blob_filepath.string());
120124
result.reset((std::istream*)new std::ifstream(blob_filepath, std::ios_base::binary | std::ios_base::in));
121125
}

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class EPCtxHandler {
3131
const std::string& graph_name,
3232
const bool embed_mode,
3333
std::string&& model_blob_str) const;
34-
std::unique_ptr<std::istream> GetModelBlobStream(const GraphViewer& graph_viewer) const;
34+
std::unique_ptr<std::istream> GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const;
3535
InlinedVector<const Node*> GetEPCtxNodes() const;
3636

3737
private:

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ void ParseConfigOptions(ProviderInfo& pi, const ConfigOptions& config_options) {
2222
pi.so_context_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
2323
}
2424

25-
void* ParseUint64(const ProviderOptions& provider_options, [[maybe_unused]] std::string option_name) {
26-
if (provider_options.contains("context")) {
27-
uint64_t number = std::strtoull(provider_options.at("context").data(), nullptr, 16);
25+
void* ParseUint64(const ProviderOptions& provider_options, std::string option_name) {
26+
if (provider_options.contains(option_name)) {
27+
uint64_t number = std::strtoull(provider_options.at(option_name).data(), nullptr, 16);
2828
return reinterpret_cast<void*>(number);
2929
} else {
3030
return nullptr;

0 commit comments

Comments
 (0)