Skip to content

Conversation

@CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Dec 26, 2025

Related to #2640
cc @gdalle

In this PR we do the following

  • Add the new dependency ADTypes.
  • Introduce the methods Flux.gradient(f, adtype, x...) and Flux.withgradient(f, adtype, x...), where adtype should be one of AutoZygote(), AutoMooncake(), AutoEnzyme(), AutoFiniteDifferences.
  • Revamp the test_gradients util to be AD backend agnostic. Also address the comment Enable other reactant tests #2600 (comment), making the test function more robust.
  • Add a Mooncake extension and corresponding tests

TODO:

  • news
  • docs
  • more tests
  • reactant

@CarloLucibello CarloLucibello marked this pull request as draft December 26, 2025 11:29
@codecov
Copy link

codecov bot commented Dec 26, 2025

Codecov Report

❌ Patch coverage is 0% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 31.91%. Comparing base (2a91836) to head (faaedc8).

Files with missing lines Patch % Lines
src/gradient.jl 0.00% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2645      +/-   ##
==========================================
- Coverage   32.09%   31.91%   -0.18%     
==========================================
  Files          32       32              
  Lines        2038     2049      +11     
==========================================
  Hits          654      654              
- Misses       1384     1395      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@CarloLucibello CarloLucibello marked this pull request as ready for review December 28, 2025 06:24
Copy link

@gdalle gdalle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks promising!
Note that without an amortized preparation step, Mooncake will be very slow. Hopefully that can be fixed by chalk-lab/Mooncake.jl#900

@wsmoses
Copy link
Contributor

wsmoses commented Dec 29, 2025

If you’re thinking of extending the api to support more general ADTypes I think it would likely be better to have the core function be something like update_with_gradient!(model, optimizer, ADType) that fuses the update and gradient.

this will be required for example, to successfully leverage reactant (and see luxs train state for examples)

@CarloLucibello
Copy link
Member Author

If you’re thinking of extending the api to support more general ADTypes I think it would likely be better to have the core function be something like update_with_gradient!(model, optimizer, ADType) that fuses the update and gradient.

we can add a training_step!(loss, adtype, model, data, opt_state)

Copy link
Contributor

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can DI be an extension here instead of a proper dependency? I worry a bit that users of flux and non DI based code can end up in version resolution hell since DI needs to mark the version compat of all auto diff packages simultaneously. I recall this causing issues for diffractor in the past.

You can still have the DI ext loaded as a req of any of the other exts here

@wsmoses
Copy link
Contributor

wsmoses commented Dec 29, 2025

If you’re thinking of extending the api to support more general ADTypes I think it would likely be better to have the core function be something like update_with_gradient!(model, optimizer, ADType) that fuses the update and gradient.

we can add a training_step!(loss, adtype, model, data, opt_state)

Sounds good to me. Somewhat relatedly for that (and perhaps gradient), you might need an additional arg for either a reactant compiled thunk or DI preparation

1 similar comment
@wsmoses
Copy link
Contributor

wsmoses commented Dec 29, 2025

If you’re thinking of extending the api to support more general ADTypes I think it would likely be better to have the core function be something like update_with_gradient!(model, optimizer, ADType) that fuses the update and gradient.

we can add a training_step!(loss, adtype, model, data, opt_state)

Sounds good to me. Somewhat relatedly for that (and perhaps gradient), you might need an additional arg for either a reactant compiled thunk or DI preparation

@wsmoses
Copy link
Contributor

wsmoses commented Dec 29, 2025

concurrently (though separately) for more general things than zygote, the error I mentioned here (#2600 (comment)) back in may about the structural equivalence check is still blocking (and was indeed what was blocking me at least from getting flux to work for the past six months).

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jan 2, 2026

@wsmoses do you know why the Enzyme tests are erroring on recurrent layers?
https://github.com/FluxML/Flux.jl/actions/runs/20655148072/job/59306597399?pr=2645#step:6:1264

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

@wsmoses do you know why the Enzyme tests are erroring on recurrent layers?

https://github.com/FluxML/Flux.jl/actions/runs/20655148072/job/59306597399?pr=2645#step:6:1264

It appears you made a differentiated function type unstable

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

@wsmoses do you know why the Enzyme tests are erroring on recurrent layers?

https://github.com/FluxML/Flux.jl/actions/runs/20655148072/job/59306597399?pr=2645#step:6:1264

Somewhat relatedly, as the changes to enzyme here will equally change/break the use of enzyme within reactant, can we fix that before merging here? Otherwise reactant CI (which does test flux), as well as other users, may be broken by this PR.

@CarloLucibello
Copy link
Member Author

It appears you made a differentiated function type unstable

I didn't change any function

@CarloLucibello
Copy link
Member Author

Somewhat relatedly,

since reactant test are still broken on master and in both #2600 and #2609, it shouldn't block this PR.

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

It appears you made a differentiated function type unstable

I didn't change any function

Well ci does seem to disagree (it complains that the relevant function returns Any via inference)

src/gradient.jl Outdated
return _enzyme_withgradient(f, _make_duplicated(x); zero=true)
end
function withgradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}) where {F,N}
return _enzyme_withgradient(f, map(_make_duplicated, x)...; zero=true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses the Enyzme tests work when using DI, but fail when introducing this specialized method

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

It appears you made a differentiated function type unstable

I didn't change any function

Well ci does seem to disagree (it complains that the relevant function returns Any via inference)

That does seem to be an anonymous function, and it does seem like you both changed the functions being differentiated and handlers in test_utils, in addition to the enzyme support utils

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jan 2, 2026

That does seem to be an anonymous function, and it does seem like you both changed the functions being differentiated and handlers in test_utils, in addition to the enzyme support utils.

Recurrent layers were type unstable already prior to this PR and the loss was already a closure passed as default argument loss = (f, xs...) -> mean(f(xs...)). So I still don't understand what is the relevant change that broke things.

Another weird thing is that DI is actually more robust than the direct call to Enzyme we are using, since commenting out the Enzyme specializations

function withgradient(f::F, adtype::AutoEnzyme, x) where {F}
    return _enzyme_withgradient(f, _make_duplicated(x); zero=true)
end
function withgradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}) where {F,N}
    return _enzyme_withgradient(f, map(_make_duplicated, x)...; zero=true)
end

the error goes away.

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

yeah so that specific error indicates that type inference inside the method deduced a concrete type (Float64) whereas typeinference at the callside found an abstract type (Any). Likely what's happening from the DI wrapper is that its deducing Any for both (and Enzyme works perfectly well on). It's better to retain the type stable version, though we ought figure out where/why it's lost. In that particular section of the code (sorry this is a bit internals specific) atm we support the inner type being less known than the outer type (or equal ofc), but not the other way round. So something changed due to this PR via type inference, though the specifics are a bit unclear (I can try to take a look and debug this more deeply in a bit).

@wsmoses
Copy link
Contributor

wsmoses commented Jan 2, 2026

Somewhat relatedly,

since reactant test are still broken on master and in both #2600 and #2609, it shouldn't block this PR.

actually using the Fix1 comment I made in the PR, alongside the test_leaf fix here does actually make it pass!

mooncake

update

cleanup

enzyme

grad x

finite differences

cleanup

vararg

unwrap

fix

testmode

fix finite diff

specialize

update

update

fix check equal leaves

fix

fix

fix

add some docs

update enzyme

more metal tests

news

cleanup
@CarloLucibello
Copy link
Member Author

yeah so that specific error indicates that type inference inside the method deduced a concrete type (Float64) whereas typeinference at the callside found an abstract type (Any). Likely what's happening from the DI wrapper is that its deducing Any for both (and Enzyme works perfectly well on). It's better to retain the type stable version, though we ought figure out where/why it's lost. In that particular section of the code (sorry this is a bit internals specific) atm we support the inner type being less known than the outer type (or equal ofc), but not the other way round. So something changed due to this PR via type inference, though the specifics are a bit unclear (I can try to take a look and debug this more deeply in a bit).

I'll stick with DI then. If you want to experiment with direct enzyme call just uncomment the AutoEnzyme specialization for withgradient.

@wsmoses
Copy link
Contributor

wsmoses commented Jan 4, 2026

yeah so that specific error indicates that type inference inside the method deduced a concrete type (Float64) whereas typeinference at the callside found an abstract type (Any). Likely what's happening from the DI wrapper is that its deducing Any for both (and Enzyme works perfectly well on). It's better to retain the type stable version, though we ought figure out where/why it's lost. In that particular section of the code (sorry this is a bit internals specific) atm we support the inner type being less known than the outer type (or equal ofc), but not the other way round. So something changed due to this PR via type inference, though the specifics are a bit unclear (I can try to take a look and debug this more deeply in a bit).

I'll stick with DI then. If you want to experiment with direct enzyme call just uncomment the AutoEnzyme specialization for withgradient.

Use of DI will break reactant (specifically DI's internal preparation has been shown to cause issues in this sort of setup isn't considered supported from the Enzyme/Reactant side atm and is likely to break without warning).

@CarloLucibello
Copy link
Member Author

@wsmoses
Copy link
Contributor

wsmoses commented Jan 4, 2026

You're using reactant on DI, which as I mentioned in my last comment isn't supported =/

@CarloLucibello
Copy link
Member Author

So something changed due to this PR via type inference, though the specifics are a bit unclear (I can try to take a look and debug this more deeply in a bit).

can you help then? I wouldn't know how to fix it

@wsmoses
Copy link
Contributor

wsmoses commented Jan 4, 2026

If you want to move forward here without causing breakage before I can sort out the cause of the type instability mismatch I think your options are:

  • fix the type instability introduced in this PR as distinct from main (I'm not quite sure it's not inferred but I presume the standard code typed and friends can help)
  • explicitly don't use DI on reactant compilation. I opened Introduce AutoDI wrapper type for DifferentiationInterface dispatch SciML/ADTypes.jl#140 earlier to distinguish DI uses of an adtype from non DI versions, so once that merged then you can use autoenzyme from within reactant and autodi{autoenzyme} for the other tests in the meantime

@wsmoses
Copy link
Contributor

wsmoses commented Jan 4, 2026

So something changed due to this PR via type inference, though the specifics are a bit unclear (I can try to take a look and debug this more deeply in a bit).

can you help then? I wouldn't know how to fix it

Yeah I can take a look but it probably won't be today

@gdalle
Copy link

gdalle commented Jan 6, 2026

You're using reactant on DI, which as I mentioned in my last comment isn't supported =/

See JuliaDiff/DifferentiationInterface.jl#918 if you want to keep track of the feature, which is reasonably high on my to-do list

@wsmoses
Copy link
Contributor

wsmoses commented Jan 6, 2026

that PR is distinct. That PR has DI on the outside of a reactant compile. What doesn't work here is Reactant compile of DI

@gdalle
Copy link

gdalle commented Jan 6, 2026

that PR is distinct. That PR has DI on the outside of a reactant compile. What doesn't work here is Reactant compile of DI

Sorry for the misunderstanding. Can you provide an MWE of Reactant failing to compile something with DI? I think I might know where a possible issue comes from, and if so it would be easy to fix

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jan 7, 2026

Can you provide an MWE of Reactant failing to compile something with DI?

    using Flux, Enzyme, Reactant, Statistics, LinearAlgebra, Random, MLDataDevices
    import DifferentiationInterface as DI
    dev = reactant_device()
    m = LayerNorm(2)
    x = randn(Float32, 2)
    mr, xr = (m, x) |> dev

    loss(m, x...) = mean(m(x...))
    @jit loss(mr, xr)
    DI.gradient(args -> loss(args...), AutoEnzyme(), (m, x))
    @jit DI.gradient(args -> loss(args...), AutoEnzyme(), (mr, xr))

@gdalle
Copy link

gdalle commented Jan 7, 2026

I'll take a look

@gdalle
Copy link

gdalle commented Jan 9, 2026

I tried to make it work with native Enzyme+Reactant but I didn't succeed. @wsmoses am I holding the thing wrong?

using Flux, Enzyme, Reactant, Statistics, LinearAlgebra, Random, MLDataDevices

dev = reactant_device()
m = LayerNorm(2)
x = randn(Float32, 2)
mr, xr = (m, x) |> dev

loss(m, x::X) where {X} = mean(m(x))

loss(m, x)  # works
Enzyme.gradient(Enzyme.Reverse, splat(loss), (m, x))  # works
Enzyme.gradient(Enzyme.Reverse, loss, m, x)  # works

@jit loss(mr, xr)  # works
@jit Enzyme.gradient(Enzyme.Reverse, splat(loss), (mr, xr))  # fails
@jit Enzyme.gradient(Enzyme.Reverse, loss, mr, xr)  # fails

Error message:

julia> @jit Enzyme.gradient(Enzyme.Reverse, splat(loss), (mr, xr))
ERROR: MethodError: no method matching act_from_type(::Type{MixedDuplicated{Tuple{LayerNorm{}, Reactant.TracedRArray{}}}}, ::Bool, ::Bool)
The function `act_from_type` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  act_from_type(::Type{<:Active}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:203
  act_from_type(::Type{<:DuplicatedNoNeed}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:223
  act_from_type(::Type{<:Duplicated}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:210
  ...

Stacktrace:
  [1] act_from_type (repeats 2 times)
    @ ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:200 [inlined]
  [2] overload_autodiff(::ReverseMode{…}, f::Const{…}, ::Type{…}, args::MixedDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:335
  [3] autodiff(rmode::ReverseMode{…}, f::Const{…}, rt::Type{…}, args::MixedDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Overlay.jl:36
  [4] autodiff
    @ ~/.julia/packages/Enzyme/QsBMf/src/Enzyme.jl:542 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:327 [inlined]
  [6] gradient
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
  [7] (::Nothing)(none::typeof(Enzyme.gradient), none::ReverseMode{…}, none::Base.Splat{…}, none::Tuple{…}, none::Tuple{})
    @ Reactant ./<missing>:0
  [8] GenericMemory
    @ ./boot.jl:516 [inlined]
  [9] IdDict
    @ ./iddict.jl:31 [inlined]
 [10] IdDict
    @ ./iddict.jl:49 [inlined]
 [11] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/RpjpI/src/EnzymeCore.jl:587 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:324 [inlined]
 [13] gradient
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
 [14] call_with_reactant(::typeof(Enzyme.gradient), ::ReverseMode{…}, ::Base.Splat{…}, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/LcOwI/src/utils.jl:0
 [15] make_mlir_fn(f::typeof(Enzyme.gradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:348
 [16] make_mlir_fn
    @ ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:277 [inlined]
 [17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(Enzyme.gradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1733
 [18] compile_mlir!
    @ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1695 [inlined]
 [19] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3690
 [20] compile_xla
    @ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3662 [inlined]
 [21] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3766
 [22] top-level scope
    @ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:2835
Some type information was truncated. Use `show(err)` to see complete types.

julia> @jit Enzyme.gradient(Enzyme.Reverse, loss, mr, xr)
ERROR: MethodError: no method matching act_from_type(::Type{MixedDuplicated{LayerNorm{typeof(identity), Flux.Scale{}, Float32, 1}}}, ::Bool, ::Bool)
The function `act_from_type` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  act_from_type(::Type{<:Active}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:203
  act_from_type(::Type{<:DuplicatedNoNeed}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:223
  act_from_type(::Type{<:Duplicated}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:210
  ...

Stacktrace:
  [1] act_from_type (repeats 2 times)
    @ ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:200 [inlined]
  [2] overload_autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::MixedDuplicated{…}, ::Duplicated{…})
    @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Enzyme.jl:335
  [3] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::MixedDuplicated{…}, ::Duplicated{…})
    @ Reactant ~/.julia/packages/Reactant/LcOwI/src/Overlay.jl:36
  [4] autodiff
    @ ~/.julia/packages/Enzyme/QsBMf/src/Enzyme.jl:542 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:287 [inlined]
  [6] gradient
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
  [7] (::Nothing)(none::typeof(Enzyme.gradient), none::ReverseMode{…}, none::typeof(loss), none::LayerNorm{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
  [8] GenericMemory
    @ ./boot.jl:516 [inlined]
  [9] IdDict
    @ ./iddict.jl:31 [inlined]
 [10] IdDict
    @ ./iddict.jl:49 [inlined]
 [11] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/RpjpI/src/EnzymeCore.jl:587 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:324 [inlined]
 [13] gradient
    @ ~/.julia/packages/Enzyme/QsBMf/src/sugar.jl:274 [inlined]
 [14] call_with_reactant(::typeof(Enzyme.gradient), ::ReverseMode{…}, ::typeof(loss), ::LayerNorm{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/LcOwI/src/utils.jl:0
 [15] make_mlir_fn(f::typeof(Enzyme.gradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:348
 [16] make_mlir_fn
    @ ~/.julia/packages/Reactant/LcOwI/src/TracedUtils.jl:277 [inlined]
 [17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(Enzyme.gradient), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1733
 [18] compile_mlir!
    @ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:1695 [inlined]
 [19] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3690
 [20] compile_xla
    @ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3662 [inlined]
 [21] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:3766
 [22] top-level scope
    @ ~/.julia/packages/Reactant/LcOwI/src/Compiler.jl:2835
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Contributor

wsmoses commented Jan 9, 2026

Reactant doesn't support mixedduplicated at the moment, which is why the custom Enzyme handling here (which also will be faster regardless) instead of DI's or Enzyme's generic handler which doesn't understand the calling convention of flux, needs to be kept.

@gdalle
Copy link

gdalle commented Jan 13, 2026

Might want to rename the PR if DI is no longer used

@CarloLucibello CarloLucibello changed the title extend gradient through DifferentiationInterface extend gradient to take an ADType argument Jan 15, 2026
@gdalle
Copy link

gdalle commented Jan 15, 2026

Out of curiosity, ignoring Enzyme, what's missing from DI for the purposes of Flux+Mooncake? Is the reasoning simply "there are only two backends you care about now so you might as well use them manually"? @CarloLucibello

@CarloLucibello
Copy link
Member Author

Yes that + the fact that a extension was needed (for different reasons) for each backend in any case.

@CarloLucibello CarloLucibello merged commit 43b0a91 into master Jan 17, 2026
6 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants