@@ -9,6 +9,8 @@ Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
99to a vector, and returns also a function which reverses this transformation.
1010Differentiable.
1111
12+ See also [`destructure!`](@ref).
13+
1214# Example
1315```jldoctest
1416julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
@@ -31,6 +33,36 @@ function destructure(x)
3133 flat, Restructure (x, off, len)
3234end
3335
36+ """
37+ destructure!(model) -> vector, reconstructor
38+
39+ This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model.
40+ Requires that all trainable parameters in the model be mutable arrays!
41+
42+ # Example
43+ ```jldoctest
44+ julia> m = (x=[1.0, 2.0], y=(sin, Float32[3.0 4.0], cos))
45+
46+ julia> v, re! = destructure!(m)
47+ ([1.0, 2.0, 3.0, 4.0], Restructure!(NamedTuple, ..., 4))
48+
49+ julia> m === re!([3, 5, 7, 9]) # mutates the original m, and returns it
50+ true
51+
52+ julia> m
53+ (x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos))
54+ ```
55+ """
56+ function destructure! (x)
57+ flat, off, len = _flatten (x)
58+ flat, Restructure! (x, off, len)
59+ end
60+
61+ # function destructure!(flat::AbstractVector, x)
62+ # flat, off, len = _flatten!(flat, x)
63+ # flat, Restructure!(x, off, len)
64+ # end
65+
3466"""
3567 Restructure(Model, ..., length)
3668
@@ -55,12 +87,20 @@ struct Restructure{T,S}
5587 model:: T
5688 offsets:: S
5789 length:: Int
90+ mutate:: Bool
5891end
59- (re:: Restructure )(flat:: AbstractVector ) = _rebuild (re. model, re. offsets, flat, re. length)
92+ Restructure (model, offsets, length) = Restructure (model, offsets, length, false )
93+ Restructure! (model, offsets, length) = Restructure (model, offsets, length, true )
94+
95+ (re:: Restructure )(flat:: AbstractVector ) = re. mutate ? _rebuild! (re. model, re. offsets, flat, re. length) : _rebuild (re. model, re. offsets, flat, re. length)
6096(re:: Restructure )(x, flat:: AbstractVector ) = re (flat)(x)
61- Base. show (io:: IO , re:: Restructure{T} ) where T = print (io, " Restructure(" , T. name. name, " , ..., " , re. length, " )" )
6297Base. length (re:: Restructure ) = re. length
6398
99+ function Base. show (io:: IO , re:: Restructure{T} ) where T
100+ print (io, " Restructure" , re. mutate ? " !" : " " )
101+ print (io, " (" , T. name. name, " , ..., " , re. length, " )" )
102+ end
103+
64104# This flattens a model, and returns a web of offsets for later use:
65105function _flatten (x)
66106 isnumeric (x) && return vcat (_vec (x)), 0 , length (x) # trivial case
@@ -75,6 +115,17 @@ function _flatten(x)
75115 isempty (arrays) && return Bool[], off, 0
76116 reduce (vcat, arrays), off, len[]
77117end
118+ # function _flatten!(flat, x)
119+ # isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case
120+ # len = Ref(0)
121+ # off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
122+ # o = len[]
123+ # copyto!(flat, o, _vec(y))
124+ # len[] = o + length(y)
125+ # o
126+ # end
127+ # flat, off, len[]
128+ # end
78129
79130struct _TrainableStructWalk <: AbstractWalk end
80131
@@ -97,10 +148,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trai
97148 _getat (y, o, flat)
98149 end
99150end
151+ # (mutating version, same arguments & same return)
152+ function _rebuild! (x, off, flat:: AbstractVector , len = length (flat); walk = _Trainable_biwalk (), kw... )
153+ len == length (flat) || throw (DimensionMismatch (" Rebuild expected a vector of length $len , got $(length (flat)) " ))
154+ fmap (x, off; exclude = isnumeric, walk, kw... ) do y, o
155+ copyto! (y, _getat (y, o, flat, view))
156+ end
157+ x
158+ end
100159
101- _getat (y:: Number , o:: Int , flat:: AbstractVector ) = ProjectTo (y)(flat[o + 1 ])
102- _getat (y:: AbstractArray , o:: Int , flat:: AbstractVector ) =
103- ProjectTo (y)(reshape (flat[ o .+ (1 : length (y))] , axes (y))) # ProjectTo is just correcting eltypes
160+ _getat (y:: Number , o:: Int , flat:: AbstractVector , _ ... ) = ProjectTo (y)(flat[o + 1 ])
161+ _getat (y:: AbstractArray , o:: Int , flat:: AbstractVector , get = getindex ) =
162+ ProjectTo (y)(reshape (get ( flat, o .+ (1 : length (y))) , axes (y))) # ProjectTo is just correcting eltypes
104163
105164struct _Trainable_biwalk <: AbstractWalk end
106165
@@ -135,6 +194,10 @@ function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
135194 _rebuild_back (dx) = (NoT, NoT, NoT, _grad! (x, unthunk (dx), off, _zero (flat)), NoT)
136195 _rebuild (x, off, flat, len; kw... ), _rebuild_back
137196end
197+ function ChainRulesCore. rrule (:: typeof (_rebuild!), x, off, flat, len; kw... )
198+ _rebuild!_back (dx) = (NoT, NoT, NoT, _grad! (x, unthunk (dx), off, _zero (flat)), NoT)
199+ _rebuild! (x, off, flat, len; kw... ), _rebuild!_back
200+ end
138201
139202_zero (x) = map! (zero, similar (x, float (eltype (x))), x) # mutable zero array for _grad!
140203ChainRulesCore. @non_differentiable _zero (x)
0 commit comments