Skip to content

Commit b5e5d72

Browse files
committed
Make normalization more AD friendly (#148)
1 parent 2c529ed commit b5e5d72

File tree

4 files changed

+66
-52
lines changed

4 files changed

+66
-52
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "0.4.20"
4+
version = "0.4.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/autodiff.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,23 @@ ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
1010
ChainRulesCore.@non_differentiable check_use_cuda()
1111
ChainRulesCore.@non_differentiable istraining(::Any)
1212
ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any)
13+
ChainRulesCore.@non_differentiable _affine(::Any)
14+
ChainRulesCore.@non_differentiable _track_stats(::Any)
15+
ChainRulesCore.@non_differentiable _copy_autodiff_barrier(::Any)
1316

1417
# NNlib Functions
15-
function ChainRulesCore.rrule(::typeof(batchnorm), g::CuArray{T}, b::CuArray{T},
16-
x::Union{CuArray{T, 4}, CuArray{T, 5}}, running_mean,
17-
running_var, momentum; kwargs...) where {T <: CUDNNFloat}
18-
y = batchnorm(g, b, x, running_mean, running_var, momentum; kwargs...)
19-
function batchnorm_pullback(dy)
20-
dg, db, dx = ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kwargs...)
21-
return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent()
18+
function ChainRulesCore.rrule(::typeof(_batchnorm), g::CuArray{T}, b::CuArray{T},
19+
x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}},
20+
running_mean, running_var, momentum, epsilon,
21+
training) where {T <: CUDNNFloat}
22+
y = _batchnorm(g, b, x, running_mean, running_var, momentum, epsilon, training)
23+
function _batchnorm_pullback(dy)
24+
dg, db, dx = ∇batchnorm(g, b, x, unthunk(dy), running_mean, running_var, momentum;
25+
eps=epsilon, training=training)
26+
return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent(), NoTangent(),
27+
NoTangent()
2228
end
23-
return y, batchnorm_pullback
29+
return y, _batchnorm_pullback
2430
end
2531

2632
function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractArray{T, N},
@@ -59,6 +65,9 @@ function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
5965
dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2))
6066
return (NoTangent(), dnt1, dnt2)
6167
end
68+
function merge_pullback(dy::Union{NoTangent, ZeroTangent})
69+
return (NoTangent(), NoTangent(), NoTangent())
70+
end
6271
return y, merge_pullback
6372
end
6473

@@ -89,6 +98,11 @@ function ChainRulesCore.rrule(::typeof(collect), v::Vector)
8998
return y, collect_pullback
9099
end
91100

101+
function ChainRulesCore.rrule(::typeof(copy), x)
102+
copy_pullback(dy) = (NoTangent(), dy)
103+
return copy(x), copy_pullback
104+
end
105+
92106
# Zygote Fixes
93107
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
94108
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))

src/layers/normalize.jl

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
abstract type AbstractNormalizationLayer{affine, track_stats} <: AbstractExplicitLayer end
22

3+
@inline _affine(l::AbstractNormalizationLayer{A, T}) where {A, T} = A
4+
@inline _track_stats(l::AbstractNormalizationLayer{A, T}) where {A, T} = T
5+
36
"""
47
BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32,
58
affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0)
@@ -93,28 +96,25 @@ function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, init_scale=
9396
chs, init_bias, init_scale)
9497
end
9598

96-
function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine}
97-
if affine
99+
function initialparameters(rng::AbstractRNG, l::BatchNorm)
100+
if _affine(l)
98101
return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs))
99102
else
100103
return (scale=nothing, bias=nothing)
101104
end
102105
end
103106

104-
function initialstates(rng::AbstractRNG,
105-
l::BatchNorm{affine, track_stats}) where {affine, track_stats}
106-
return if track_stats
107-
(running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs),
108-
training=Val(true))
107+
function initialstates(rng::AbstractRNG, l::BatchNorm)
108+
if _track_stats(l)
109+
return (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs),
110+
training=Val(true))
109111
else
110-
(running_mean=nothing, running_var=nothing, training=Val(true))
112+
return (running_mean=nothing, running_var=nothing, training=Val(true))
111113
end
112114
end
113115

114-
parameterlength(l::BatchNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0
115-
function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_stats}
116-
return (track_stats ? 2 * l.chs : 0) + 1
117-
end
116+
parameterlength(l::BatchNorm) = _affine(l) ? (l.chs * 2) : 0
117+
statelength(l::BatchNorm) = (_track_stats(l) ? 2 * l.chs : 0) + 1
118118

119119
function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
120120
x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale,
@@ -127,42 +127,42 @@ function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N
127127
return x_normalized, st
128128
end
129129

130-
function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, 4},
131-
CuArray{T, 5}}, ps,
132-
st::NamedTuple) where {
133-
T <:
134-
Union{Float32, Float64
135-
}, affine,
136-
track_stats}
130+
function _batchnorm(scale, bias, x, running_mean, running_var, momentum, epsilon, training)
131+
return batchnorm(scale, bias, x, running_mean, running_var, momentum; eps=epsilon,
132+
training=training)
133+
end
134+
135+
function (BN::BatchNorm)(x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}}, ps,
136+
st::NamedTuple) where {T <: Union{Float32, Float64}}
137137
# NNlibCUDA silently updates running_mean and running_var so copying them
138138
if istraining(st)
139-
running_mean2 = track_stats ? copy(st.running_mean) : nothing
140-
running_var2 = track_stats ? copy(st.running_var) : nothing
139+
running_mean2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_mean) : nothing
140+
running_var2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_var) : nothing
141141
else
142-
if track_stats
143-
running_mean2 = copy(st.running_mean)
144-
running_var2 = copy(st.running_var)
142+
if _track_stats(BN)
143+
running_mean2 = _copy_autodiff_barrier(st.running_mean)
144+
running_var2 = _copy_autodiff_barrier(st.running_var)
145145
else
146146
N = ndims(x)
147147
reduce_dims = collect([1:(N - 2); N])
148148
running_mean2 = mean(x; dims=reduce_dims)
149149
running_var2 = var(x; mean=running_mean2, dims=reduce_dims, corrected=false)
150150
end
151151
end
152-
res = BN.activation.(batchnorm(affine ? ps.scale : nothing, affine ? ps.bias : nothing,
153-
x, running_mean2, running_var2, BN.momentum;
154-
eps=BN.epsilon, training=istraining(st)))
155-
if track_stats
152+
res = BN.activation.(_batchnorm(_affine(BN) ? ps.scale : nothing,
153+
_affine(BN) ? ps.bias : nothing, x, running_mean2,
154+
running_var2, BN.momentum, BN.epsilon, istraining(st)))
155+
if _track_stats(BN)
156156
st = merge(st, (running_mean=running_mean2, running_var=running_var2))
157157
end
158158
return res, st
159159
end
160160

161-
function Base.show(io::IO, l::BatchNorm{affine, track_stats}) where {affine, track_stats}
161+
function Base.show(io::IO, l::BatchNorm)
162162
print(io, "BatchNorm($(l.chs)")
163163
(l.activation == identity) || print(io, ", $(l.activation)")
164-
print(io, ", affine=$(affine)")
165-
print(io, ", track_stats=$(track_stats)")
164+
print(io, ", affine=$(_affine(l))")
165+
print(io, ", track_stats=$(_track_stats(l))")
166166
return print(io, ")")
167167
end
168168

@@ -281,29 +281,26 @@ function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias
281281
groups)
282282
end
283283

284-
function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine}
285-
if affine
284+
function initialparameters(rng::AbstractRNG, l::GroupNorm)
285+
if _affine(l)
286286
return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs))
287287
else
288288
return (scale=nothing, bias=nothing)
289289
end
290290
end
291291

292-
function initialstates(rng::AbstractRNG,
293-
l::GroupNorm{affine, track_stats}) where {affine, track_stats}
294-
return if track_stats
292+
function initialstates(rng::AbstractRNG, l::GroupNorm)
293+
return if _track_stats(l)
295294
(running_mean=zeros32(rng, l.groups), running_var=ones32(rng, l.groups),
296295
training=Val(true))
297296
else
298297
(running_mean=nothing, running_var=nothing, training=Val(true))
299298
end
300299
end
301300

302-
parameterlength(l::GroupNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0
301+
parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0
303302

304-
function statelength(l::GroupNorm{affine, track_stats}) where {affine, track_stats}
305-
return (track_stats ? 2 * l.groups : 0) + 1
306-
end
303+
statelength(l::GroupNorm) = (_track_stats(l) ? 2 * l.groups : 0) + 1
307304

308305
function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
309306
sz = size(x)
@@ -318,11 +315,11 @@ function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N
318315
return reshape(x_normalized, sz), st
319316
end
320317

321-
function Base.show(io::IO, l::GroupNorm{affine, track_stats}) where {affine, track_stats}
318+
function Base.show(io::IO, l::GroupNorm)
322319
print(io, "GroupNorm($(l.chs), $(l.groups)")
323320
(l.activation == identity) || print(io, ", $(l.activation)")
324-
print(io, ", affine=$(affine)")
325-
print(io, ", track_stats=$(track_stats)")
321+
print(io, ", affine=$(_affine(l))")
322+
print(io, ", track_stats=$(_track_stats(l))")
326323
return print(io, ")")
327324
end
328325

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,6 @@ Split up `x` into `N` equally sized chunks (along dimension `1`).
189189

190190
# Val utilities
191191
get_known(::Val{T}) where {T} = T
192+
193+
# Copy and don't allow gradient propagation
194+
_copy_autodiff_barrier(x) = copy(x)

0 commit comments

Comments
 (0)