From 8967126e835c8baa86b5a626b5ba33581a28cbc3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 7 Feb 2020 08:07:47 -0500 Subject: [PATCH 1/2] fix localmem and add dynamic mem support --- src/KernelAbstractions.jl | 28 +++++++++++++++---- src/backends/cpu.jl | 10 +++++++ src/backends/cuda.jl | 26 +++++++++++++++++- src/macros.jl | 42 +++++++++++++++++++++++++++- test/localmem.jl | 58 +++++++++++++++++++++++++++++++++++++++ test/macros.jl | 26 ++++++++++++++++++ test/runtests.jl | 8 ++++++ 7 files changed, 190 insertions(+), 8 deletions(-) create mode 100644 test/localmem.jl create mode 100644 test/macros.jl diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 650585b00..d20abb008 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -79,16 +79,26 @@ Query the workgroupsize on the device. """ function groupsize end -const shmem_id = Ref(0) - """ @localmem T dims """ macro localmem(T, dims) - id = (shmem_id[]+= 1) + # Stay in sync with CUDAnative + id = gensym("static_shmem") + quote + $SharedMemory($(esc(T)), Val($(esc(dims))), Val($(QuoteNode(id)))) + end +end +""" + @dynamic_localmem T N + @dynamic_localmem T (workgroupsize)->N +""" +macro dynamic_localmem(T, N, previous...) + # Stay in sync with CUDAnative + id = gensym("dynamic_shmem") quote - $SharedMemory($(esc(T)), Val($(esc(dims))), Val($id)) + $DynamicSharedMemory($(esc(T)), $(esc(N)), Val($(QuoteNode(id))), $(map(esc, previous))...) end end @@ -196,6 +206,8 @@ struct Kernel{Device, WorkgroupSize<:_Size, NDRange<:_Size, Fun} f::Fun end +function allocator end + workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize ndrange(::Kernel{D, WorkgroupSize, NDRange}) where {D, WorkgroupSize,NDRange} = NDRange @@ -281,11 +293,15 @@ include("macros.jl") ### function Scratchpad(::Type{T}, ::Val{Dims}) where {T, Dims} - throw(MethodError(ScratchArray, (T, Val(Dims)))) + throw(MethodError(Scratchpad, (T, Val(Dims)))) end function SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id} - throw(MethodError(ScratchArray, (T, Val(Dims), Val(Id)))) + throw(MethodError(SharedMemory, (T, Val(Dims), Val(Id)))) +end + +function DynamicSharedMemory(::Type{T}, N, ::Val{Id}, previous...) where {T, Id} + throw(MethodError(DynamicSharedMemory, (T, N, Val(id), previous))) end function __synchronize() diff --git a/src/backends/cpu.jl b/src/backends/cpu.jl index c5b674eb6..c5178e3b5 100644 --- a/src/backends/cpu.jl +++ b/src/backends/cpu.jl @@ -108,6 +108,16 @@ generate_overdubs(CPUCtx) MArray{__size(Dims), T}(undef) end +@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(DynamicSharedMemory), ::Type{T}, alloc, ::Val, previous...) where T + if alloc isa Function + N = alloc(__groupsize(ctx.metadata)) + else + N = alloc + end + @assert N isa Int + Vector{T}(undef, N) +end + ### # CPU implementation of scratch memory # - private memory for each workitem diff --git a/src/backends/cuda.jl b/src/backends/cuda.jl index bd2a20670..834163f80 100644 --- a/src/backends/cuda.jl +++ b/src/backends/cuda.jl @@ -203,7 +203,31 @@ end ### @inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id} ptr = CUDAnative._shmem(Val(Id), T, Val(prod(Dims))) - CUDAnative.CuDeviceArray(Dims, CUDAnative.DevicePtr{T, CUDAnative.AS.Shared}(ptr)) + CUDAnative.CuDeviceArray(Dims, ptr) +end + +@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(DynamicSharedMemory), ::Type{T}, alloc, ::Val{Id}, previous::Vararg{Any, N}) where {T, Id, N} + ptr = CUDAnative._shmem(Val(Id), T, Val(0)) + nthreads = __gpu_groupsize(ctx.metadata) + offset = 0 + # ah yes this is indeed a mapreduce + ntuple(Val(N)) do I + Base.@_inline_meta + Tvar, _alloc = @inbounds previous[I] + if _alloc isa Function + _size = _alloc(nthreads)::Int + else + _size = _alloc::Int + end + offset += sizeof(Tvar) * _size + end + if alloc isa Function + size = alloc(nthreads)::Int + else + size = alloc::Int + end + offset = offset::Int + CUDAnative.CuDeviceArray((size,), ptr + offset) end ### diff --git a/src/macros.jl b/src/macros.jl index e70df3915..6d3416adc 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -32,10 +32,12 @@ function __kernel(expr) # 2. CPU function with work-group loops inserted gpu_name = esc(gensym(Symbol(:gpu_, name))) cpu_name = esc(gensym(Symbol(:cpu_, name))) + alloc_name = esc(gensym(Symbol(:alloc_, name))) name = esc(name) gpu_decl = Expr(:call, gpu_name, arglist...) cpu_decl = Expr(:call, cpu_name, arglist...) + alloc_decl = Expr(:call, alloc_name, arglist...) # Without the deepcopy we might accidentially modify expr shared between CPU and GPU gpu_body = transform_gpu(deepcopy(body), args) @@ -44,6 +46,9 @@ function __kernel(expr) cpu_body = transform_cpu(deepcopy(body), args) cpu_function = Expr(:function, cpu_decl, cpu_body) + dynamic_allocator_body = transform_dynamic(body, args) + dynamic_allocator_function = Expr(:function, alloc_decl, dynamic_allocator_body) + # create constructor functions constructors = quote $name(dev::Device) = $name(dev, $DynamicSize(), $DynamicSize()) @@ -55,9 +60,11 @@ function __kernel(expr) function $name(::Device, ::S, ::NDRange) where {Device<:$GPU, S<:$_Size, NDRange<:$_Size} return $Kernel{Device, S, NDRange, typeof($gpu_name)}($gpu_name) end + KernelAbstractions.allocator(::typeof($gpu_name)) = $alloc_name + KernelAbstractions.allocator(::typeof($cpu_name)) = $alloc_name end - return Expr(:toplevel, cpu_function, gpu_function, constructors) + return Expr(:toplevel, cpu_function, gpu_function, dynamic_allocator_function, constructors) end # Transform function for GPU execution @@ -83,6 +90,7 @@ function split(stmts) # 2. Aggregate the index and allocation expressions seen at the sync points indicies = Any[] allocations = Any[] + dynamic_allocs = Any[] loops = Any[] current = Any[] @@ -102,6 +110,15 @@ function split(stmts) elseif callee === Symbol("@localmem") || callee === Symbol("@private") push!(allocations, stmt) continue + elseif callee === Symbol("@dynamic_localmem") + # args[2] LineNumberNode + Tvar = rhs.args[3] + size = rhs.args[4] + # add all previous dynamic allocations + append!(rhs.args, dynamic_allocs) + push!(allocations, stmt) + push!(dynamic_allocs, Expr(:tuple, Tvar, size)) + continue end end end @@ -161,3 +178,26 @@ function transform_cpu(stmts, args) push!(new_stmts, :(return nothing)) return Expr(:block, new_stmts...) end + +function transform_dynamic(stmts, args) + # 1. Find dynamic allocations + allocators = Expr[] + for stmt in stmts.args + if isexpr(stmt, :(=)) + rhs = stmt.args[2] + if isexpr(rhs, :macrocall) + callee = rhs.args[1] + if callee === Symbol("@dynamic_localmem") + # args[2] LineNumberNode + Tvar = rhs.args[3] + size = rhs.args[4] + push!(allocators, Expr(:tuple, Expr(:call, :sizeof, Tvar), size)) + continue + end + end + end + end + return Expr(:block, Expr(:tuple, allocators...)) +end + + diff --git a/test/localmem.jl b/test/localmem.jl new file mode 100644 index 000000000..19e471d7c --- /dev/null +++ b/test/localmem.jl @@ -0,0 +1,58 @@ +using KernelAbstractions +using Test +using CUDAapi +if has_cuda_gpu() + using CuArrays + CuArrays.allowscalar(false) +end + +@kernel function dynamic(A) + I = @index(Global, Linear) + i = @index(Local, Linear) + lmem = @dynamic_localmem eltype(A) (wkrgpsize)->2*wkrgpsize + lmem[2*i] = A[I] + @synchronize + A[I] = lmem[2*i] +end + +@kernel function dynamic2(A) + I = @index(Global, Linear) + i = @index(Local, Linear) + lmem = @dynamic_localmem eltype(A) (wkrgpsize)->2*wkrgpsize + lmem2 = @dynamic_localmem Int (wkrgpsize)->wkrgpsize + lmem2[i] = i + lmem[2*i] = A[I] + @synchronize + A[I] = lmem[2*lmem2[i]] +end + +@kernel function dynamic_mixed(A) + I = @index(Global, Linear) + i = @index(Local, Linear) + lmem = @dynamic_localmem eltype(A) (wkrgpsize)->2*wkrgpsize + lmem2 = @localmem Int groupsize() # Ok iff groupsize is static + lmem2[i] = i + lmem[2*i] = A[I] + @synchronize + A[I] = lmem[2*lmem2[i]] +end + + +function harness(backend, ArrayT) + A = ArrayT{Float64}(undef, 16, 16) + A .= 1.0 + wait(dynamic(backend, 16)(A, ndrange=size(A))) + + A = ArrayT{Float64}(undef, 16, 16) + wait(dynamic2(backend, 16)(A, ndrange=size(A))) + + A = ArrayT{Float64}(undef, 16, 16) + wait(dynamic_mixed(backend, 16)(A, ndrange=size(A))) +end + +@testset "kernels" begin + harness(CPU(), Array) + if has_cuda_gpu() + harness(CUDA(), CuArray) + end +end diff --git a/test/macros.jl b/test/macros.jl new file mode 100644 index 000000000..6eb0bdb35 --- /dev/null +++ b/test/macros.jl @@ -0,0 +1,26 @@ +using KernelAbstractions +using Test + +@kernel function nodynamic(A) + I = @index(Global, Linear) + A[I] = I +end + +@kernel function dynamic(A) + I = @index(Global, Linear) + i = @index(Local, Linear) + lmem = @dynamic_localmem eltype(A) (wkrgpsize)->2*wkrgpsize + lmem[2*i] = A[I] + @synchronize + A[I] = lmem[2*i] +end + +nodyn_kernel = nodynamic(CPU()) +dyn_kernel = dynamic(CPU()) + +@test KernelAbstractions.allocator(nodyn_kernel.f)(zeros(32, 32)) == () +let allocs = KernelAbstractions.allocator(dyn_kernel.f)(zeros(32, 32)) + bytes, alloc = allocs[1] + @test bytes == sizeof(Float64) + @test alloc(32) == 64 +end diff --git a/test/runtests.jl b/test/runtests.jl index 80790a34a..b2a123f0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,4 +5,12 @@ using Test include("test.jl") end +@testset "Macro" begin + include("macros.jl") +end + +@testset "Localmem" begin + include("localmem.jl") +end + include("examples.jl") From 4e214f40158b3919ffdaf5a0d247b354625a523f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 13 Feb 2020 09:18:36 -0500 Subject: [PATCH 2/2] WIP changes --- src/backends/cuda.jl | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/backends/cuda.jl b/src/backends/cuda.jl index 834163f80..b5bb84311 100644 --- a/src/backends/cuda.jl +++ b/src/backends/cuda.jl @@ -206,28 +206,21 @@ end CUDAnative.CuDeviceArray(Dims, ptr) end +@inline get_offset(threads) = 0 +@inline get_offset(threads, elem::Tuple{Int, Int}) = elem[1] * elem[2] +@inline get_offset(threads, elem::Tuple{Int, <:Function}) = elem[1] * elem[2](threads) +@inline get_offset(threads, elem, args...) = get_offset(threads, elem) + get_offset(threads, args...) + @inline function Cassette.overdub(ctx::CUDACtx, ::typeof(DynamicSharedMemory), ::Type{T}, alloc, ::Val{Id}, previous::Vararg{Any, N}) where {T, Id, N} ptr = CUDAnative._shmem(Val(Id), T, Val(0)) nthreads = __gpu_groupsize(ctx.metadata) - offset = 0 - # ah yes this is indeed a mapreduce - ntuple(Val(N)) do I - Base.@_inline_meta - Tvar, _alloc = @inbounds previous[I] - if _alloc isa Function - _size = _alloc(nthreads)::Int - else - _size = _alloc::Int - end - offset += sizeof(Tvar) * _size - end + offset = get_offset(nthreads, previous...) if alloc isa Function - size = alloc(nthreads)::Int + size = alloc(nthreads) else - size = alloc::Int + size = alloc end - offset = offset::Int - CUDAnative.CuDeviceArray((size,), ptr + offset) + CUDAnative.CuDeviceArray{T}((size,), ptr + offset) end ###