|
2 | 2 | // Licensed under the MIT License
|
3 | 3 |
|
4 | 4 | #include <map>
|
| 5 | +#include <unordered_set> |
| 6 | + |
5 | 7 | #include <string>
|
6 | 8 | #include <memory>
|
7 | 9 | #include <sstream>
|
@@ -222,6 +224,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
|
222 | 224 | }
|
223 | 225 | }
|
224 | 226 | }
|
| 227 | + auto find_device_type_mode = [&](const std::string& device_type) -> std::string { |
| 228 | + std::string device_mode = ""; |
| 229 | + auto delimiter_pos = device_type.find(':'); |
| 230 | + if (delimiter_pos != std::string::npos) { |
| 231 | + std::stringstream str_stream(device_type.substr(0, delimiter_pos)); |
| 232 | + std::getline(str_stream, device_mode, ','); |
| 233 | + } |
| 234 | + return device_mode; |
| 235 | + }; |
225 | 236 |
|
226 | 237 | // Parse device types like "AUTO:CPU,GPU" and extract individual devices
|
227 | 238 | auto parse_individual_devices = [&](const std::string& device_type) -> std::vector<std::string> {
|
@@ -270,8 +281,14 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
|
270 | 281 | if (session_context_.device_type.find("AUTO") == 0 ||
|
271 | 282 | session_context_.device_type.find("HETERO") == 0 ||
|
272 | 283 | session_context_.device_type.find("MULTI") == 0) {
|
| 284 | + //// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO") |
| 285 | + std::unordered_set<std::string> supported_mode = {"AUTO", "HETERO", "MULTI"}; |
| 286 | + auto device_mode = find_device_type_mode(session_context_.device_type); |
| 287 | + ORT_ENFORCE(supported_mode.find(device_mode)!=supported_mode.end(), " Invalid device mode is passed : " , session_context_.device_type); |
273 | 288 | // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"])
|
274 | 289 | auto individual_devices = parse_individual_devices(session_context_.device_type);
|
| 290 | + if (!device_mode.empty()) individual_devices.emplace_back(device_mode); |
| 291 | + |
275 | 292 | // Set properties only for individual devices (e.g., "CPU", "GPU")
|
276 | 293 | for (const std::string& device : individual_devices) {
|
277 | 294 | if (target_config.count(device)) {
|
|
0 commit comments