Skip to content

Commit e6877af

Browse files
authored
[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape
1 parent c45ebb0 commit e6877af

2 files changed

Lines changed: 85 additions & 27 deletions

File tree

onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,7 @@ Status NodeImporter::ImportAll() {
638638

639639
for (const auto &node : nodes) {
640640
if (torch_mlir_onnx::failed(ImportNode(node))) {
641-
return SetError("Failed to import node '" + node.name() +
642-
"': " + "(node:\n" + node.DebugString() + "\n)");
641+
return failure();
643642
}
644643
}
645644

@@ -728,7 +727,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
728727
if (found_it == nv_map_.end()) {
729728
std::string msg = "Non topologically produced ONNX node input '";
730729
msg.append(input_name);
731-
msg.append("'");
730+
msg.append("': ");
731+
msg.append(node.DebugString());
732732
return SetError(std::move(msg));
733733
}
734734
input_values.push_back(found_it->second);
@@ -739,8 +739,9 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
739739
for (auto &output_name : node.output()) {
740740
const onnx::TypeProto *type_proto =
741741
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
742-
if (!type_proto)
743-
return failure();
742+
if (!type_proto) {
743+
return SetError("Failed to obtain TypeProto for tensor");
744+
}
744745

745746
MlirType t = cc_.ConvertTypeProto(*type_proto);
746747
if (mlirTypeIsNull(t))
@@ -906,38 +907,83 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
906907
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
907908
/*encoding*/ {nullptr});
908909
};
910+
const bool has_raw_data = tensor_proto.has_raw_data();
909911
MlirAttribute splat_attr = {nullptr};
912+
size_t out_size;
910913
switch (tensor_proto.data_type()) {
911-
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
914+
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
915+
const float *data = nullptr;
916+
if (has_raw_data) {
917+
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
918+
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
919+
tensor_proto.DebugString());
920+
}
912921
splat_attr = mlirDenseElementsAttrFloatSplatGet(
913-
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
922+
tensorTypeFor(mlirF32TypeGet(context_)),
923+
has_raw_data ? data[0] : tensor_proto.float_data(0));
914924
break;
915-
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
916-
splat_attr = mlirDenseElementsAttrFloatSplatGet(
925+
}
926+
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
927+
const int32_t *data = nullptr;
928+
if (has_raw_data) {
929+
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
930+
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
931+
tensor_proto.DebugString());
932+
}
933+
splat_attr = mlirDenseElementsAttrInt32SplatGet(
917934
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
918-
tensor_proto.int32_data(0));
935+
has_raw_data ? data[0] : tensor_proto.int32_data(0));
919936
break;
920-
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
921-
splat_attr = mlirDenseElementsAttrFloatSplatGet(
937+
}
938+
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
939+
const int64_t *data = nullptr;
940+
if (has_raw_data) {
941+
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
942+
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
943+
tensor_proto.DebugString());
944+
}
945+
splat_attr = mlirDenseElementsAttrInt64SplatGet(
922946
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
923-
tensor_proto.int64_data(0));
947+
has_raw_data ? data[0] : tensor_proto.int64_data(0));
924948
break;
925-
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
926-
splat_attr = mlirDenseElementsAttrFloatSplatGet(
927-
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
949+
}
950+
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
951+
const double *data = nullptr;
952+
if (has_raw_data) {
953+
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
954+
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
955+
tensor_proto.DebugString());
956+
}
957+
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
958+
tensorTypeFor(mlirF64TypeGet(context_)),
959+
has_raw_data ? data[0] : tensor_proto.double_data(0));
928960
break;
929-
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
930-
splat_attr = mlirDenseElementsAttrFloatSplatGet(
961+
}
962+
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
963+
const uint64_t *data = nullptr;
964+
if (has_raw_data) {
965+
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
966+
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
967+
tensor_proto.DebugString());
968+
}
969+
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
931970
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
932-
tensor_proto.uint64_data(0));
971+
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
933972
break;
934-
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
935-
// Special case: inline data is stored in uint64.
936-
splat_attr = mlirDenseElementsAttrFloatSplatGet(
973+
}
974+
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
975+
const uint32_t *data = nullptr;
976+
if (has_raw_data) {
977+
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
978+
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
979+
tensor_proto.DebugString());
980+
}
981+
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
937982
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
938-
tensor_proto.uint64_data(0));
983+
has_raw_data ? data[0] : tensor_proto.float_data(0));
939984
break;
940985
}
986+
}
941987

942988
if (mlirAttributeIsNull(splat_attr)) {
943989
std::string message =
@@ -958,8 +1004,7 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
9581004
toMlirNamedAttribute("value", splat_attr));
9591005
MlirValue result = mlirOperationGetResult(op, 0);
9601006

961-
// Export to the nv_map.
962-
auto inserted = nv_map_.insert(std::make_pair(name, result));
1007+
auto inserted = nv_map_.emplace(node.output(0), result);
9631008
if (!inserted.second) {
9641009
std::string msg = "Multiple nodes produced a value for '";
9651010
msg.append(name);
@@ -973,8 +1018,17 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
9731018

9741019
Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
9751020
std::vector<int64_t> &shape) {
976-
const onnx::TensorProto &tp =
977-
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
1021+
const onnx::TensorProto *tensor =
1022+
graph_info_.graph_viewer().GetConstantInitializer(name, false);
1023+
if (!tensor) {
1024+
std::string msg = "Could not find the immediate shape tensor ";
1025+
msg.append(name);
1026+
msg.append(" in constant graph initializers. It was possibly produced "
1027+
"dynamically.");
1028+
return SetError(msg);
1029+
}
1030+
const onnx::TensorProto &tp = *tensor;
1031+
9781032
shape.clear();
9791033

9801034
// Since this is being interpreted as a shape, we only support some limited

onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ class GraphInfo {
9393
return nullptr;
9494
}
9595

96+
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
97+
value_info_map() {
98+
return value_info_map_;
99+
}
96100
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
97101
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
98102
input_map() {

0 commit comments

Comments
 (0)