Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BiasGelu fusion and related ops #23518

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float or half tensors.</dd>
</dl>

Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -912,11 +912,11 @@ Do not modify directly.*
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|DynamicTimeWarping|*in* input:**F**<br> *out* output:**I**|1+|**F** = tensor(float)<br/> **I** = tensor(int32)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

using namespace ONNX_NAMESPACE;

Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);

class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
Expand Down Expand Up @@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
Expand Down Expand Up @@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
Expand Down Expand Up @@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp<ElementwiseParams<T>> {
}

ELEMENTWISE_FWD_DECL(FastGeLU, float);
ELEMENTWISE_FWD_DECL(FastGeLU, double);
ELEMENTWISE_FWD_DECL(FastGeLU, half);
ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16);

ELEMENTWISE_FWD_DECL(GeLU, float);
ELEMENTWISE_FWD_DECL(GeLU, double);
ELEMENTWISE_FWD_DECL(GeLU, half);
ELEMENTWISE_FWD_DECL(GeLU, BFloat16);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"

ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16);
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"

ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16);
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ namespace contrib {
namespace rocm {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
Expand Down Expand Up @@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
Expand Down Expand Up @@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
Expand Down Expand Up @@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(0, "X", "input tensor", "T")
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
.Output(0, "Y", "output tensor", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// fastgelu(x) =
Expand Down
13 changes: 11 additions & 2 deletions onnxruntime/core/optimizer/bias_gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}

const Node& next_node = (*next_node_itr);
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||

bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain);
if (!(is_onnx_gelu ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
Expand All @@ -72,14 +75,20 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

bool is_approximate = is_fast_gelu;
if (is_onnx_gelu) {
const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate");
is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh");
}

if (graph.NodeProducesGraphOutput(node)) {
continue;
}

Node& add_node = node;
Node& gelu_node = const_cast<Node&>(next_node);
std::string op_type = "BiasGelu";
if (is_fast_gelu) op_type = "FastGelu";
if (is_approximate) op_type = "FastGelu";

Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
op_type,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

template <typename T>
Expand Down Expand Up @@ -80,6 +81,7 @@ namespace contrib::cuda {

REGISTER_CONTRIB_KERNEL_TYPED(float)
REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(BFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(double)

#undef REGISTER_CONTRIB_KERNEL_TYPED
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Status LaunchGeluKernel(

SPECIALIZED_GELU_IMPL(float);
SPECIALIZED_GELU_IMPL(half);
SPECIALIZED_GELU_IMPL(BFloat16);
SPECIALIZED_GELU_IMPL(double);

#undef SPECIALIZED_GELU_IMPL
Expand Down
40 changes: 39 additions & 1 deletion onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
int min_cuda_architecture = 800;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
Expand Down Expand Up @@ -440,5 +440,43 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
}
#endif

// CUDA and ROCm only for double type.
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_Double) {
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;

std::vector<double> X = {
0.8, -0.5, 0.0, 1.0,
0.5, 0.2, 0.3, -0.6};

std::vector<double> B = {
-0.5, 0.6, 1.2, 2.1};

std::vector<double> Y = {
0.185371, 0.053983, 1.061703, 3.097373,
0.000000, 0.630432, 1.399572, 1.399572};

std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;

tester.AddInput<double>("X", input_dims, X);
tester.AddInput<double>("bias", bias_dims, B);
tester.AddOutput<double>("Y", output_dims, Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

} // namespace test
} // namespace onnxruntime
40 changes: 40 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4781,6 +4781,46 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}

TEST_F(GraphTransformationTests, BiasOnnxGeluTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}

TEST_F(GraphTransformationTests, BiasOnnxFastGeluTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_fast_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 0);
}

// BiasGelu allows input switching based on input dimensions.
// This test validates the input edges are plugged correct in the optimized graph.
TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
Expand Down
Binary file modified onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx
Binary file not shown.
Loading
Loading