diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 2290030073e5c..571d26bc1903b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16), tensor(bfloat16)
+
T : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)
Constrain input and output types to float or half tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 98dcb777422bc..f565c607c4b37 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -912,11 +912,11 @@ Do not modify directly.* |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| -|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| +|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 8b8e4e267f895..3a16f16466ed3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -30,6 +30,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) REGISTER_KERNEL_TYPED(BFloat16) +REGISTER_KERNEL_TYPED(double) using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 21bd5eb91c20f..cbe4d87dbf398 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -25,10 +25,14 @@ namespace onnxruntime { namespace contrib { namespace cuda { class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample); + class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu); class CUDA_MS_OP_CLASS_NAME(1, BiasGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu); @@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv); -class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul); class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul); @@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h index dd78f47f48b9b..768295767835a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise.h +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise.h @@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp> { } ELEMENTWISE_FWD_DECL(FastGeLU, float); +ELEMENTWISE_FWD_DECL(FastGeLU, double); ELEMENTWISE_FWD_DECL(FastGeLU, half); ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16); ELEMENTWISE_FWD_DECL(GeLU, float); +ELEMENTWISE_FWD_DECL(GeLU, double); ELEMENTWISE_FWD_DECL(GeLU, half); ELEMENTWISE_FWD_DECL(GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu index 3e9534f459338..c2a670ea76aca 100644 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu @@ -4,5 +4,6 @@ #include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float); +ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double); ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half); ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu index dafc524405eb9..97f0f74640c6e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu +++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu @@ -3,6 +3,7 @@ #include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh" +ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double); ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float); ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half); ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16); diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 4284b4254f485..7dbb24463961e 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -11,10 +11,13 @@ namespace contrib { namespace rocm { class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); @@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); @@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2a2a52f8334f..d8f8d28a47621 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(0, "X", "input tensor", "T") .Input(1, "bias", "bias tensor", "T", OpSchema::Optional) .Output(0, "Y", "output tensor", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput) .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { // fastgelu(x) = diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index b90143e8f6121..9e72e6fbeb7b3 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } const Node& next_node = (*next_node_itr); - if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) || + + bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain); + if (!(is_onnx_gelu || + graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) || graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; @@ -72,6 +75,12 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } + bool is_approximate = is_fast_gelu; + if (is_onnx_gelu) { + const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate"); + is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh"); + } + if (graph.NodeProducesGraphOutput(node)) { continue; } @@ -79,7 +88,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& add_node = node; Node& gelu_node = const_cast(next_node); std::string op_type = "BiasGelu"; - if (is_fast_gelu) op_type = "FastGelu"; + if (is_approximate) op_type = "FastGelu"; Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type), op_type, diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.cc b/onnxruntime/core/providers/cuda/tensor/gelu.cc index 67b2fad373a7f..69c16fc927430 100644 --- a/onnxruntime/core/providers/cuda/tensor/gelu.cc +++ b/onnxruntime/core/providers/cuda/tensor/gelu.cc @@ -23,6 +23,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) REGISTER_KERNEL_TYPED(double) template @@ -80,6 +81,7 @@ namespace contrib::cuda { REGISTER_CONTRIB_KERNEL_TYPED(float) REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16) +REGISTER_CONTRIB_KERNEL_TYPED(BFloat16) REGISTER_CONTRIB_KERNEL_TYPED(double) #undef REGISTER_CONTRIB_KERNEL_TYPED diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu index 3f96da38b37bb..9bf513b7de6f8 100644 --- a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu @@ -40,6 +40,7 @@ Status LaunchGeluKernel( SPECIALIZED_GELU_IMPL(float); SPECIALIZED_GELU_IMPL(half); +SPECIALIZED_GELU_IMPL(BFloat16); SPECIALIZED_GELU_IMPL(double); #undef SPECIALIZED_GELU_IMPL diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index a7d751f4472fc..497b8b5fd6cc7 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -389,7 +389,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { #if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA - int min_cuda_architecture = 530; + int min_cuda_architecture = 800; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; return; @@ -440,5 +440,43 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) { } #endif +// CUDA and ROCm only for double type. +#if defined(USE_CUDA) || defined(USE_ROCM) +TEST(FastGeluTest, FastGeluWithBias_Double) { + OpTester tester("FastGelu", 1, onnxruntime::kMSDomain); + + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + + std::vector X = { + 0.8, -0.5, 0.0, 1.0, + 0.5, 0.2, 0.3, -0.6}; + + std::vector B = { + -0.5, 0.6, 1.2, 2.1}; + + std::vector Y = { + 0.185371, 0.053983, 1.061703, 3.097373, + 0.000000, 0.630432, 1.399572, 1.399572}; + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {hidden_size}; + std::vector output_dims = input_dims; + + tester.AddInput("X", input_dims, X); + tester.AddInput("bias", bias_dims, B); + tester.AddOutput("Y", output_dims, Y); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e069f6ef2432a..42ddd8cd37c81 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4781,6 +4781,46 @@ TEST_F(GraphTransformationTests, BiasGeluTest) { ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1); } +TEST_F(GraphTransformationTests, BiasOnnxGeluTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_gelu_fusion.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Gelu"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1); +} + +TEST_F(GraphTransformationTests, BiasOnnxFastGeluTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_fast_gelu_fusion.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const InlinedHashSet no_limit_empty_ep_list = {}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Gelu"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 0); +} + // BiasGelu allows input switching based on input dimensions. // This test validates the input edges are plugged correct in the optimized graph. TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx index 5adf9e1de1a28..f6df398e43c5e 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx and b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py index 96a52c73f40e6..710019ab24ec5 100644 --- a/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py @@ -29,3 +29,40 @@ model = helper.make_model(graph) onnx.save(model, r"bias_gelu_fusion.onnx") + +graph = helper.make_graph( + [ + helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"), + helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu"), + ], + "Gelu_Add_Fusion", # name + [ # inputs + helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]), + ], + [ # outputs + helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]), + ], +) + +model = helper.make_model(graph) +onnx.save(model, r"bias_onnx_gelu_fusion.onnx") + + +graph = helper.make_graph( + [ + helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"), + helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu", approximate="tanh"), + ], + "Gelu_Add_Fusion", # name + [ # inputs + helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]), + ], + [ # outputs + helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]), + ], +) + +model = helper.make_model(graph) +onnx.save(model, r"bias_onnx_fast_gelu_fusion.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx new file mode 100644 index 0000000000000..d449450d8c9f8 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx @@ -0,0 +1,21 @@ + +:½ + +X +Badd0_outadd0"Add +1 +add0_outYgelu"Gelu* + approximate"tanh Gelu_Add_FusionZ# +X + +batch +seqlen +€Z +B +  +€b# +Y + +batch +seqlen +€B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx new file mode 100644 index 0000000000000..06c5c1ff7b423 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx @@ -0,0 +1,20 @@ + +:₯ + +X +Badd0_outadd0"Add + +add0_outYgelu"GeluGelu_Add_FusionZ# +X + +batch +seqlen +€Z +B +  +€b# +Y + +batch +seqlen +€B \ No newline at end of file