@@ -638,8 +638,7 @@ Status NodeImporter::ImportAll() {
638638
639639 for (const auto &node : nodes) {
640640 if (torch_mlir_onnx::failed (ImportNode (node))) {
641- return SetError (" Failed to import node '" + node.name () +
642- " ': " + " (node:\n " + node.DebugString () + " \n )" );
641+ return failure ();
643642 }
644643 }
645644
@@ -728,7 +727,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
728727 if (found_it == nv_map_.end ()) {
729728 std::string msg = " Non topologically produced ONNX node input '" ;
730729 msg.append (input_name);
731- msg.append (" '" );
730+ msg.append (" ': " );
731+ msg.append (node.DebugString ());
732732 return SetError (std::move (msg));
733733 }
734734 input_values.push_back (found_it->second );
@@ -739,8 +739,9 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
739739 for (auto &output_name : node.output ()) {
740740 const onnx::TypeProto *type_proto =
741741 graph_info_.graph_viewer ().GetNodeArg (output_name)->TypeAsProto ();
742- if (!type_proto)
743- return failure ();
742+ if (!type_proto) {
743+ return SetError (" Failed to obtain TypeProto for tensor" );
744+ }
744745
745746 MlirType t = cc_.ConvertTypeProto (*type_proto);
746747 if (mlirTypeIsNull (t))
@@ -906,38 +907,83 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
906907 return mlirRankedTensorTypeGet (shape.size (), shape.data (), element_type,
907908 /* encoding*/ {nullptr });
908909 };
910+ const bool has_raw_data = tensor_proto.has_raw_data ();
909911 MlirAttribute splat_attr = {nullptr };
912+ size_t out_size;
910913 switch (tensor_proto.data_type ()) {
911- case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
914+ case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
915+ const float *data = nullptr ;
916+ if (has_raw_data) {
917+ data = graph_info_.GetOptionalRawData <float >(tensor_proto, out_size);
918+ ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
919+ tensor_proto.DebugString ());
920+ }
912921 splat_attr = mlirDenseElementsAttrFloatSplatGet (
913- tensorTypeFor (mlirF32TypeGet (context_)), tensor_proto.float_data (0 ));
922+ tensorTypeFor (mlirF32TypeGet (context_)),
923+ has_raw_data ? data[0 ] : tensor_proto.float_data (0 ));
914924 break ;
915- case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
916- splat_attr = mlirDenseElementsAttrFloatSplatGet (
925+ }
926+ case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
927+ const int32_t *data = nullptr ;
928+ if (has_raw_data) {
929+ data = graph_info_.GetOptionalRawData <int32_t >(tensor_proto, out_size);
930+ ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
931+ tensor_proto.DebugString ());
932+ }
933+ splat_attr = mlirDenseElementsAttrInt32SplatGet (
917934 tensorTypeFor (mlirIntegerTypeSignedGet (context_, 32 )),
918- tensor_proto.int32_data (0 ));
935+ has_raw_data ? data[ 0 ] : tensor_proto.int32_data (0 ));
919936 break ;
920- case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
921- splat_attr = mlirDenseElementsAttrFloatSplatGet (
937+ }
938+ case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
939+ const int64_t *data = nullptr ;
940+ if (has_raw_data) {
941+ data = graph_info_.GetOptionalRawData <int64_t >(tensor_proto, out_size);
942+ ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
943+ tensor_proto.DebugString ());
944+ }
945+ splat_attr = mlirDenseElementsAttrInt64SplatGet (
922946 tensorTypeFor (mlirIntegerTypeSignedGet (context_, 64 )),
923- tensor_proto.int64_data (0 ));
947+ has_raw_data ? data[ 0 ] : tensor_proto.int64_data (0 ));
924948 break ;
925- case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
926- splat_attr = mlirDenseElementsAttrFloatSplatGet (
927- tensorTypeFor (mlirF64TypeGet (context_)), tensor_proto.double_data (0 ));
949+ }
950+ case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
951+ const double *data = nullptr ;
952+ if (has_raw_data) {
953+ data = graph_info_.GetOptionalRawData <double >(tensor_proto, out_size);
954+ ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
955+ tensor_proto.DebugString ());
956+ }
957+ splat_attr = mlirDenseElementsAttrDoubleSplatGet (
958+ tensorTypeFor (mlirF64TypeGet (context_)),
959+ has_raw_data ? data[0 ] : tensor_proto.double_data (0 ));
928960 break ;
929- case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
930- splat_attr = mlirDenseElementsAttrFloatSplatGet (
961+ }
962+ case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
963+ const uint64_t *data = nullptr ;
964+ if (has_raw_data) {
965+ data = graph_info_.GetOptionalRawData <uint64_t >(tensor_proto, out_size);
966+ ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
967+ tensor_proto.DebugString ());
968+ }
969+ splat_attr = mlirDenseElementsAttrUInt64SplatGet (
931970 tensorTypeFor (mlirIntegerTypeUnsignedGet (context_, 64 )),
932- tensor_proto.uint64_data (0 ));
971+ has_raw_data ? data[ 0 ] : tensor_proto.uint64_data (0 ));
933972 break ;
934- case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
935- // Special case: inline data is stored in uint64.
936- splat_attr = mlirDenseElementsAttrFloatSplatGet (
973+ }
974+ case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
975+ const uint32_t *data = nullptr ;
976+ if (has_raw_data) {
977+ data = graph_info_.GetOptionalRawData <uint32_t >(tensor_proto, out_size);
978+ ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
979+ tensor_proto.DebugString ());
980+ }
981+ splat_attr = mlirDenseElementsAttrUInt32SplatGet (
937982 tensorTypeFor (mlirIntegerTypeUnsignedGet (context_, 32 )),
938- tensor_proto.uint64_data (0 ));
983+ has_raw_data ? data[ 0 ] : tensor_proto.float_data (0 ));
939984 break ;
940985 }
986+ }
941987
942988 if (mlirAttributeIsNull (splat_attr)) {
943989 std::string message =
@@ -958,8 +1004,7 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
9581004 toMlirNamedAttribute (" value" , splat_attr));
9591005 MlirValue result = mlirOperationGetResult (op, 0 );
9601006
961- // Export to the nv_map.
962- auto inserted = nv_map_.insert (std::make_pair (name, result));
1007+ auto inserted = nv_map_.emplace (node.output (0 ), result);
9631008 if (!inserted.second ) {
9641009 std::string msg = " Multiple nodes produced a value for '" ;
9651010 msg.append (name);
@@ -973,8 +1018,17 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
9731018
9741019Status NodeImporter::GetImmediateShapeTensor (const std::string &name,
9751020 std::vector<int64_t > &shape) {
976- const onnx::TensorProto &tp =
977- *graph_info_.graph_viewer ().GetConstantInitializer (name, false );
1021+ const onnx::TensorProto *tensor =
1022+ graph_info_.graph_viewer ().GetConstantInitializer (name, false );
1023+ if (!tensor) {
1024+ std::string msg = " Could not find the immediate shape tensor " ;
1025+ msg.append (name);
1026+ msg.append (" in constant graph initializers. It was possibly produced "
1027+ " dynamically." );
1028+ return SetError (msg);
1029+ }
1030+ const onnx::TensorProto &tp = *tensor;
1031+
9781032 shape.clear ();
9791033
9801034 // Since this is being interpreted as a shape, we only support some limited
0 commit comments