Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions docs/src/reference/training/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The available optimization rules are listed the [optimisation rules](@ref man-op

```@docs
Flux.Train.setup
Flux.Train.train!(loss, model, data, state)
Flux.Train.train!
Optimisers.update
Optimisers.update!
Optimisers.setup
Expand All @@ -36,11 +36,6 @@ julia> opt_state = Flux.setup(Adam(0), model);
julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
```

```@docs
Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
```


## Optimisation Modifiers

The state returned by `setup` can be modified to temporarily prevent training of
Expand Down
82 changes: 31 additions & 51 deletions ext/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
module FluxEnzymeExt

using Flux
import Flux.Train: _enzyme_train!

import Optimisers
import Functors
import Enzyme
using Enzyme: EnzymeCore, EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed
using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
using ProgressLogging: @withprogress, @logprogress

EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true

Expand All @@ -28,13 +26,13 @@ _trymake_duplicated(x) = EnzymeCore.Duplicated(x, EnzymeCore.make_zero(x))


function _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval)
_check_mutable(x)
end
ad = Enzyme.set_runtime_activity(Reverse)
Enzyme.autodiff(ad, Const(f), Active, args...)
return map(_grad_or_nothing, args)
for x in args
zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval)
_check_mutable(x)
end
ad = Enzyme.set_runtime_activity(Reverse)
Enzyme.autodiff(ad, Const(f), Active, args...)
return map(_grad_or_nothing, args)
end

_check_mutable(x::Const) = nothing
Expand All @@ -48,30 +46,30 @@ _grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

function _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval)
_check_mutable(x)
end

# In order to support auxillary outputs, we try different ways.

## Take I, doesn't allow for aux at all.
ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
_, result = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, args...)

## Take II, using split mode.
## This fails with RNNs https://github.com/EnzymeAD/Enzyme.jl/issues/2897
# forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
# tape, result, shadow_result = forward(Const(f), args...)
# reverse(Const(f), args..., _sensitivity(result), tape)

## Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
## This doesn't work with Reactant
# dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
# ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
# _, result = autodiff(ad, Const(_ref_loss!), Const, dup_loss, Const(f), args...)

return (; val = result, grad = map(_grad_or_nothing, args))
for x in args
zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval)
_check_mutable(x)
end

# In order to support auxillary outputs, we try different ways.

## Take I, doesn't allow for aux at all.
ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
_, result = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, args...)

## Take II, using split mode.
## This fails with RNNs https://github.com/EnzymeAD/Enzyme.jl/issues/2897
# forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
# tape, result, shadow_result = forward(Const(f), args...)
# reverse(Const(f), args..., _sensitivity(result), tape)

## Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
## This doesn't work with Reactant
# dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
# ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
# _, result = autodiff(ad, Const(_ref_loss!), Const, dup_loss, Const(f), args...)

return (; val = result, grad = map(_grad_or_nothing, args))
end

## for Take II above
Expand All @@ -94,22 +92,4 @@ end
# or else a Tuple or NamedTuple whose first element is a real number.""")


### Flux.Train, for train!

function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)
l, gs = Flux.withgradient(loss, AutoEnzyme(), model, map(Const, d_splat)...)
if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
model = Duplicated(model2, model.dval)

@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end

end # FluxEnzymeExt
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using Zygote.ForwardDiff: value
using EnzymeCore: EnzymeCore

@reexport using ADTypes # AutoZygote, AutoMooncake, etc...
using ADTypes: AbstractADType

@reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!,
default_device_rng,
Expand Down
3 changes: 1 addition & 2 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ julia> Flux.gradient(f, AutoMooncake(), [1.0, 2.0, 3.0])
([2.0, 2.0, 2.0],)
```
"""
function gradient(f, adtype::ADTypes.AbstractADType, args...)
function gradient(f, adtype::AbstractADType, args...)
error("AD backend has to be loaded to use `gradient(f, AutoXXX(), args...)`.
Make sure to `using` the corresponding package, e.g. `using Mooncake` for `AutoMooncake()`.
Supported backends are $SUPPORTED_AD_BACKENDS.")
end


# Default gradient using Zygote
function gradient(f, args...; zero::Bool=true)
for a in args
Expand Down
87 changes: 40 additions & 47 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ using Functors: fmap, fmapstructure
using ..Flux: Flux

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote
using EnzymeCore: Duplicated
using ADTypes: AbstractADType, AutoEnzyme, AutoZygote

export setup, train!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote
using EnzymeCore: Duplicated

"""
opt_state = setup(rule, model)

Expand Down Expand Up @@ -49,7 +46,7 @@ function setup(rule::Optimisers.AbstractRule, model)
state = Optimisers.setup(rule, model)
# This check only needs foreach; using fmap caused https://github.com/FluxML/Flux.jl/issues/2144
fmapstructure(model, exclude = Optimisers.isnumeric) do x
Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`.
Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
end
return state
Expand All @@ -63,15 +60,17 @@ Special method for use with Enzyme.jl, ignores the stored gradient.
setup(rule::Optimisers.AbstractRule, model::Duplicated) = setup(rule, model.val)

"""
train!(loss, model, data, opt_state)
train!(loss, [adtype,] model, data, opt_state)

Uses a `loss` function and training `data` to improve the `model`'s parameters
according to a particular optimisation rule encoded in `opt_state`.

Iterates through `data` once, evaluating for each `d in data` either
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.

If `model` is an Enzyme.Duplicated and `Enzyme.jl` is loaded, gradients will be computed with Enzyme,
otherwise they will be computed with Zygote.
The optional argument `adtype`, selects an automatic differentiation engine among the ones supported by
[`gradient`](@ref). If no `adtype` is given, then Zygote is used by default, unless `model` is of type `Duplicated` from Enzyme.jl,
in which case Enzyme is used.

For example, with these definitions...
```
Expand Down Expand Up @@ -108,62 +107,56 @@ It adds only a few features to the loop above:
* Callback functions are not supported.
(But any code can be included in the above `for` loop.)
"""
function train!(loss, model, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)
function train!(loss, adtype::AbstractADType, model, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)

l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
l, gs = Flux.withgradient(m -> loss(m, d_splat...), adtype, model)

if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end

opt, model = _update!(opt, model, gs[1])

@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end

opt, model = Optimisers.update!(opt, model, gs[1])
_update!(opt_state, model, grads) = Optimisers.update!(opt_state, model, grads)

@logprogress Base.haslength(data) ? i/length(data) : nothing
end
function _update!(opt_state, model::Duplicated, grad)
opt_state, model2 = Optimisers.update!(opt_state, model.val, grad)
return opt_state, Duplicated(model2, model.dval)
end


train!(loss, model, data, opt; cb = nothing) = train!(loss, AutoZygote(), model, data, opt; cb)

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
return train!(loss, model, data, _rule_to_state(model, rule); cb)
end

function _rule_to_state(model, rule::Optimisers.AbstractRule)
state = setup(rule, model)
@gensym warn_id
name = typeof(rule).name.name
fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf
leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes.
Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id
leaf
end
state
state = setup(rule, model)
@gensym warn_id
name = typeof(rule).name.name
fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf
leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes.
Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id
leaf
end
return state
end

"""
train!(loss, Duplicated(model), data, opt_state)

This method uses Enzyme.jl instead of Zygote.jl to compute the gradients,
but is otherwise the same as `train!(loss, model, data, opt_state)`.

Only available when Enzyme is loaded.

!!! compat "New"
This method was added in Flux 0.13.9.

"""
train!(loss, model::Duplicated, data, opt; cb = nothing) = _enzyme_train!(loss, model, data, opt; cb = nothing)

# FluxEnzymeExt defines more specific _enzyme_train!(loss, model::Duplicated, data, opt; cb)
_enzyme_train!(loss, model, data, opt; cb = nothing) = throw(ArgumentError("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded"))
train!(loss, model::Duplicated, data, opt; cb = nothing) = train!(loss, AutoEnzyme(), model, data, opt; cb)

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb=nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
return train!(loss, model, data, _rule_to_state(model, rule); cb)
end

end # module Train
Loading