diff --git a/Project.toml b/Project.toml index 026eeb17..b33954ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.28" +version = "0.9.29" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/impl/conv_im2col.jl b/src/impl/conv_im2col.jl index 8c497954..1faf0d97 100644 --- a/src/impl/conv_im2col.jl +++ b/src/impl/conv_im2col.jl @@ -8,6 +8,48 @@ function kernel_index(w, h, d, cdims::ConvDims) return (kernel_w - w + 1, kernel_h - h + 1, kernel_d - d + 1) end +# Internal. Used to memoize the `col` scratchspace used in the functions below. +_col_memo_lock = ReentrantLock() +_col_memo = Dict() +macro _memo_col(key, type, default) + # The caller gets exclusive use of the returned array and should return it to the memo once they are finished. + return quote + lock(_col_memo_lock) do + if !haskey(_col_memo, $(esc(key))) + _col_memo[$(esc(key))] = $(esc(type))[] + end + if !isempty(_col_memo[$(esc(key))]) + return pop!(_col_memo[$(esc(key))]) + else + return $(esc(default)) + end + end + end +end +macro _return_col_to_memo(key, type, val) + return quote + lock(_col_memo_lock) do + # always need to check the key because the memo could have been emptied + if !haskey(_col_memo, $(esc(key))) + _col_memo[$(esc(key))] = $(esc(type))[] + end + push!(_col_memo[$(esc(key))], $(esc(val))) + end + end +end + +""" + free_scratchspace_memo!() + +Empties the memo holding arrays used for scratch space. Thread safe. +""" +function free_scratchspace_memo!() + lock(_col_memo_lock) do + empty!(_col_memo) + end + return nothing +end + """ conv_im2col!(y, x, w, cdims, col=similar(x); alpha=1, beta=0) @@ -19,15 +61,24 @@ by setting `alpha` to a nonunitary value, various gain factors can be applied. Note for the particularly performance-minded, you can provide a pre-allocated `col`, which should eliminate any need for large allocations within this method. +By default, `col` will be memoized to reduce allocations between calls. +The memo can be emptied at any time using [`free_scratchspace_memo!`](@ref). """ function conv_im2col!( y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5}, cdims::DenseConvDims; - col::AbstractArray{T,3}=similar(x, im2col_dims(cdims)), + col::Union{AbstractArray{T,3},Nothing}=nothing, alpha::T=T(1), beta::T=T(0), ntasks::Int=nthreads()) where {T} check_dims(size(x), size(w), size(y), cdims) + # Memoize col to reduce allocations. We get exclusive use of the returned col. + dims = im2col_dims(cdims) + if col === nothing + key = (T,dims...) + col = @_memo_col key AbstractArray{T,3} similar(x, dims) + end + # COL * W -> Y # [M x K] * [K x N] -> [M x N] # @@ -61,6 +112,10 @@ function conv_im2col!( end end end + + # Return col to the memo so another function can use it. + @_return_col_to_memo key AbstractArray{T,3} col + return y end @@ -74,10 +129,17 @@ See [`conv_im2col!`](@ref) for explanation of optional parameters. function ∇conv_filter_im2col!( dw::AbstractArray{T,5}, x::AbstractArray{T,5}, dy::AbstractArray{T,5}, cdims::DenseConvDims; - col::AbstractArray{T,3} = similar(dw, ∇filter_im2col_dims(cdims)), + col::Union{AbstractArray{T,3},Nothing}=nothing, alpha::T=T(1), beta::T=T(0)) where {T} check_dims(size(x), size(dw), size(dy), cdims) + # Memoize col to reduce allocations. We get exclusive use of the returned col. + dims = ∇filter_im2col_dims(cdims) + if col === nothing + key = (T,dims...) + col = @_memo_col key AbstractArray{T,3} similar(dw, dims) + end + # COL' * dY -> dW # [M x K] * [K x N] -> [M x N] # @@ -114,6 +176,10 @@ function ∇conv_filter_im2col!( # to `1.0` from this point on. beta = T(1) end + + # Return col to the memo so another function can use it. + @_return_col_to_memo key AbstractArray{T,3} col + return dw end diff --git a/test/conv.jl b/test/conv.jl index badb2c5f..182cfc76 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -997,3 +997,10 @@ end end end + +@testset "free_scratchspace_memo" begin + NNlib._col_memo[1] = 2 + @test !isempty(NNlib._col_memo) + NNlib.free_scratchspace_memo!() + @test isempty(NNlib._col_memo) +end