2525# Enzyme.jl overlays
2626const WITHIN_AUTODIFF = Ref (false )
2727
28- @reactant_overlay @noinline function Enzyme. within_autodiff ()
28+ @reactant_overlay function Enzyme. within_autodiff ()
2929 return WITHIN_AUTODIFF[]
3030end
3131
32- @reactant_overlay @noinline function Enzyme. autodiff_deferred (
32+ @reactant_overlay function Enzyme. autodiff_deferred (
3333 rmode:: Enzyme.Mode , f:: FA , args:: Vararg{Annotation,Nargs}
3434) where {FA<: Annotation ,Nargs}
3535 original_within_autodiff = WITHIN_AUTODIFF[]
4141 end
4242end
4343
44- @reactant_overlay @noinline function Enzyme. autodiff (
44+ @reactant_overlay function Enzyme. autodiff (
4545 rmode:: Enzyme.Mode , f:: FA , args:: Vararg{Annotation,Nargs}
4646) where {FA<: Annotation ,Nargs}
4747 original_within_autodiff = WITHIN_AUTODIFF[]
5353 end
5454end
5555
56- @reactant_overlay @noinline function Enzyme. autodiff_deferred (
56+ @reactant_overlay function Enzyme. autodiff_deferred (
5757 rmode:: Enzyme.Mode , f:: FA , rt:: Type{A} , args:: Vararg{Annotation,Nargs}
5858) where {FA<: Annotation ,A<: Annotation ,Nargs}
5959 original_within_autodiff = WITHIN_AUTODIFF[]
6565 end
6666end
6767
68- @reactant_overlay @noinline function Enzyme. autodiff (
68+ @reactant_overlay function Enzyme. autodiff (
6969 rmode:: Enzyme.Mode , f:: FA , rt:: Type{A} , args:: Vararg{Annotation,Nargs}
7070) where {FA<: Annotation ,A<: Annotation ,Nargs}
7171 original_within_autodiff = WITHIN_AUTODIFF[]
9292end
9393
9494# Random.jl overlays
95- @reactant_overlay @noinline function Random. default_rng ()
95+ @reactant_overlay function Random. default_rng ()
9696 return call_with_reactant (TracedRandom. default_rng)
9797end
9898
99- @reactant_overlay @noinline function TracedRandom. default_rng ()
99+ @reactant_overlay function TracedRandom. default_rng ()
100100 return ReactantRNG (
101101 promote_to (TracedRArray{UInt64,1 }, TracedRandom. make_seed ()), " DEFAULT"
102102 )
@@ -110,7 +110,7 @@ for randfun in (:rand, :randn, :randexp)
110110 overload_randfun! = Symbol (:overload_ , randfun!)
111111
112112 @eval begin
113- @reactant_overlay @noinline function Random. $ (randfun)(
113+ @reactant_overlay function Random. $ (randfun)(
114114 rng:: AbstractRNG , :: Type{T} , dims:: Dims
115115 ) where {T}
116116 if unwrapped_eltype (T) <: ReactantPrimitive
@@ -123,13 +123,13 @@ for randfun in (:rand, :randn, :randexp)
123123 return call_with_native (Random.$ (randfun), rng, T, dims)
124124 end
125125
126- @reactant_overlay @noinline function Random. $ (randfun)(
126+ @reactant_overlay function Random. $ (randfun)(
127127 rng:: AbstractRNG , dim1:: Integer , dims:: Integer...
128128 )
129129 return TracedRandom.$ (overload_randfun)(rng, dim1, dims... )
130130 end
131131
132- @reactant_overlay @noinline function Random. $ (randfun)(
132+ @reactant_overlay function Random. $ (randfun)(
133133 rng:: AbstractRNG , :: Type{T} , dim1:: Integer , dims:: Integer...
134134 ) where {T}
135135 if unwrapped_eltype (T) <: ReactantPrimitive
@@ -143,7 +143,7 @@ for randfun in (:rand, :randn, :randexp)
143143 end
144144
145145 # scalars
146- @reactant_overlay @noinline function Random. $ (randfun)(
146+ @reactant_overlay function Random. $ (randfun)(
147147 rng:: AbstractRNG , :: Type{T} = Float64
148148 ) where {T}
149149 if unwrapped_eltype (T) <: ReactantPrimitive
@@ -155,12 +155,10 @@ for randfun in (:rand, :randn, :randexp)
155155 end
156156
157157 # inplace
158- @reactant_overlay @noinline function Random. $ (randfun!)(
159- rng:: AbstractRNG , A:: AnyTracedRArray
160- )
158+ @reactant_overlay function Random. $ (randfun!)(rng:: AbstractRNG , A:: AnyTracedRArray )
161159 return call_with_native (TracedRandom.$ (overload_randfun!), rng, A)
162160 end
163- @reactant_overlay @noinline function Random. $ (randfun!)(A:: AnyTracedRArray )
161+ @reactant_overlay function Random. $ (randfun!)(A:: AnyTracedRArray )
164162 return TracedRandom.$ (overload_randfun!)(
165163 call_with_reactant (TracedRandom. default_rng), A
166164 )
@@ -176,7 +174,7 @@ for (cT, aT, bT) in (
176174 (:AbstractMatrix , :AbstractMatrix , :AbstractVecOrMat ),
177175)
178176 @eval begin
179- @reactant_overlay @noinline function LinearAlgebra. mul! (
177+ @reactant_overlay function LinearAlgebra. mul! (
180178 C:: CT , A:: AT , B:: BT , α:: Number , β:: Number
181179 ) where {CT<: $cT ,AT<: $aT ,BT<: $bT }
182180 A, B = aos_to_soa (A), aos_to_soa (B)
@@ -196,7 +194,7 @@ for (cT, aT, bT) in (
196194 end
197195
198196 # Needed mostly for 1.10 where 3-arg mul is often specialized
199- @reactant_overlay @noinline function LinearAlgebra. mul! (
197+ @reactant_overlay function LinearAlgebra. mul! (
200198 C:: CT , A:: AT , B:: BT
201199 ) where {CT<: $cT ,AT<: $aT ,BT<: $bT }
202200 call_with_reactant (LinearAlgebra. mul!, C, A, B, true , false )
@@ -206,7 +204,7 @@ for (cT, aT, bT) in (
206204end
207205
208206# Base overloads
209- @reactant_overlay @noinline function Base. _stack (dims:: Union{Integer,Colon} , iter)
207+ @reactant_overlay function Base. _stack (dims:: Union{Integer,Colon} , iter)
210208 if use_overlayed_version (iter)
211209 return call_with_native (TracedRArrayOverrides. overloaded_stack, dims, iter)
212210 else
@@ -223,15 +221,15 @@ end
223221end
224222
225223# # fixes #493
226- @reactant_overlay @noinline function Base. _unique_dims (A:: AbstractArray , dims:: Colon )
224+ @reactant_overlay function Base. _unique_dims (A:: AbstractArray , dims:: Colon )
227225 if use_overlayed_version (A)
228226 error (" Reactant doesn't have a `Base._unique_dims` with the current interpreter." )
229227 else
230228 call_with_native (Base. _unique_dims, A, dims)
231229 end
232230end
233231
234- @reactant_overlay @noinline function Base. mapreduce (
232+ @reactant_overlay function Base. mapreduce (
235233 f,
236234 op,
237235 A:: Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator} ;
248246 end
249247end
250248
251- @reactant_overlay @noinline function Base. map (f, x:: AbstractArray , ys:: AbstractArray... )
249+ @reactant_overlay function Base. map (f, x:: AbstractArray , ys:: AbstractArray... )
252250 if (
253251 use_overlayed_version (x) ||
254252 use_overlayed_version (f) ||
260258 end
261259end
262260
263- @reactant_overlay @noinline function Base. map! (
261+ @reactant_overlay function Base. map! (
264262 f, y:: AbstractArray , x:: AbstractArray , xs:: AbstractArray...
265263)
266264 if (
@@ -275,23 +273,23 @@ end
275273 end
276274end
277275
278- @reactant_overlay @noinline function Base. _all (f, x:: AbstractArray , dims)
276+ @reactant_overlay function Base. _all (f, x:: AbstractArray , dims)
279277 if use_overlayed_version (x) || use_overlayed_version (f)
280278 return call_with_native (TracedRArrayOverrides. overloaded_mapreduce, f, & , x; dims)
281279 else
282280 return call_with_native (Base. _all, CallWithReactant (f), x, dims)
283281 end
284282end
285283
286- @reactant_overlay @noinline function Base. _any (f, x:: AbstractArray , dims)
284+ @reactant_overlay function Base. _any (f, x:: AbstractArray , dims)
287285 if use_overlayed_version (x) || use_overlayed_version (f)
288286 return call_with_native (TracedRArrayOverrides. overloaded_mapreduce, f, | , x; dims)
289287 else
290288 return call_with_native (Base. _any, CallWithReactant (f), x, dims)
291289 end
292290end
293291
294- @reactant_overlay @noinline function Base. _getindex (
292+ @reactant_overlay function Base. _getindex (
295293 :: IndexLinear , x:: Array{T,N} , idxs:: Vararg{Any,N}
296294) where {T,N}
297295 if use_overlayed_version (idxs)
@@ -316,9 +314,7 @@ for (jlop, rop, default_pivot) in (
316314 (:cholesky! , :overloaded_cholesky , NoPivot),
317315)
318316 @eval begin
319- @reactant_overlay @noinline function LinearAlgebra. $ (jlop)(
320- x:: AbstractArray ; kwargs...
321- )
317+ @reactant_overlay function LinearAlgebra. $ (jlop)(x:: AbstractArray ; kwargs... )
322318 if use_overlayed_version (x)
323319 pivot = $ (default_pivot)()
324320 return call_with_native (
@@ -332,7 +328,7 @@ for (jlop, rop, default_pivot) in (
332328 end
333329 end
334330
335- @reactant_overlay @noinline function LinearAlgebra. $ (jlop)(
331+ @reactant_overlay function LinearAlgebra. $ (jlop)(
336332 x:: AbstractArray , pivot:: $ (default_pivot); kwargs...
337333 )
338334 if use_overlayed_version (x)
351347
352348for (jlop, rop) in ((:svd , :overloaded_svd ),)
353349 @eval begin
354- @reactant_overlay @noinline function LinearAlgebra. $ (jlop)(
355- x:: AbstractArray ; kwargs...
356- )
350+ @reactant_overlay function LinearAlgebra. $ (jlop)(x:: AbstractArray ; kwargs... )
357351 if use_overlayed_version (x)
358352 return call_with_native (
359353 TracedLinearAlgebra.$ (rop),
@@ -367,14 +361,14 @@ for (jlop, rop) in ((:svd, :overloaded_svd),)
367361 end
368362end
369363
370- @reactant_overlay @noinline function LinearAlgebra. dot (x:: AbstractArray , y:: AbstractArray )
364+ @reactant_overlay function LinearAlgebra. dot (x:: AbstractArray , y:: AbstractArray )
371365 if use_overlayed_version (x) || use_overlayed_version (y)
372366 return call_with_native (TracedLinearAlgebra. overloaded_dot, x, y)
373367 else
374368 return call_with_native (LinearAlgebra. dot, x, y)
375369 end
376370end
377- @reactant_overlay @noinline function LinearAlgebra. dot (
371+ @reactant_overlay function LinearAlgebra. dot (
378372 x:: AbstractVector , A:: AbstractMatrix , y:: AbstractVector
379373)
380374 if use_overlayed_version (x) || use_overlayed_version (A) || use_overlayed_version (y)
386380
387381# 3 arg multiplication is specialized in Base, but we can reorder the computation
388382# as an MLIR optimization
389- @reactant_overlay @noinline function Base.:(* )(
390- a:: AbstractArray , b:: AbstractArray , c:: AbstractArray
391- )
383+ @reactant_overlay function Base.:(* )(a:: AbstractArray , b:: AbstractArray , c:: AbstractArray )
392384 if use_overlayed_version ((a, b, c))
393385 ab = call_with_native (TracedLinearAlgebra. overloaded_mul, a, b)
394386 return call_with_native (TracedLinearAlgebra. overloaded_mul, ab, c)
397389 end
398390end
399391
400- @reactant_overlay @noinline function Base.:(* )(
392+ @reactant_overlay function Base.:(* )(
401393 a:: AbstractArray , b:: AbstractArray , c:: AbstractArray , d:: AbstractArray
402394)
403395 if use_overlayed_version ((a, b, c, d))
0 commit comments