@@ -638,7 +638,8 @@ Status NodeImporter::ImportAll() {
638638
639639 for (const auto &node : nodes) {
640640 if (torch_mlir_onnx::failed (ImportNode (node))) {
641- return failure ();
641+ return SetError (" Failed to import node '" + node.name () +
642+ " ': " + " (node:\n " + node.DebugString () + " \n )" );
642643 }
643644 }
644645
@@ -727,8 +728,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
727728 if (found_it == nv_map_.end ()) {
728729 std::string msg = " Non topologically produced ONNX node input '" ;
729730 msg.append (input_name);
730- msg.append (" ': " );
731- msg.append (node.DebugString ());
731+ msg.append (" '" );
732732 return SetError (std::move (msg));
733733 }
734734 input_values.push_back (found_it->second );
@@ -739,9 +739,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
739739 for (auto &output_name : node.output ()) {
740740 const onnx::TypeProto *type_proto =
741741 graph_info_.graph_viewer ().GetNodeArg (output_name)->TypeAsProto ();
742- if (!type_proto) {
743- return SetError (" Failed to obtain TypeProto for tensor" );
744- }
742+ if (!type_proto)
743+ return failure ();
745744
746745 MlirType t = cc_.ConvertTypeProto (*type_proto);
747746 if (mlirTypeIsNull (t))
@@ -907,83 +906,38 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
907906 return mlirRankedTensorTypeGet (shape.size (), shape.data (), element_type,
908907 /* encoding*/ {nullptr });
909908 };
910- const bool has_raw_data = tensor_proto.has_raw_data ();
911909 MlirAttribute splat_attr = {nullptr };
912- size_t out_size;
913910 switch (tensor_proto.data_type ()) {
914- case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
915- const float *data = nullptr ;
916- if (has_raw_data) {
917- data = graph_info_.GetOptionalRawData <float >(tensor_proto, out_size);
918- ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
919- tensor_proto.DebugString ());
920- }
911+ case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
921912 splat_attr = mlirDenseElementsAttrFloatSplatGet (
922- tensorTypeFor (mlirF32TypeGet (context_)),
923- has_raw_data ? data[0 ] : tensor_proto.float_data (0 ));
913+ tensorTypeFor (mlirF32TypeGet (context_)), tensor_proto.float_data (0 ));
924914 break ;
925- }
926- case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
927- const int32_t *data = nullptr ;
928- if (has_raw_data) {
929- data = graph_info_.GetOptionalRawData <int32_t >(tensor_proto, out_size);
930- ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
931- tensor_proto.DebugString ());
932- }
933- splat_attr = mlirDenseElementsAttrInt32SplatGet (
915+ case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
916+ splat_attr = mlirDenseElementsAttrFloatSplatGet (
934917 tensorTypeFor (mlirIntegerTypeSignedGet (context_, 32 )),
935- has_raw_data ? data[ 0 ] : tensor_proto.int32_data (0 ));
918+ tensor_proto.int32_data (0 ));
936919 break ;
937- }
938- case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
939- const int64_t *data = nullptr ;
940- if (has_raw_data) {
941- data = graph_info_.GetOptionalRawData <int64_t >(tensor_proto, out_size);
942- ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
943- tensor_proto.DebugString ());
944- }
945- splat_attr = mlirDenseElementsAttrInt64SplatGet (
920+ case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
921+ splat_attr = mlirDenseElementsAttrFloatSplatGet (
946922 tensorTypeFor (mlirIntegerTypeSignedGet (context_, 64 )),
947- has_raw_data ? data[ 0 ] : tensor_proto.int64_data (0 ));
923+ tensor_proto.int64_data (0 ));
948924 break ;
949- }
950- case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
951- const double *data = nullptr ;
952- if (has_raw_data) {
953- data = graph_info_.GetOptionalRawData <double >(tensor_proto, out_size);
954- ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
955- tensor_proto.DebugString ());
956- }
957- splat_attr = mlirDenseElementsAttrDoubleSplatGet (
958- tensorTypeFor (mlirF64TypeGet (context_)),
959- has_raw_data ? data[0 ] : tensor_proto.double_data (0 ));
925+ case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
926+ splat_attr = mlirDenseElementsAttrFloatSplatGet (
927+ tensorTypeFor (mlirF64TypeGet (context_)), tensor_proto.double_data (0 ));
960928 break ;
961- }
962- case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
963- const uint64_t *data = nullptr ;
964- if (has_raw_data) {
965- data = graph_info_.GetOptionalRawData <uint64_t >(tensor_proto, out_size);
966- ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
967- tensor_proto.DebugString ());
968- }
969- splat_attr = mlirDenseElementsAttrUInt64SplatGet (
929+ case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
930+ splat_attr = mlirDenseElementsAttrFloatSplatGet (
970931 tensorTypeFor (mlirIntegerTypeUnsignedGet (context_, 64 )),
971- has_raw_data ? data[ 0 ] : tensor_proto.uint64_data (0 ));
932+ tensor_proto.uint64_data (0 ));
972933 break ;
973- }
974- case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
975- const uint32_t *data = nullptr ;
976- if (has_raw_data) {
977- data = graph_info_.GetOptionalRawData <uint32_t >(tensor_proto, out_size);
978- ORT_ENFORCE (data, " GetOptionalRawData() returned null for tensor proto: " ,
979- tensor_proto.DebugString ());
980- }
981- splat_attr = mlirDenseElementsAttrUInt32SplatGet (
934+ case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
935+ // Special case: inline data is stored in uint64.
936+ splat_attr = mlirDenseElementsAttrFloatSplatGet (
982937 tensorTypeFor (mlirIntegerTypeUnsignedGet (context_, 32 )),
983- has_raw_data ? data[ 0 ] : tensor_proto.float_data (0 ));
938+ tensor_proto.uint64_data (0 ));
984939 break ;
985940 }
986- }
987941
988942 if (mlirAttributeIsNull (splat_attr)) {
989943 std::string message =
@@ -1004,7 +958,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
1004958 toMlirNamedAttribute (" value" , splat_attr));
1005959 MlirValue result = mlirOperationGetResult (op, 0 );
1006960
1007- auto inserted = nv_map_.emplace (node.output (0 ), result);
961+ // Export to the nv_map.
962+ auto inserted = nv_map_.insert (std::make_pair (name, result));
1008963 if (!inserted.second ) {
1009964 std::string msg = " Multiple nodes produced a value for '" ;
1010965 msg.append (name);
@@ -1018,17 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
1018973
1019974Status NodeImporter::GetImmediateShapeTensor (const std::string &name,
1020975 std::vector<int64_t > &shape) {
1021- const onnx::TensorProto *tensor =
1022- graph_info_.graph_viewer ().GetConstantInitializer (name, false );
1023- if (!tensor) {
1024- std::string msg = " Could not find the immediate shape tensor " ;
1025- msg.append (name);
1026- msg.append (" in constant graph initializers. It was possibly produced "
1027- " dynamically." );
1028- return SetError (msg);
1029- }
1030- const onnx::TensorProto &tp = *tensor;
1031-
976+ const onnx::TensorProto &tp =
977+ *graph_info_.graph_viewer ().GetConstantInitializer (name, false );
1032978 shape.clear ();
1033979
1034980 // Since this is being interpreted as a shape, we only support some limited
0 commit comments