-
-
Notifications
You must be signed in to change notification settings - Fork 615
extend gradient to take an ADType argument
#2645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2645 +/- ##
==========================================
- Coverage 32.09% 31.91% -0.18%
==========================================
Files 32 32
Lines 2038 2049 +11
==========================================
Hits 654 654
- Misses 1384 1395 +11 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
gdalle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks promising!
Note that without an amortized preparation step, Mooncake will be very slow. Hopefully that can be fixed by chalk-lab/Mooncake.jl#900
|
If you’re thinking of extending the api to support more general ADTypes I think it would likely be better to have the core function be something like update_with_gradient!(model, optimizer, ADType) that fuses the update and gradient. this will be required for example, to successfully leverage reactant (and see luxs train state for examples) |
we can add a |
wsmoses
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can DI be an extension here instead of a proper dependency? I worry a bit that users of flux and non DI based code can end up in version resolution hell since DI needs to mark the version compat of all auto diff packages simultaneously. I recall this causing issues for diffractor in the past.
You can still have the DI ext loaded as a req of any of the other exts here
Sounds good to me. Somewhat relatedly for that (and perhaps gradient), you might need an additional arg for either a reactant compiled thunk or DI preparation |
1 similar comment
Sounds good to me. Somewhat relatedly for that (and perhaps gradient), you might need an additional arg for either a reactant compiled thunk or DI preparation |
|
concurrently (though separately) for more general things than zygote, the error I mentioned here (#2600 (comment)) back in may about the structural equivalence check is still blocking (and was indeed what was blocking me at least from getting flux to work for the past six months). |
|
@wsmoses do you know why the Enzyme tests are erroring on recurrent layers? |
It appears you made a differentiated function type unstable |
Somewhat relatedly, as the changes to enzyme here will equally change/break the use of enzyme within reactant, can we fix that before merging here? Otherwise reactant CI (which does test flux), as well as other users, may be broken by this PR. |
I didn't change any function |
Well ci does seem to disagree (it complains that the relevant function returns Any via inference) |
src/gradient.jl
Outdated
| return _enzyme_withgradient(f, _make_duplicated(x); zero=true) | ||
| end | ||
| function withgradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}) where {F,N} | ||
| return _enzyme_withgradient(f, map(_make_duplicated, x)...; zero=true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wsmoses the Enyzme tests work when using DI, but fail when introducing this specialized method
That does seem to be an anonymous function, and it does seem like you both changed the functions being differentiated and handlers in test_utils, in addition to the enzyme support utils |
Recurrent layers were type unstable already prior to this PR and the loss was already a closure passed as default argument Another weird thing is that DI is actually more robust than the direct call to Enzyme we are using, since commenting out the Enzyme specializations function withgradient(f::F, adtype::AutoEnzyme, x) where {F}
return _enzyme_withgradient(f, _make_duplicated(x); zero=true)
end
function withgradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}) where {F,N}
return _enzyme_withgradient(f, map(_make_duplicated, x)...; zero=true)
endthe error goes away. |
|
yeah so that specific error indicates that type inference inside the method deduced a concrete type (Float64) whereas typeinference at the callside found an abstract type (Any). Likely what's happening from the DI wrapper is that its deducing Any for both (and Enzyme works perfectly well on). It's better to retain the type stable version, though we ought figure out where/why it's lost. In that particular section of the code (sorry this is a bit internals specific) atm we support the inner type being less known than the outer type (or equal ofc), but not the other way round. So something changed due to this PR via type inference, though the specifics are a bit unclear (I can try to take a look and debug this more deeply in a bit). |
mooncake update cleanup enzyme grad x finite differences cleanup vararg unwrap fix testmode fix finite diff specialize update update fix check equal leaves fix fix fix add some docs update enzyme more metal tests news cleanup
8776de3 to
0edf736
Compare
I'll stick with DI then. If you want to experiment with direct enzyme call just uncomment the AutoEnzyme specialization for |
Use of DI will break reactant (specifically DI's internal preparation has been shown to cause issues in this sort of setup isn't considered supported from the Enzyme/Reactant side atm and is likely to break without warning). |
|
You're using reactant on DI, which as I mentioned in my last comment isn't supported =/ |
can you help then? I wouldn't know how to fix it |
|
If you want to move forward here without causing breakage before I can sort out the cause of the type instability mismatch I think your options are:
|
Yeah I can take a look but it probably won't be today |
See JuliaDiff/DifferentiationInterface.jl#918 if you want to keep track of the feature, which is reasonably high on my to-do list |
|
that PR is distinct. That PR has DI on the outside of a reactant compile. What doesn't work here is Reactant compile of DI |
Sorry for the misunderstanding. Can you provide an MWE of Reactant failing to compile something with DI? I think I might know where a possible issue comes from, and if so it would be easy to fix |
using Flux, Enzyme, Reactant, Statistics, LinearAlgebra, Random, MLDataDevices
import DifferentiationInterface as DI
dev = reactant_device()
m = LayerNorm(2)
x = randn(Float32, 2)
mr, xr = (m, x) |> dev
loss(m, x...) = mean(m(x...))
@jit loss(mr, xr)
DI.gradient(args -> loss(args...), AutoEnzyme(), (m, x))
@jit DI.gradient(args -> loss(args...), AutoEnzyme(), (mr, xr)) |
|
I'll take a look |
|
I tried to make it work with native Enzyme+Reactant but I didn't succeed. @wsmoses am I holding the thing wrong? using Flux, Enzyme, Reactant, Statistics, LinearAlgebra, Random, MLDataDevices
dev = reactant_device()
m = LayerNorm(2)
x = randn(Float32, 2)
mr, xr = (m, x) |> dev
loss(m, x::X) where {X} = mean(m(x))
loss(m, x) # works
Enzyme.gradient(Enzyme.Reverse, splat(loss), (m, x)) # works
Enzyme.gradient(Enzyme.Reverse, loss, m, x) # works
@jit loss(mr, xr) # works
@jit Enzyme.gradient(Enzyme.Reverse, splat(loss), (mr, xr)) # fails
@jit Enzyme.gradient(Enzyme.Reverse, loss, mr, xr) # failsError message: julia> @jit Enzyme.gradient(Enzyme.Reverse, splat(loss), (mr, xr))
ERROR: MethodError: no method matching act_from_type(::Type{MixedDuplicated{Tuple{LayerNorm{…}, Reactant.TracedRArray{…}}}}, ::Bool, ::Bool)
The function `act_from_type` exists, but no method is defined for this combination of argument types.
Closest candidates are:
act_from_type(::Type{<:Active}, ::Any, ::Any)
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:203
act_from_type(::Type{<:DuplicatedNoNeed}, ::Any, ::Any)
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:223
act_from_type(::Type{<:Duplicated}, ::Any, ::Any)
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:210
...
Stacktrace:
[1] act_from_type (repeats 2 times)
@ ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:200 [inlined]
[2] overload_autodiff(::ReverseMode{…}, f::Const{…}, ::Type{…}, args::MixedDuplicated{…})
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:335
[3] autodiff(rmode::ReverseMode{…}, f::Const{…}, rt::Type{…}, args::MixedDuplicated{…})
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Overlay.jl:36
[4] autodiff
@ ~/.julia/packages/Enzyme/QsBMf/src/Enzyme.jl:542 [inlined]
[5] macro expansion
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:327 [inlined]
[6] gradient
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
[7] (::Nothing)(none::typeof(Enzyme.gradient), none::ReverseMode{…}, none::Base.Splat{…}, none::Tuple{…}, none::Tuple{})
@ Reactant ./<missing>:0
[8] GenericMemory
@ ./boot.jl:516 [inlined]
[9] IdDict
@ ./iddict.jl:31 [inlined]
[10] IdDict
@ ./iddict.jl:49 [inlined]
[11] make_zero (repeats 2 times)
@ ~/.julia/packages/EnzymeCore/RpjpI/src/EnzymeCore.jl:587 [inlined]
[12] macro expansion
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:324 [inlined]
[13] gradient
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
[14] call_with_reactant(::typeof(Enzyme.gradient), ::ReverseMode{…}, ::Base.Splat{…}, ::Tuple{…})
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/utils.jl:0
[15] make_mlir_fn(f::typeof(Enzyme.gradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:348
[16] make_mlir_fn
@ ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:277 [inlined]
[17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(Enzyme.gradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1733
[18] compile_mlir!
@ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1695 [inlined]
[19] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3690
[20] compile_xla
@ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3662 [inlined]
[21] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3766
[22] top-level scope
@ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:2835
Some type information was truncated. Use `show(err)` to see complete types.
julia> @jit Enzyme.gradient(Enzyme.Reverse, loss, mr, xr)
ERROR: MethodError: no method matching act_from_type(::Type{MixedDuplicated{LayerNorm{typeof(identity), Flux.Scale{…}, Float32, 1}}}, ::Bool, ::Bool)
The function `act_from_type` exists, but no method is defined for this combination of argument types.
Closest candidates are:
act_from_type(::Type{<:Active}, ::Any, ::Any)
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:203
act_from_type(::Type{<:DuplicatedNoNeed}, ::Any, ::Any)
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:223
act_from_type(::Type{<:Duplicated}, ::Any, ::Any)
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:210
...
Stacktrace:
[1] act_from_type (repeats 2 times)
@ ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:200 [inlined]
[2] overload_autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::MixedDuplicated{…}, ::Duplicated{…})
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:335
[3] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::MixedDuplicated{…}, ::Duplicated{…})
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/Overlay.jl:36
[4] autodiff
@ ~/.julia/packages/Enzyme/QsBMf/src/Enzyme.jl:542 [inlined]
[5] macro expansion
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:287 [inlined]
[6] gradient
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
[7] (::Nothing)(none::typeof(Enzyme.gradient), none::ReverseMode{…}, none::typeof(loss), none::LayerNorm{…}, none::Tuple{…})
@ Reactant ./<missing>:0
[8] GenericMemory
@ ./boot.jl:516 [inlined]
[9] IdDict
@ ./iddict.jl:31 [inlined]
[10] IdDict
@ ./iddict.jl:49 [inlined]
[11] make_zero (repeats 2 times)
@ ~/.julia/packages/EnzymeCore/RpjpI/src/EnzymeCore.jl:587 [inlined]
[12] macro expansion
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:324 [inlined]
[13] gradient
@ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
[14] call_with_reactant(::typeof(Enzyme.gradient), ::ReverseMode{…}, ::typeof(loss), ::LayerNorm{…}, ::Reactant.TracedRArray{…})
@ Reactant ~/.julia/packages/Reactant/LcOwI/src/utils.jl:0
[15] make_mlir_fn(f::typeof(Enzyme.gradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:348
[16] make_mlir_fn
@ ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:277 [inlined]
[17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(Enzyme.gradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1733
[18] compile_mlir!
@ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1695 [inlined]
[19] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3690
[20] compile_xla
@ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3662 [inlined]
[21] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3766
[22] top-level scope
@ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:2835
Some type information was truncated. Use `show(err)` to see complete types. |
|
Reactant doesn't support mixedduplicated at the moment, which is why the custom Enzyme handling here (which also will be faster regardless) instead of DI's or Enzyme's generic handler which doesn't understand the calling convention of flux, needs to be kept. |
|
Might want to rename the PR if DI is no longer used |
gradient to take an ADType argument
|
Out of curiosity, ignoring Enzyme, what's missing from DI for the purposes of Flux+Mooncake? Is the reasoning simply "there are only two backends you care about now so you might as well use them manually"? @CarloLucibello |
|
Yes that + the fact that a extension was needed (for different reasons) for each backend in any case. |
Related to #2640
cc @gdalle
In this PR we do the following
ADTypes.Flux.gradient(f, adtype, x...)andFlux.withgradient(f, adtype, x...), where adtype should be one ofAutoZygote(),AutoMooncake(),AutoEnzyme(),AutoFiniteDifferences.test_gradientsutil to be AD backend agnostic. Also address the comment Enable other reactant tests #2600 (comment), making the test function more robust.TODO: