5656
5757Uses a `loss` function and training `data` to improve the `model`'s parameters
5858according to a particular optimisation rule `opt`. Iterates through `data` once,
59- evaluating `loss(model, d...)` for each `d` in data.
59+ evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`,
60+ or else `loss(model, d)` for other `d`.
6061
6162For example, with these definitions...
6263```
63- data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
64+ data = [(x1, y1), (x2, y2), (x3, y3)]
6465
6566loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
6667
7677```
7778You can also write this loop yourself, if you need more flexibility.
7879For this reason `train!` is not highly extensible.
79- It adds only a few featurs to the loop above:
80+ It adds only a few features to the loop above:
8081
8182* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
8283
@@ -88,9 +89,6 @@ It adds only a few featurs to the loop above:
8889 (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
8990 * Instead of `loss` being a function which accepts only the data,
9091 now it must also accept the `model` itself, as the first argument.
91- * `data` must iterate tuples, otherwise you get an error.
92- (Previously non-tuple types were not splatted into the loss.
93- Pass in `((d,) for d in data)` to simulate this.)
9492 * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
9593 such as `Adam()` without this step should give you a warning.
9694 * Callback functions are not supported.
@@ -100,9 +98,8 @@ function train!(loss, model, data, opt; cb = nothing)
10098 isnothing (cb) || error (""" train! does not support callback functions.
10199 For more control use a loop with `gradient` and `update!`.""" )
102100 @withprogress for (i,d) in enumerate (data)
103- d isa Tuple || error (""" train! expects as data an iterator producing tuples, but got $(typeof (d)) .
104- Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""" )
105- l, gs = Zygote. withgradient (m -> loss (m, d... ), model)
101+ d_splat = d isa Tuple ? d : (d,)
102+ l, gs = Zygote. withgradient (m -> loss (m, d_splat... ), model)
106103 if ! isfinite (l)
107104 throw (DomainError (" Loss is $l on data item $i , stopping training" ))
108105 end
@@ -112,8 +109,8 @@ function train!(loss, model, data, opt; cb = nothing)
112109end
113110
114111# This method let you use Optimisers.Descent() without setup, when there is no state
115- function train! (loss, model, data, rule:: Optimisers.AbstractRule )
116- train! (loss, model, data, _rule_to_state (model, rule))
112+ function train! (loss, model, data, rule:: Optimisers.AbstractRule ; cb = nothing )
113+ train! (loss, model, data, _rule_to_state (model, rule); cb )
117114end
118115
119116function _rule_to_state (model, rule:: Optimisers.AbstractRule )
0 commit comments