Skip to content

fix: match gelu with paper implementation #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ext/ReactantNNlibExt/Implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ for (jlop, hloop) in (
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
end

# See https://github.com/EnzymeAD/Reactant.jl/issues/1420
# Without this we will never fuse the gelu into gemm
if isdefined(NNlib, :gelu_tanh)
function NNlib.gelu_tanh(x::TracedRNumber)
return Reactant.Ops.gelu(x, Reactant.NNLIB_GELU_APPROXIMATION[])
end

NNlib.gelu_erf(x::TracedRNumber) = Reactant.Ops.gelu(x, "NONE")
else
# Older versions of NNlib do not have gelu_tanh (gelu refers to the tanh version)
function NNlib.gelu(x::TracedRNumber)
return Reactant.Ops.gelu(x, Reactant.NNLIB_GELU_APPROXIMATION[])
end
end

function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
x = T.(Reactant.materialize_traced_array(x))
max_ = maximum(x; dims)
Expand Down
7 changes: 7 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,7 @@ function compile_mlir!(
blas_int_width = sizeof(BLAS.BlasInt) * 8
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
blas_int_width=$blas_int_width}"
lower_enzymexla_ml_pass = "lower-enzymexla-ml"

if compile_options.optimization_passes === :all
run_pass_pipeline!(
Expand Down Expand Up @@ -1650,6 +1651,7 @@ function compile_mlir!(
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
lower_enzymexla_ml_pass,
jit,
]
else
Expand All @@ -1674,6 +1676,7 @@ function compile_mlir!(
kern,
raise_passes,
lower_enzymexla_linalg_pass,
lower_enzymexla_ml_pass,
jit,
]
end,
Expand Down Expand Up @@ -1863,6 +1866,7 @@ function compile_mlir!(
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
lower_enzymexla_ml_pass,
jit,
]
else
Expand All @@ -1884,6 +1888,7 @@ function compile_mlir!(
kern,
raise_passes,
lower_enzymexla_linalg_pass,
lower_enzymexla_ml_pass,
jit,
]
end,
Expand All @@ -1906,6 +1911,7 @@ function compile_mlir!(
enzyme_pass,
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
lower_enzymexla_linalg_pass,
lower_enzymexla_ml_pass,
jit,
]
else
Expand All @@ -1919,6 +1925,7 @@ function compile_mlir!(
kern,
raise_passes,
lower_enzymexla_linalg_pass,
lower_enzymexla_ml_pass,
jit,
]
end,
Expand Down
8 changes: 8 additions & 0 deletions src/Configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ scope will use the provided values.
`ApproxTopK` for TPUs unless `fallback_approx_top_k_lowering` is set to `true`.
- `fallback_approx_top_k_lowering`: Whether to lower `Ops.approx_top_k` to
`stablehlo.top_k` if the XLA backend doesn't support `ApproxTopK`. Defaults to `true`.
- `nnlib_gelu_approximation`: Controls the approximation used for `NNlib.gelu_tanh`. Can
be `"TANH"` or `"SIGMOID"`. Defaults to `"SIGMOID"`.

### DotGeneral

Expand All @@ -38,6 +40,7 @@ function with_config(
convolution_precision=missing,
lower_partialsort_to_approx_top_k=missing,
fallback_approx_top_k_lowering=missing,
nnlib_gelu_approximation=missing,
)
config_vars = ()
dot_general_algorithm !== missing &&
Expand All @@ -58,13 +61,18 @@ function with_config(
FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering,
)
)
if nnlib_gelu_approximation !== missing
@assert nnlib_gelu_approximation in ("TANH", "SIGMOID") "Invalid nnlib_gelu_approximation: $nnlib_gelu_approximation. Expected \"TANH\" or \"SIGMOID\"."
config_vars = (config_vars..., NNLIB_GELU_APPROXIMATION => nnlib_gelu_approximation)
end

return ScopedValues.with(f, config_vars...)
end

# Lower to ApproxTopK
const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedValue(false)
const FALLBACK_APPROX_TOP_K_LOWERING = ScopedValue(true)
const NNLIB_GELU_APPROXIMATION = ScopedValue("SIGMOID")

# DotGeneral Attributes Configuration
"""
Expand Down
22 changes: 20 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
module Ops
using ..MLIR: MLIR
using ..MLIR.Dialects: stablehlo, chlo, enzyme
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
using ..Reactant:
Reactant,
TracedRArray,
Expand Down Expand Up @@ -3003,7 +3003,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
info_shape = batch_shape

op = MLIR.Dialects.enzymexla.linalg_lu(
op = enzymexla.linalg_lu(
x.mlir_data;
output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)),
Expand Down Expand Up @@ -3210,4 +3210,22 @@ end
end
end

@noinline function gelu(
x::Union{TracedRArray{T,N},TracedRNumber{T}},
approximation::String;
location=mlir_stacktrace("gelu", @__FILE__, @__LINE__),
) where {T,N}
@assert approximation in ("NONE", "TANH", "SIGMOID")

res = MLIR.IR.result(
enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approximation, location), 1
)

if x isa TracedRArray
return TracedRArray{T,N}((), res, size(x))
else
return TracedRNumber{T}((), res)
end
end

end # module Ops
Loading