Skip to content

Commit ed9e425

Browse files
mklimenkvthaniel
andauthored
Add self-detecting on-the-fly bfloat16->float16 conversion pass (#741)
* Add on-the-fly bfloat16->float16 conversion pass * Fix undetected bfloat16 initializers * Remove the option and make the logic implicit * Add tests * Rename detection function * Fix CI for strict aliasing rules --------- Co-authored-by: Vishnudas Thaniel S <[email protected]>
1 parent f4da9f1 commit ed9e425

File tree

5 files changed

+197
-3
lines changed

5 files changed

+197
-3
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,18 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
375375
return false;
376376
}
377377

378+
static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) {
379+
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
380+
for (std::size_t i = 0; i < node_indices.size(); i++) {
381+
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));
382+
for (auto& output : node->OutputDefs()) {
383+
if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
384+
return true;
385+
}
386+
}
387+
return false;
388+
}
389+
378390
static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
379391
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
380392
[[maybe_unused]] const onnxruntime::Node& fused_node) {
@@ -456,6 +468,16 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
456468
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
457469
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
458470
return model_proto;
471+
} else if (IsModelBF16(subgraph)) {
472+
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled";
473+
std::unique_ptr<onnxruntime::Model> model;
474+
Status status = bfloat16_fix::Transform(subgraph, logger, model);
475+
auto model_proto = model->ToProto();
476+
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
477+
print_model_proto_duration();
478+
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
479+
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
480+
return model_proto;
459481
} else {
460482
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled";
461483
auto model = subgraph.CreateModel(logger);

onnxruntime/core/providers/openvino/ov_versions/data_ops.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,11 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
555555
return false;
556556
}
557557

558+
auto dtype = type_proto->tensor_type().elem_type();
559+
// Enable bfloat16 -> float16 on-the-fly conversion
560+
if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16)
561+
return true;
558562
if (is_initializer) {
559-
auto dtype = type_proto->tensor_type().elem_type();
560563
for (auto const& var : supported_types_initializer_) {
561564
if ((var.first <= version_id_) &&
562565
(var.second == dtype)) {
@@ -571,8 +574,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
571574
#endif
572575
return false;
573576
} else {
574-
auto dtype = type_proto->tensor_type().elem_type();
575-
576577
if (device_id_.find("HETERO") != std::string::npos ||
577578
device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) {
578579
for (auto const& var : supported_types_npu_) {

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "qdq_scales_fix.h"
55
#include "core/providers/openvino/ov_protobuf_utils.h"
6+
#include "core/framework/float16.h"
67

78
#include <fstream>
89
#include <list>
@@ -940,5 +941,54 @@ Status Transform(const GraphViewer& src_graph_viewer,
940941
return status;
941942
}
942943
} // namespace qdq_scales_fix
944+
945+
namespace bfloat16_fix {
946+
void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) {
947+
for (auto& const_node : gen_graph.original_graph.Nodes()) {
948+
auto node = const_cast<ONNX_NAMESPACE::Node*>(const_node);
949+
if (node->OpType() == "Cast") {
950+
for (auto& [name, const_attribute] : node->GetAttributes()) {
951+
auto& attribute = const_cast<ONNX_NAMESPACE::AttributeProto&>(const_attribute);
952+
if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT)
953+
if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
954+
attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
955+
}
956+
}
957+
for (auto& output : node->OutputDefs()) {
958+
auto& output_proto = const_cast<ONNX_NAMESPACE::TypeProto&>(output->ToProto().type());
959+
if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
960+
output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
961+
}
962+
}
963+
964+
const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors();
965+
for (auto& [key, const_tensor_proto] : init_set) {
966+
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(const_tensor_proto);
967+
auto dt = tensor_proto->data_type();
968+
if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
969+
auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast<std::uint16_t*>(tensor_proto->mutable_raw_data()->data()) : nullptr;
970+
if (raw_data) {
971+
tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
972+
std::int64_t size = 1;
973+
for (int i = 0; i < tensor_proto->dims_size(); ++i)
974+
size *= tensor_proto->dims()[i];
975+
for (std::int64_t i = 0; i < size; ++i) {
976+
raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val;
977+
}
978+
}
979+
}
980+
}
981+
}
982+
983+
Status Transform(const GraphViewer& src_graph_viewer,
984+
const logging::Logger& logger,
985+
/*out*/ std::unique_ptr<onnxruntime::Model>& model) {
986+
auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model);
987+
auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph());
988+
989+
replace_bf16_with_fp16(g);
990+
return status;
991+
}
992+
} // namespace bfloat16_fix
943993
} // namespace openvino_ep
944994
} // namespace onnxruntime

onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,10 @@ Status Transform(const GraphViewer& src_graph,
1515
const logging::Logger& logger,
1616
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
1717
}
18+
namespace bfloat16_fix {
19+
Status Transform(const GraphViewer& src_graph,
20+
const logging::Logger& logger,
21+
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
22+
}
1823
} // namespace openvino_ep
1924
} // namespace onnxruntime
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <filesystem>
5+
#include <map>
6+
#include <string>
7+
8+
#include "core/session/onnxruntime_cxx_api.h"
9+
#include "core/framework/float16.h"
10+
11+
#include "test/util/include/test/test_environment.h"
12+
#include "test/optimizer/qdq_test_utils.h"
13+
14+
#include "gtest/gtest.h"
15+
#include "gmock/gmock.h"
16+
17+
using namespace ONNX_NAMESPACE;
18+
using namespace onnxruntime::logging;
19+
20+
extern std::unique_ptr<Ort::Env> ort_env;
21+
22+
class OVEP_BF16_Tests : public ::testing::TestWithParam<std::string> {};
23+
24+
namespace detail {
25+
auto ConstructModel() {
26+
using namespace onnxruntime;
27+
using namespace test;
28+
29+
std::unordered_map<std::string, int> domain_to_version;
30+
domain_to_version[kOnnxDomain] = 19;
31+
Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
32+
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
33+
34+
Graph& graph = model.MainGraph();
35+
ModelTestBuilder builder(graph);
36+
auto dim = 4;
37+
std::vector<float> input_data(dim, 1.0f);
38+
auto* input = builder.MakeInput<float>({dim}, input_data);
39+
builder.graph_.SetInputs({input});
40+
41+
auto* cast_to_bf16 = builder.MakeIntermediate();
42+
Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, "");
43+
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16));
44+
45+
std::vector<onnxruntime::BFloat16> weight_data(dim * dim);
46+
for (std::size_t i = 0; i < weight_data.size(); ++i)
47+
weight_data[i] = onnxruntime::BFloat16(static_cast<float>(i % dim) / dim);
48+
auto* weights = builder.MakeInitializer<onnxruntime::BFloat16>({dim, dim}, weight_data);
49+
50+
auto* matmul_out = builder.MakeIntermediate();
51+
builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out});
52+
53+
std::vector<onnxruntime::BFloat16> weight_data_2(dim * dim);
54+
for (std::size_t i = 0; i < weight_data_2.size(); ++i)
55+
weight_data_2[i] = onnxruntime::BFloat16(static_cast<float>(i % dim) / dim);
56+
auto* weights_2 = builder.MakeInitializer<onnxruntime::BFloat16>({dim, dim}, weight_data_2);
57+
58+
auto* matmul_out_2 = builder.MakeIntermediate();
59+
builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2});
60+
61+
auto* output = builder.MakeOutput();
62+
Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output});
63+
cast2_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
64+
65+
builder.SetGraphOutputs();
66+
auto st = model.MainGraph().Resolve();
67+
if (st != Status::OK())
68+
throw std::runtime_error(st.ErrorMessage());
69+
return model;
70+
}
71+
72+
auto ProbeDevice(const std::string& device) {
73+
static std::map<std::string, bool> is_present;
74+
if (is_present.find(device) == is_present.end()) {
75+
Ort::SessionOptions sessionOptions;
76+
std::unordered_map<std::string, std::string> ov_options;
77+
ov_options["device_type"] = device;
78+
try {
79+
sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options);
80+
is_present[device] = true;
81+
} catch (...) {
82+
is_present[device] = false;
83+
}
84+
}
85+
return is_present[device];
86+
}
87+
} // namespace detail
88+
89+
namespace onnxruntime {
90+
namespace test {
91+
92+
TEST_P(OVEP_BF16_Tests, TestModelConversion) {
93+
Ort::SessionOptions sessionOptions;
94+
std::unordered_map<std::string, std::string> ov_options;
95+
const auto& device = GetParam();
96+
if (!::detail::ProbeDevice(device))
97+
GTEST_SKIP() << device + " is not available on this machine";
98+
99+
ov_options["device_type"] = device;
100+
auto model = ::detail::ConstructModel();
101+
sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options);
102+
103+
std::string model_data;
104+
model.ToProto().SerializeToString(&model_data);
105+
auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
106+
try {
107+
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions);
108+
} catch (...) {
109+
FAIL();
110+
}
111+
}
112+
INSTANTIATE_TEST_SUITE_P(OVEP_Tests,
113+
OVEP_BF16_Tests,
114+
::testing::Values("CPU", "GPU", "NPU"));
115+
} // namespace test
116+
} // namespace onnxruntime

0 commit comments

Comments
 (0)