Skip to content

Commit 61ba67f

Browse files
committed
fix: match gelu with paper implementation
1 parent e801f01 commit 61ba67f

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ for (jlop, hloop) in (
66
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
77
end
88

9+
# See https://github.com/EnzymeAD/Reactant.jl/issues/1420
10+
# Without this we will never fuse the gelu into gemm
11+
function NNlib.gelu_tanh(x::TracedRNumber)
12+
α = NNlib.oftf(x, 0.044715)
13+
half = NNlib.oftf(x, 0.5)
14+
λ = sqrt(NNlib.oftf(x, 2 / pi))
15+
return x * half * (1 + tanh* (x + α * x^3)))
16+
end
17+
918
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
1019
x = T.(Reactant.materialize_traced_array(x))
1120
max_ = maximum(x; dims)

0 commit comments

Comments
 (0)