Skip to content

Commit 10dcd88

Browse files
authored
Merge branch 'main' into fixupl
2 parents 194b3ce + ed4ece3 commit 10dcd88

7 files changed

Lines changed: 114 additions & 54 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
3-
version = "0.2.261"
3+
version = "0.2.262"
44
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
55

66
[workspace]
@@ -149,7 +149,7 @@ PythonCall = "0.9.25"
149149
Random = "1.10"
150150
Random123 = "1.7"
151151
ReactantCore = "0.1.18"
152-
Reactant_jll = "0.0.379"
152+
Reactant_jll = "0.0.381"
153153
ScopedValues = "1.3.0"
154154
Scratch = "1.3"
155155
Serialization = "1.10"

ext/ReactantCUDAExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ end
575575
function recudaconvert(arg)
576576
return adapt(ReactantKernelAdaptor(), arg)
577577
end
578-
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
578+
Reactant.@reactant_overlay function CUDA.cudaconvert(arg)
579579
return recudaconvert(arg)
580580
end
581581

@@ -1125,7 +1125,7 @@ function mlir_extract_roots_from_value!(
11251125
end
11261126
end
11271127

1128-
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
1128+
Reactant.@reactant_overlay function (func::LLVMFunc{F,tt})(
11291129
args...;
11301130
convert=Val(true),
11311131
blocks::CuDim=1,
@@ -1600,7 +1600,7 @@ function _convert_bf16_value(
16001600
return src_val
16011601
end
16021602

1603-
Reactant.@reactant_overlay @noinline function CUDA.cufunction(
1603+
Reactant.@reactant_overlay function CUDA.cufunction(
16041604
f::F, tt::TT=Tuple{}; kwargs...
16051605
) where {F,TT}
16061606
res = Base.@lock CUDACore.cufunction_lock begin

ext/ReactantMPIExt/Overrides.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
using Reactant: @reactant_overlay, TracedRArray, TracedRNumber
22

3-
# @reactant_overlay @noinline function MPI.Init(; kwargs...)
3+
# @reactant_overlay function MPI.Init(; kwargs...)
44
# if !isempty(kwargs)
55
# @warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs...
66
# end
77
# return Ops.init()
88
# end
99

10-
# @reactant_overlay @noinline function MPI.Finalize(; kwargs...)
10+
# @reactant_overlay function MPI.Finalize(; kwargs...)
1111
# return Ops.finalize()
1212
# end
1313

14-
@reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm)
14+
@reactant_overlay function MPI.Comm_rank(comm::MPI.Comm)
1515
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
1616
return Ops.comm_rank()
1717
end
1818

19-
@reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm)
19+
@reactant_overlay function MPI.Comm_size(comm::MPI.Comm)
2020
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
2121
return Ops.comm_size()
2222
end
2323

24-
@reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm)
24+
@reactant_overlay function MPI.Barrier(comm::MPI.Comm)
2525
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
2626
return Ops.barrier()
2727
end

ext/ReactantNNlibExt/Overlay.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
@reactant_overlay @noinline function NNlib.conv!(y, x, w, cdims::DenseConvDims; kwargs...)
1+
@reactant_overlay function NNlib.conv!(y, x, w, cdims::DenseConvDims; kwargs...)
22
if any(Reactant.use_overlayed_version, (y, x, w))
33
overloaded_conv!(y, x, w, cdims; kwargs...)
44
else
55
Reactant.call_with_native(NNlib.conv!, y, x, w, cdims; kwargs...)
66
end
77
end
88

9-
@reactant_overlay @noinline function NNlib.maxpool!(y, x, pdims::NNlib.PoolDims; kwargs...)
9+
@reactant_overlay function NNlib.maxpool!(y, x, pdims::NNlib.PoolDims; kwargs...)
1010
if any(Reactant.use_overlayed_version, (y, x))
1111
overloaded_maxpool!(y, x, pdims; kwargs...)
1212
else
1313
Reactant.call_with_native(NNlib.maxpool!, y, x, pdims; kwargs...)
1414
end
1515
end
1616

17-
@reactant_overlay @noinline function NNlib.meanpool!(y, x, pdims::NNlib.PoolDims; kwargs...)
17+
@reactant_overlay function NNlib.meanpool!(y, x, pdims::NNlib.PoolDims; kwargs...)
1818
if any(Reactant.use_overlayed_version, (y, x))
1919
overloaded_meanpool!(y, x, pdims; kwargs...)
2020
else
2121
Reactant.call_with_native(NNlib.meanpool!, y, x, pdims; kwargs...)
2222
end
2323
end
2424

25-
@reactant_overlay @noinline function NNlib.∇conv_filter!(
25+
@reactant_overlay function NNlib.∇conv_filter!(
2626
dw, x, dy, cdims::NNlib.DenseConvDims; kwargs...
2727
)
2828
if any(Reactant.use_overlayed_version, (dw, x, dy))
@@ -32,7 +32,7 @@ end
3232
end
3333
end
3434

35-
@reactant_overlay @noinline function NNlib.∇conv_data!(
35+
@reactant_overlay function NNlib.∇conv_data!(
3636
dx, dy, w, cdims::NNlib.DenseConvDims; kwargs...
3737
)
3838
if any(Reactant.use_overlayed_version, (dx, dy, w))

src/Interpreter.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ import Core.Compiler:
1818

1919
Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE)
2020

21-
function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def)
21+
macro reactant_overlay(def)
22+
def = Expr(:macrocall, Symbol("@noinline"), __source__, def)
2223
return Base.Experimental.var"@overlay"(
2324
__source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def
2425
)

src/Overlay.jl

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ end
2525
# Enzyme.jl overlays
2626
const WITHIN_AUTODIFF = Ref(false)
2727

28-
@reactant_overlay @noinline function Enzyme.within_autodiff()
28+
@reactant_overlay function Enzyme.within_autodiff()
2929
return WITHIN_AUTODIFF[]
3030
end
3131

32-
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
32+
@reactant_overlay function Enzyme.autodiff_deferred(
3333
rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs}
3434
) where {FA<:Annotation,Nargs}
3535
original_within_autodiff = WITHIN_AUTODIFF[]
@@ -41,7 +41,7 @@ end
4141
end
4242
end
4343

44-
@reactant_overlay @noinline function Enzyme.autodiff(
44+
@reactant_overlay function Enzyme.autodiff(
4545
rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs}
4646
) where {FA<:Annotation,Nargs}
4747
original_within_autodiff = WITHIN_AUTODIFF[]
@@ -53,7 +53,7 @@ end
5353
end
5454
end
5555

56-
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
56+
@reactant_overlay function Enzyme.autodiff_deferred(
5757
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
5858
) where {FA<:Annotation,A<:Annotation,Nargs}
5959
original_within_autodiff = WITHIN_AUTODIFF[]
@@ -65,7 +65,7 @@ end
6565
end
6666
end
6767

68-
@reactant_overlay @noinline function Enzyme.autodiff(
68+
@reactant_overlay function Enzyme.autodiff(
6969
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
7070
) where {FA<:Annotation,A<:Annotation,Nargs}
7171
original_within_autodiff = WITHIN_AUTODIFF[]
@@ -92,11 +92,11 @@ end
9292
end
9393

9494
# Random.jl overlays
95-
@reactant_overlay @noinline function Random.default_rng()
95+
@reactant_overlay function Random.default_rng()
9696
return call_with_reactant(TracedRandom.default_rng)
9797
end
9898

99-
@reactant_overlay @noinline function TracedRandom.default_rng()
99+
@reactant_overlay function TracedRandom.default_rng()
100100
return ReactantRNG(
101101
promote_to(TracedRArray{UInt64,1}, TracedRandom.make_seed()), "DEFAULT"
102102
)
@@ -110,7 +110,7 @@ for randfun in (:rand, :randn, :randexp)
110110
overload_randfun! = Symbol(:overload_, randfun!)
111111

112112
@eval begin
113-
@reactant_overlay @noinline function Random.$(randfun)(
113+
@reactant_overlay function Random.$(randfun)(
114114
rng::AbstractRNG, ::Type{T}, dims::Dims
115115
) where {T}
116116
if unwrapped_eltype(T) <: ReactantPrimitive
@@ -123,13 +123,13 @@ for randfun in (:rand, :randn, :randexp)
123123
return call_with_native(Random.$(randfun), rng, T, dims)
124124
end
125125

126-
@reactant_overlay @noinline function Random.$(randfun)(
126+
@reactant_overlay function Random.$(randfun)(
127127
rng::AbstractRNG, dim1::Integer, dims::Integer...
128128
)
129129
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
130130
end
131131

132-
@reactant_overlay @noinline function Random.$(randfun)(
132+
@reactant_overlay function Random.$(randfun)(
133133
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
134134
) where {T}
135135
if unwrapped_eltype(T) <: ReactantPrimitive
@@ -143,7 +143,7 @@ for randfun in (:rand, :randn, :randexp)
143143
end
144144

145145
# scalars
146-
@reactant_overlay @noinline function Random.$(randfun)(
146+
@reactant_overlay function Random.$(randfun)(
147147
rng::AbstractRNG, ::Type{T}=Float64
148148
) where {T}
149149
if unwrapped_eltype(T) <: ReactantPrimitive
@@ -155,12 +155,10 @@ for randfun in (:rand, :randn, :randexp)
155155
end
156156

157157
# inplace
158-
@reactant_overlay @noinline function Random.$(randfun!)(
159-
rng::AbstractRNG, A::AnyTracedRArray
160-
)
158+
@reactant_overlay function Random.$(randfun!)(rng::AbstractRNG, A::AnyTracedRArray)
161159
return call_with_native(TracedRandom.$(overload_randfun!), rng, A)
162160
end
163-
@reactant_overlay @noinline function Random.$(randfun!)(A::AnyTracedRArray)
161+
@reactant_overlay function Random.$(randfun!)(A::AnyTracedRArray)
164162
return TracedRandom.$(overload_randfun!)(
165163
call_with_reactant(TracedRandom.default_rng), A
166164
)
@@ -176,7 +174,7 @@ for (cT, aT, bT) in (
176174
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
177175
)
178176
@eval begin
179-
@reactant_overlay @noinline function LinearAlgebra.mul!(
177+
@reactant_overlay function LinearAlgebra.mul!(
180178
C::CT, A::AT, B::BT, α::Number, β::Number
181179
) where {CT<:$cT,AT<:$aT,BT<:$bT}
182180
A, B = aos_to_soa(A), aos_to_soa(B)
@@ -196,7 +194,7 @@ for (cT, aT, bT) in (
196194
end
197195

198196
# Needed mostly for 1.10 where 3-arg mul is often specialized
199-
@reactant_overlay @noinline function LinearAlgebra.mul!(
197+
@reactant_overlay function LinearAlgebra.mul!(
200198
C::CT, A::AT, B::BT
201199
) where {CT<:$cT,AT<:$aT,BT<:$bT}
202200
call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false)
@@ -206,7 +204,7 @@ for (cT, aT, bT) in (
206204
end
207205

208206
# Base overloads
209-
@reactant_overlay @noinline function Base._stack(dims::Union{Integer,Colon}, iter)
207+
@reactant_overlay function Base._stack(dims::Union{Integer,Colon}, iter)
210208
if use_overlayed_version(iter)
211209
return call_with_native(TracedRArrayOverrides.overloaded_stack, dims, iter)
212210
else
@@ -223,15 +221,15 @@ end
223221
end
224222

225223
## fixes #493
226-
@reactant_overlay @noinline function Base._unique_dims(A::AbstractArray, dims::Colon)
224+
@reactant_overlay function Base._unique_dims(A::AbstractArray, dims::Colon)
227225
if use_overlayed_version(A)
228226
error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.")
229227
else
230228
call_with_native(Base._unique_dims, A, dims)
231229
end
232230
end
233231

234-
@reactant_overlay @noinline function Base.mapreduce(
232+
@reactant_overlay function Base.mapreduce(
235233
f,
236234
op,
237235
A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator};
@@ -248,7 +246,7 @@ end
248246
end
249247
end
250248

251-
@reactant_overlay @noinline function Base.map(f, x::AbstractArray, ys::AbstractArray...)
249+
@reactant_overlay function Base.map(f, x::AbstractArray, ys::AbstractArray...)
252250
if (
253251
use_overlayed_version(x) ||
254252
use_overlayed_version(f) ||
@@ -260,7 +258,7 @@ end
260258
end
261259
end
262260

263-
@reactant_overlay @noinline function Base.map!(
261+
@reactant_overlay function Base.map!(
264262
f, y::AbstractArray, x::AbstractArray, xs::AbstractArray...
265263
)
266264
if (
@@ -275,23 +273,23 @@ end
275273
end
276274
end
277275

278-
@reactant_overlay @noinline function Base._all(f, x::AbstractArray, dims)
276+
@reactant_overlay function Base._all(f, x::AbstractArray, dims)
279277
if use_overlayed_version(x) || use_overlayed_version(f)
280278
return call_with_native(TracedRArrayOverrides.overloaded_mapreduce, f, &, x; dims)
281279
else
282280
return call_with_native(Base._all, CallWithReactant(f), x, dims)
283281
end
284282
end
285283

286-
@reactant_overlay @noinline function Base._any(f, x::AbstractArray, dims)
284+
@reactant_overlay function Base._any(f, x::AbstractArray, dims)
287285
if use_overlayed_version(x) || use_overlayed_version(f)
288286
return call_with_native(TracedRArrayOverrides.overloaded_mapreduce, f, |, x; dims)
289287
else
290288
return call_with_native(Base._any, CallWithReactant(f), x, dims)
291289
end
292290
end
293291

294-
@reactant_overlay @noinline function Base._getindex(
292+
@reactant_overlay function Base._getindex(
295293
::IndexLinear, x::Array{T,N}, idxs::Vararg{Any,N}
296294
) where {T,N}
297295
if use_overlayed_version(idxs)
@@ -316,9 +314,7 @@ for (jlop, rop, default_pivot) in (
316314
(:cholesky!, :overloaded_cholesky, NoPivot),
317315
)
318316
@eval begin
319-
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
320-
x::AbstractArray; kwargs...
321-
)
317+
@reactant_overlay function LinearAlgebra.$(jlop)(x::AbstractArray; kwargs...)
322318
if use_overlayed_version(x)
323319
pivot = $(default_pivot)()
324320
return call_with_native(
@@ -332,7 +328,7 @@ for (jlop, rop, default_pivot) in (
332328
end
333329
end
334330

335-
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
331+
@reactant_overlay function LinearAlgebra.$(jlop)(
336332
x::AbstractArray, pivot::$(default_pivot); kwargs...
337333
)
338334
if use_overlayed_version(x)
@@ -351,9 +347,7 @@ end
351347

352348
for (jlop, rop) in ((:svd, :overloaded_svd),)
353349
@eval begin
354-
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
355-
x::AbstractArray; kwargs...
356-
)
350+
@reactant_overlay function LinearAlgebra.$(jlop)(x::AbstractArray; kwargs...)
357351
if use_overlayed_version(x)
358352
return call_with_native(
359353
TracedLinearAlgebra.$(rop),
@@ -367,14 +361,14 @@ for (jlop, rop) in ((:svd, :overloaded_svd),)
367361
end
368362
end
369363

370-
@reactant_overlay @noinline function LinearAlgebra.dot(x::AbstractArray, y::AbstractArray)
364+
@reactant_overlay function LinearAlgebra.dot(x::AbstractArray, y::AbstractArray)
371365
if use_overlayed_version(x) || use_overlayed_version(y)
372366
return call_with_native(TracedLinearAlgebra.overloaded_dot, x, y)
373367
else
374368
return call_with_native(LinearAlgebra.dot, x, y)
375369
end
376370
end
377-
@reactant_overlay @noinline function LinearAlgebra.dot(
371+
@reactant_overlay function LinearAlgebra.dot(
378372
x::AbstractVector, A::AbstractMatrix, y::AbstractVector
379373
)
380374
if use_overlayed_version(x) || use_overlayed_version(A) || use_overlayed_version(y)
@@ -386,9 +380,7 @@ end
386380

387381
# 3 arg multiplication is specialized in Base, but we can reorder the computation
388382
# as an MLIR optimization
389-
@reactant_overlay @noinline function Base.:(*)(
390-
a::AbstractArray, b::AbstractArray, c::AbstractArray
391-
)
383+
@reactant_overlay function Base.:(*)(a::AbstractArray, b::AbstractArray, c::AbstractArray)
392384
if use_overlayed_version((a, b, c))
393385
ab = call_with_native(TracedLinearAlgebra.overloaded_mul, a, b)
394386
return call_with_native(TracedLinearAlgebra.overloaded_mul, ab, c)
@@ -397,7 +389,7 @@ end
397389
end
398390
end
399391

400-
@reactant_overlay @noinline function Base.:(*)(
392+
@reactant_overlay function Base.:(*)(
401393
a::AbstractArray, b::AbstractArray, c::AbstractArray, d::AbstractArray
402394
)
403395
if use_overlayed_version((a, b, c, d))

0 commit comments

Comments
 (0)