diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 2290030073e5c..571d26bc1903b 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float), tensor(float16), tensor(bfloat16)
+- T : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)
- Constrain input and output types to float or half tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 98dcb777422bc..f565c607c4b37 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -912,11 +912,11 @@ Do not modify directly.*
|DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
-|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
+|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
-|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
index 8b8e4e267f895..3a16f16466ed3 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
@@ -30,6 +30,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
+REGISTER_KERNEL_TYPED(double)
using namespace ONNX_NAMESPACE;
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 21bd5eb91c20f..cbe4d87dbf398 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -25,10 +25,14 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);
+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
@@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
-class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
@@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise.h b/onnxruntime/contrib_ops/rocm/bert/elementwise.h
index dd78f47f48b9b..768295767835a 100644
--- a/onnxruntime/contrib_ops/rocm/bert/elementwise.h
+++ b/onnxruntime/contrib_ops/rocm/bert/elementwise.h
@@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp> {
}
ELEMENTWISE_FWD_DECL(FastGeLU, float);
+ELEMENTWISE_FWD_DECL(FastGeLU, double);
ELEMENTWISE_FWD_DECL(FastGeLU, half);
ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16);
ELEMENTWISE_FWD_DECL(GeLU, float);
+ELEMENTWISE_FWD_DECL(GeLU, double);
ELEMENTWISE_FWD_DECL(GeLU, half);
ELEMENTWISE_FWD_DECL(GeLU, BFloat16);
diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu
index 3e9534f459338..c2a670ea76aca 100644
--- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu
+++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_fastgelu.cu
@@ -4,5 +4,6 @@
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float);
+ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16);
diff --git a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu
index dafc524405eb9..97f0f74640c6e 100644
--- a/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu
+++ b/onnxruntime/contrib_ops/rocm/bert/elementwise_impl/impl_gelu.cu
@@ -3,6 +3,7 @@
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"
+ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16);
diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
index 4284b4254f485..7dbb24463961e 100644
--- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
@@ -11,10 +11,13 @@ namespace contrib {
namespace rocm {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
@@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
@@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo,
// BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index f2a2a52f8334f..d8f8d28a47621 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(0, "X", "input tensor", "T")
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
.Output(0, "Y", "output tensor", "T")
- .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
+ .TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// fastgelu(x) =
diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc
index b90143e8f6121..9e72e6fbeb7b3 100644
--- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc
+++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc
@@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
const Node& next_node = (*next_node_itr);
- if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
+
+ bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain);
+ if (!(is_onnx_gelu ||
+ graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
@@ -72,6 +75,12 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
+ bool is_approximate = is_fast_gelu;
+ if (is_onnx_gelu) {
+ const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate");
+ is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh");
+ }
+
if (graph.NodeProducesGraphOutput(node)) {
continue;
}
@@ -79,7 +88,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& add_node = node;
Node& gelu_node = const_cast(next_node);
std::string op_type = "BiasGelu";
- if (is_fast_gelu) op_type = "FastGelu";
+ if (is_approximate) op_type = "FastGelu";
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
op_type,
diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.cc b/onnxruntime/core/providers/cuda/tensor/gelu.cc
index 67b2fad373a7f..69c16fc927430 100644
--- a/onnxruntime/core/providers/cuda/tensor/gelu.cc
+++ b/onnxruntime/core/providers/cuda/tensor/gelu.cc
@@ -23,6 +23,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)
template
@@ -80,6 +81,7 @@ namespace contrib::cuda {
REGISTER_CONTRIB_KERNEL_TYPED(float)
REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
+REGISTER_CONTRIB_KERNEL_TYPED(BFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(double)
#undef REGISTER_CONTRIB_KERNEL_TYPED
diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
index 3f96da38b37bb..9bf513b7de6f8 100644
--- a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
@@ -40,6 +40,7 @@ Status LaunchGeluKernel(
SPECIALIZED_GELU_IMPL(float);
SPECIALIZED_GELU_IMPL(half);
+SPECIALIZED_GELU_IMPL(BFloat16);
SPECIALIZED_GELU_IMPL(double);
#undef SPECIALIZED_GELU_IMPL
diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc
index a7d751f4472fc..497b8b5fd6cc7 100644
--- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc
+++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc
@@ -389,7 +389,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
#ifdef USE_CUDA
- int min_cuda_architecture = 530;
+ int min_cuda_architecture = 800;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
@@ -440,5 +440,43 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
}
#endif
+// CUDA and ROCm only for double type.
+#if defined(USE_CUDA) || defined(USE_ROCM)
+TEST(FastGeluTest, FastGeluWithBias_Double) {
+ OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);
+
+ int batch_size = 1;
+ int sequence_length = 2;
+ int hidden_size = 4;
+
+ std::vector X = {
+ 0.8, -0.5, 0.0, 1.0,
+ 0.5, 0.2, 0.3, -0.6};
+
+ std::vector B = {
+ -0.5, 0.6, 1.2, 2.1};
+
+ std::vector Y = {
+ 0.185371, 0.053983, 1.061703, 3.097373,
+ 0.000000, 0.630432, 1.399572, 1.399572};
+
+ std::vector input_dims = {batch_size, sequence_length, hidden_size};
+ std::vector bias_dims = {hidden_size};
+ std::vector output_dims = input_dims;
+
+ tester.AddInput("X", input_dims, X);
+ tester.AddInput("bias", bias_dims, B);
+ tester.AddOutput("Y", output_dims, Y);
+
+ std::vector> execution_providers;
+#ifdef USE_CUDA
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+#elif USE_ROCM
+ execution_providers.push_back(DefaultRocmExecutionProvider());
+#endif
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+#endif
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index e069f6ef2432a..42ddd8cd37c81 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -4781,6 +4781,46 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}
+TEST_F(GraphTransformationTests, BiasOnnxGeluTest) {
+ constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_gelu_fusion.onnx";
+ std::shared_ptr p_model;
+ ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+ Graph& graph = p_model->MainGraph();
+
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+ const InlinedHashSet no_limit_empty_ep_list = {};
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(
+ std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2));
+ ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
+
+ std::map op_to_count = CountOpsInGraph(graph);
+ ASSERT_TRUE(op_to_count["Add"] == 0);
+ ASSERT_TRUE(op_to_count["Gelu"] == 0);
+ ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 0);
+ ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
+}
+
+TEST_F(GraphTransformationTests, BiasOnnxFastGeluTest) {
+ constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_fast_gelu_fusion.onnx";
+ std::shared_ptr p_model;
+ ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+ Graph& graph = p_model->MainGraph();
+
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+ const InlinedHashSet no_limit_empty_ep_list = {};
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(
+ std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
+ ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2));
+ ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
+
+ std::map op_to_count = CountOpsInGraph(graph);
+ ASSERT_TRUE(op_to_count["Add"] == 0);
+ ASSERT_TRUE(op_to_count["Gelu"] == 0);
+ ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
+ ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 0);
+}
+
// BiasGelu allows input switching based on input dimensions.
// This test validates the input edges are plugged correct in the optimized graph.
TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx
index 5adf9e1de1a28..f6df398e43c5e 100644
Binary files a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx and b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx differ
diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py
index 96a52c73f40e6..710019ab24ec5 100644
--- a/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py
+++ b/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py
@@ -29,3 +29,40 @@
model = helper.make_model(graph)
onnx.save(model, r"bias_gelu_fusion.onnx")
+
+graph = helper.make_graph(
+ [
+ helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"),
+ helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu"),
+ ],
+ "Gelu_Add_Fusion", # name
+ [ # inputs
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
+ helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]),
+ ],
+ [ # outputs
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
+ ],
+)
+
+model = helper.make_model(graph)
+onnx.save(model, r"bias_onnx_gelu_fusion.onnx")
+
+
+graph = helper.make_graph(
+ [
+ helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"),
+ helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu", approximate="tanh"),
+ ],
+ "Gelu_Add_Fusion", # name
+ [ # inputs
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
+ helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]),
+ ],
+ [ # outputs
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
+ ],
+)
+
+model = helper.make_model(graph)
+onnx.save(model, r"bias_onnx_fast_gelu_fusion.onnx")
diff --git a/onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx
new file mode 100644
index 0000000000000..d449450d8c9f8
--- /dev/null
+++ b/onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx
@@ -0,0 +1,21 @@
+
+:½
+
+X
+Badd0_outadd0"Add
+1
+add0_outYgelu"Gelu*
+approximate"tanh Gelu_Add_FusionZ#
+X
+
+batch
+seqlen
+Z
+B
+
+b#
+Y
+
+batch
+seqlen
+B
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx
new file mode 100644
index 0000000000000..06c5c1ff7b423
--- /dev/null
+++ b/onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx
@@ -0,0 +1,20 @@
+
+:₯
+
+X
+Badd0_outadd0"Add
+
+add0_outYgelu"GeluGelu_Add_FusionZ#
+X
+
+batch
+seqlen
+Z
+B
+
+b#
+Y
+
+batch
+seqlen
+B
\ No newline at end of file