|
4 | 4 | #include "qdq_scales_fix.h" |
5 | 5 | #include "core/providers/openvino/ov_protobuf_utils.h" |
6 | 6 | #include "core/framework/ort_value.h" |
7 | | -#include "core/common/float16.h" |
8 | 7 |
|
9 | 8 | #include <fstream> |
10 | 9 | #include <list> |
@@ -955,60 +954,5 @@ Status Transform(const GraphViewer& src_graph_viewer, |
955 | 954 | return status; |
956 | 955 | } |
957 | 956 | } // namespace qdq_scales_fix |
958 | | - |
959 | | -namespace bfloat16_fix { |
960 | | -void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { |
961 | | - for (auto& const_node : gen_graph.original_graph.Nodes()) { |
962 | | - auto node = const_cast<ONNX_NAMESPACE::Node*>(const_node); |
963 | | - if (node->OpType() == "Cast") { |
964 | | - for (auto& [name, const_attribute] : node->GetAttributes()) { |
965 | | - auto& attribute = const_cast<ONNX_NAMESPACE::AttributeProto&>(const_attribute); |
966 | | - if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT) |
967 | | - if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) |
968 | | - attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); |
969 | | - } |
970 | | - } |
971 | | - for (auto& output : node->OutputDefs()) { |
972 | | - auto& output_proto = const_cast<ONNX_NAMESPACE::TypeProto&>(output->ToProto().type()); |
973 | | - if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) |
974 | | - output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); |
975 | | - } |
976 | | - } |
977 | | - |
978 | | - for (auto& node : gen_graph.original_graph.Nodes()) { |
979 | | - for (auto& input_def : node->InputDefs()) { |
980 | | - ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(gen_graph.original_graph, input_def->Name())); |
981 | | - } |
982 | | - } |
983 | | - |
984 | | - const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); |
985 | | - for (auto& [key, const_tensor_proto] : init_set) { |
986 | | - auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(const_tensor_proto); |
987 | | - auto dt = tensor_proto->data_type(); |
988 | | - if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { |
989 | | - auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast<std::uint16_t*>(tensor_proto->mutable_raw_data()->data()) : nullptr; |
990 | | - if (raw_data) { |
991 | | - tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); |
992 | | - std::int64_t size = 1; |
993 | | - for (int i = 0; i < tensor_proto->dims_size(); ++i) |
994 | | - size *= tensor_proto->dims()[i]; |
995 | | - for (std::int64_t i = 0; i < size; ++i) { |
996 | | - raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val; |
997 | | - } |
998 | | - } |
999 | | - } |
1000 | | - } |
1001 | | -} |
1002 | | - |
1003 | | -Status Transform(const GraphViewer& src_graph_viewer, |
1004 | | - const logging::Logger& logger, |
1005 | | - /*out*/ std::unique_ptr<onnxruntime::Model>& model) { |
1006 | | - auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model); |
1007 | | - auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph()); |
1008 | | - |
1009 | | - replace_bf16_with_fp16(g); |
1010 | | - return status; |
1011 | | -} |
1012 | | -} // namespace bfloat16_fix |
1013 | 957 | } // namespace openvino_ep |
1014 | 958 | } // namespace onnxruntime |
0 commit comments