11abstract type AbstractNormalizationLayer{affine, track_stats} <: AbstractExplicitLayer end
22
3+ @inline _affine (l:: AbstractNormalizationLayer{A, T} ) where {A, T} = A
4+ @inline _track_stats (l:: AbstractNormalizationLayer{A, T} ) where {A, T} = T
5+
36"""
47 BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32,
58 affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0)
@@ -93,28 +96,25 @@ function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, init_scale=
9396 chs, init_bias, init_scale)
9497end
9598
96- function initialparameters (rng:: AbstractRNG , l:: BatchNorm{affine} ) where {affine}
97- if affine
99+ function initialparameters (rng:: AbstractRNG , l:: BatchNorm )
100+ if _affine (l)
98101 return (scale= l. init_scale (rng, l. chs), bias= l. init_bias (rng, l. chs))
99102 else
100103 return (scale= nothing , bias= nothing )
101104 end
102105end
103106
104- function initialstates (rng:: AbstractRNG ,
105- l:: BatchNorm{affine, track_stats} ) where {affine, track_stats}
106- return if track_stats
107- (running_mean= zeros32 (rng, l. chs), running_var= ones32 (rng, l. chs),
108- training= Val (true ))
107+ function initialstates (rng:: AbstractRNG , l:: BatchNorm )
108+ if _track_stats (l)
109+ return (running_mean= zeros32 (rng, l. chs), running_var= ones32 (rng, l. chs),
110+ training= Val (true ))
109111 else
110- (running_mean= nothing , running_var= nothing , training= Val (true ))
112+ return (running_mean= nothing , running_var= nothing , training= Val (true ))
111113 end
112114end
113115
114- parameterlength (l:: BatchNorm{affine} ) where {affine} = affine ? (l. chs * 2 ) : 0
115- function statelength (l:: BatchNorm{affine, track_stats} ) where {affine, track_stats}
116- return (track_stats ? 2 * l. chs : 0 ) + 1
117- end
116+ parameterlength (l:: BatchNorm ) = _affine (l) ? (l. chs * 2 ) : 0
117+ statelength (l:: BatchNorm ) = (_track_stats (l) ? 2 * l. chs : 0 ) + 1
118118
119119function (BN:: BatchNorm )(x:: AbstractArray{T, N} , ps, st:: NamedTuple ) where {T, N}
120120 x_normalized, xmean, xvar = normalization (x, st. running_mean, st. running_var, ps. scale,
@@ -127,42 +127,42 @@ function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N
127127 return x_normalized, st
128128end
129129
130- function (BN :: BatchNorm{affine, track_stats} )(x :: Union {CuArray{T, 2 }, CuArray{T, 4 },
131- CuArray{T, 5 }}, ps ,
132- st :: NamedTuple ) where {
133- T < :
134- Union{Float32, Float64
135- }, affine ,
136- track_stats }
130+ function _batchnorm (scale, bias, x, running_mean, running_var, momentum, epsilon, training)
131+ return batchnorm (scale, bias, x, running_mean, running_var, momentum; eps = epsilon ,
132+ training = training)
133+ end
134+
135+ function (BN :: BatchNorm )(x :: Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}} , ps ,
136+ st :: NamedTuple ) where {T <: Union{Float32, Float64} }
137137 # NNlibCUDA silently updates running_mean and running_var so copying them
138138 if istraining (st)
139- running_mean2 = track_stats ? copy (st. running_mean) : nothing
140- running_var2 = track_stats ? copy (st. running_var) : nothing
139+ running_mean2 = _track_stats (BN) ? _copy_autodiff_barrier (st. running_mean) : nothing
140+ running_var2 = _track_stats (BN) ? _copy_autodiff_barrier (st. running_var) : nothing
141141 else
142- if track_stats
143- running_mean2 = copy (st. running_mean)
144- running_var2 = copy (st. running_var)
142+ if _track_stats (BN)
143+ running_mean2 = _copy_autodiff_barrier (st. running_mean)
144+ running_var2 = _copy_autodiff_barrier (st. running_var)
145145 else
146146 N = ndims (x)
147147 reduce_dims = collect ([1 : (N - 2 ); N])
148148 running_mean2 = mean (x; dims= reduce_dims)
149149 running_var2 = var (x; mean= running_mean2, dims= reduce_dims, corrected= false )
150150 end
151151 end
152- res = BN. activation .(batchnorm (affine ? ps. scale : nothing , affine ? ps . bias : nothing ,
153- x, running_mean2, running_var2, BN . momentum;
154- eps = BN. epsilon, training = istraining (st)))
155- if track_stats
152+ res = BN. activation .(_batchnorm ( _affine (BN) ? ps. scale : nothing ,
153+ _affine (BN) ? ps . bias : nothing , x, running_mean2,
154+ running_var2, BN. momentum, BN . epsilon, istraining (st)))
155+ if _track_stats (BN)
156156 st = merge (st, (running_mean= running_mean2, running_var= running_var2))
157157 end
158158 return res, st
159159end
160160
161- function Base. show (io:: IO , l:: BatchNorm{affine, track_stats} ) where {affine, track_stats}
161+ function Base. show (io:: IO , l:: BatchNorm )
162162 print (io, " BatchNorm($(l. chs) " )
163163 (l. activation == identity) || print (io, " , $(l. activation) " )
164- print (io, " , affine=$(affine ) " )
165- print (io, " , track_stats=$(track_stats ) " )
164+ print (io, " , affine=$(_affine (l) ) " )
165+ print (io, " , track_stats=$(_track_stats (l) ) " )
166166 return print (io, " )" )
167167end
168168
@@ -281,29 +281,26 @@ function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias
281281 groups)
282282end
283283
284- function initialparameters (rng:: AbstractRNG , l:: GroupNorm{affine} ) where {affine}
285- if affine
284+ function initialparameters (rng:: AbstractRNG , l:: GroupNorm )
285+ if _affine (l)
286286 return (scale= l. init_scale (rng, l. chs), bias= l. init_bias (rng, l. chs))
287287 else
288288 return (scale= nothing , bias= nothing )
289289 end
290290end
291291
292- function initialstates (rng:: AbstractRNG ,
293- l:: GroupNorm{affine, track_stats} ) where {affine, track_stats}
294- return if track_stats
292+ function initialstates (rng:: AbstractRNG , l:: GroupNorm )
293+ return if _track_stats (l)
295294 (running_mean= zeros32 (rng, l. groups), running_var= ones32 (rng, l. groups),
296295 training= Val (true ))
297296 else
298297 (running_mean= nothing , running_var= nothing , training= Val (true ))
299298 end
300299end
301300
302- parameterlength (l:: GroupNorm{affine} ) where {affine} = affine ? (l. chs * 2 ) : 0
301+ parameterlength (l:: GroupNorm ) = _affine (l) ? (l. chs * 2 ) : 0
303302
304- function statelength (l:: GroupNorm{affine, track_stats} ) where {affine, track_stats}
305- return (track_stats ? 2 * l. groups : 0 ) + 1
306- end
303+ statelength (l:: GroupNorm ) = (_track_stats (l) ? 2 * l. groups : 0 ) + 1
307304
308305function (GN:: GroupNorm )(x:: AbstractArray{T, N} , ps, st:: NamedTuple ) where {T, N}
309306 sz = size (x)
@@ -318,11 +315,11 @@ function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N
318315 return reshape (x_normalized, sz), st
319316end
320317
321- function Base. show (io:: IO , l:: GroupNorm{affine, track_stats} ) where {affine, track_stats}
318+ function Base. show (io:: IO , l:: GroupNorm )
322319 print (io, " GroupNorm($(l. chs) , $(l. groups) " )
323320 (l. activation == identity) || print (io, " , $(l. activation) " )
324- print (io, " , affine=$(affine ) " )
325- print (io, " , track_stats=$(track_stats ) " )
321+ print (io, " , affine=$(_affine (l) ) " )
322+ print (io, " , track_stats=$(_track_stats (l) ) " )
326323 return print (io, " )" )
327324end
328325
0 commit comments