Skip to content

Commit 1d4576f

Browse files
author
Gaurav Shukla
authored
Revert "[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape" (#12)
This reverts commit e6877af.
1 parent e6877af commit 1d4576f

2 files changed

Lines changed: 27 additions & 85 deletions

File tree

onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp

Lines changed: 27 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,8 @@ Status NodeImporter::ImportAll() {
638638

639639
for (const auto &node : nodes) {
640640
if (torch_mlir_onnx::failed(ImportNode(node))) {
641-
return failure();
641+
return SetError("Failed to import node '" + node.name() +
642+
"': " + "(node:\n" + node.DebugString() + "\n)");
642643
}
643644
}
644645

@@ -727,8 +728,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
727728
if (found_it == nv_map_.end()) {
728729
std::string msg = "Non topologically produced ONNX node input '";
729730
msg.append(input_name);
730-
msg.append("': ");
731-
msg.append(node.DebugString());
731+
msg.append("'");
732732
return SetError(std::move(msg));
733733
}
734734
input_values.push_back(found_it->second);
@@ -739,9 +739,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
739739
for (auto &output_name : node.output()) {
740740
const onnx::TypeProto *type_proto =
741741
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
742-
if (!type_proto) {
743-
return SetError("Failed to obtain TypeProto for tensor");
744-
}
742+
if (!type_proto)
743+
return failure();
745744

746745
MlirType t = cc_.ConvertTypeProto(*type_proto);
747746
if (mlirTypeIsNull(t))
@@ -907,83 +906,38 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
907906
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
908907
/*encoding*/ {nullptr});
909908
};
910-
const bool has_raw_data = tensor_proto.has_raw_data();
911909
MlirAttribute splat_attr = {nullptr};
912-
size_t out_size;
913910
switch (tensor_proto.data_type()) {
914-
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
915-
const float *data = nullptr;
916-
if (has_raw_data) {
917-
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
918-
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
919-
tensor_proto.DebugString());
920-
}
911+
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
921912
splat_attr = mlirDenseElementsAttrFloatSplatGet(
922-
tensorTypeFor(mlirF32TypeGet(context_)),
923-
has_raw_data ? data[0] : tensor_proto.float_data(0));
913+
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
924914
break;
925-
}
926-
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
927-
const int32_t *data = nullptr;
928-
if (has_raw_data) {
929-
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
930-
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
931-
tensor_proto.DebugString());
932-
}
933-
splat_attr = mlirDenseElementsAttrInt32SplatGet(
915+
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
916+
splat_attr = mlirDenseElementsAttrFloatSplatGet(
934917
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
935-
has_raw_data ? data[0] : tensor_proto.int32_data(0));
918+
tensor_proto.int32_data(0));
936919
break;
937-
}
938-
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
939-
const int64_t *data = nullptr;
940-
if (has_raw_data) {
941-
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
942-
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
943-
tensor_proto.DebugString());
944-
}
945-
splat_attr = mlirDenseElementsAttrInt64SplatGet(
920+
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
921+
splat_attr = mlirDenseElementsAttrFloatSplatGet(
946922
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
947-
has_raw_data ? data[0] : tensor_proto.int64_data(0));
923+
tensor_proto.int64_data(0));
948924
break;
949-
}
950-
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
951-
const double *data = nullptr;
952-
if (has_raw_data) {
953-
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
954-
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
955-
tensor_proto.DebugString());
956-
}
957-
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
958-
tensorTypeFor(mlirF64TypeGet(context_)),
959-
has_raw_data ? data[0] : tensor_proto.double_data(0));
925+
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
926+
splat_attr = mlirDenseElementsAttrFloatSplatGet(
927+
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
960928
break;
961-
}
962-
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
963-
const uint64_t *data = nullptr;
964-
if (has_raw_data) {
965-
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
966-
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
967-
tensor_proto.DebugString());
968-
}
969-
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
929+
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
930+
splat_attr = mlirDenseElementsAttrFloatSplatGet(
970931
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
971-
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
932+
tensor_proto.uint64_data(0));
972933
break;
973-
}
974-
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
975-
const uint32_t *data = nullptr;
976-
if (has_raw_data) {
977-
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
978-
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
979-
tensor_proto.DebugString());
980-
}
981-
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
934+
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
935+
// Special case: inline data is stored in uint64.
936+
splat_attr = mlirDenseElementsAttrFloatSplatGet(
982937
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
983-
has_raw_data ? data[0] : tensor_proto.float_data(0));
938+
tensor_proto.uint64_data(0));
984939
break;
985940
}
986-
}
987941

988942
if (mlirAttributeIsNull(splat_attr)) {
989943
std::string message =
@@ -1004,7 +958,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
1004958
toMlirNamedAttribute("value", splat_attr));
1005959
MlirValue result = mlirOperationGetResult(op, 0);
1006960

1007-
auto inserted = nv_map_.emplace(node.output(0), result);
961+
// Export to the nv_map.
962+
auto inserted = nv_map_.insert(std::make_pair(name, result));
1008963
if (!inserted.second) {
1009964
std::string msg = "Multiple nodes produced a value for '";
1010965
msg.append(name);
@@ -1018,17 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
1018973

1019974
Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
1020975
std::vector<int64_t> &shape) {
1021-
const onnx::TensorProto *tensor =
1022-
graph_info_.graph_viewer().GetConstantInitializer(name, false);
1023-
if (!tensor) {
1024-
std::string msg = "Could not find the immediate shape tensor ";
1025-
msg.append(name);
1026-
msg.append(" in constant graph initializers. It was possibly produced "
1027-
"dynamically.");
1028-
return SetError(msg);
1029-
}
1030-
const onnx::TensorProto &tp = *tensor;
1031-
976+
const onnx::TensorProto &tp =
977+
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
1032978
shape.clear();
1033979

1034980
// Since this is being interpreted as a shape, we only support some limited

onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,6 @@ class GraphInfo {
9393
return nullptr;
9494
}
9595

96-
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
97-
value_info_map() {
98-
return value_info_map_;
99-
}
10096
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
10197
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
10298
input_map() {

0 commit comments

Comments
 (0)