Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices)
end

struct WorkgroupLoop
indices::Vector{Any}
stmts::Vector{Any}
allocations::Vector{Any}
private_allocations::Vector{Any}
private::Set{Symbol}
terminated_in_sync::Bool
end

Expand All @@ -111,26 +108,18 @@ function find_sync(stmt)
end

# TODO proper handling of LineInfo
function split(
stmts,
indices = Any[], private = Set{Symbol}(),
)
function split(stmts)
# 1. Split the code into blocks separated by `@synchronize`
# 2. Aggregate `@index` expressions
# 3. Hoist allocations
# 4. Hoist uniforms

current = Any[]
allocations = Any[]
private_allocations = Any[]
new_stmts = Any[]
for stmt in stmts
has_sync = find_sync(stmt)
if has_sync
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
loop = WorkgroupLoop(current, allocations, is_sync(stmt))
push!(new_stmts, emit(loop))
allocations = Any[]
private_allocations = Any[]
current = Any[]
is_sync(stmt) && continue

Expand All @@ -142,7 +131,7 @@ function split(
function recurse(expr::Expr)
expr = unblock(expr)
if is_scope_construct(expr) && any(find_sync, expr.args)
new_args = unblock(split(expr.args, deepcopy(indices), deepcopy(private)))
new_args = unblock(split(expr.args))
return Expr(expr.head, new_args...)
else
return Expr(expr.head, map(recurse, expr.args)...)
Expand All @@ -156,14 +145,10 @@ function split(
push!(allocations, stmt)
continue
elseif @capture(stmt, @private lhs_ = rhs_)
push!(private, lhs)
push!(private_allocations, :($lhs = $rhs))
push!(allocations, :($lhs = $rhs))
continue
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
if @capture(rhs, @index(args__))
push!(indices, stmt)
continue
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
if @capture(rhs, @localmem(args__) | @uniform(args__))
push!(allocations, stmt)
continue
elseif @capture(rhs, @private(T_, dims_))
Expand All @@ -175,7 +160,6 @@ function split(
end
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
push!(allocations, :($lhs = $alloc))
push!(private, lhs)
continue
end
end
Expand All @@ -185,7 +169,7 @@ function split(

# everything since the last `@synchronize`
if !isempty(current)
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), false)
loop = WorkgroupLoop(current, allocations, false)
push!(new_stmts, emit(loop))
end
return new_stmts
Expand All @@ -197,9 +181,7 @@ function emit(loop)
body = Expr(:block, loop.stmts...)
loopexpr = quote
$(loop.allocations...)
$(loop.private_allocations...)
if __active_lane__
$(loop.indices...)
$(unblock(body))
end
end
Expand Down
Loading