@@ -86,22 +86,24 @@ function transform_gpu!(def, constargs, force_inbounds)
8686 end
8787 end
8888 pushfirst! (def[:args ], :__ctx__ )
89- body = def[:body ]
89+ new_stmts = Expr[]
90+ body = MacroTools. flatten (def[:body ])
91+ stmts = body. args
92+ push! (new_stmts, Expr (:aliasscope ))
93+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
9094 if force_inbounds
91- body = quote
92- @inbounds $ (body)
93- end
95+ push! (new_stmts, Expr (:inbounds , true ))
9496 end
95- body = quote
96- if $ __validindex (__ctx__)
97- $ (body)
98- end
99- return nothing
97+ append! (new_stmts, split (emit_gpu, body. args))
98+ if force_inbounds
99+ push! (new_stmts, Expr (:inbounds , :pop ))
100100 end
101+ push! (new_stmts, Expr (:popaliasscope ))
102+ push! (new_stmts, :(return nothing ))
101103 def[:body ] = Expr (
102104 :let ,
103105 Expr (:block , let_constargs... ),
104- body ,
106+ Expr ( :block , new_stmts ... ) ,
105107 )
106108 return
107109end
@@ -127,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
127129 if force_inbounds
128130 push! (new_stmts, Expr (:inbounds , true ))
129131 end
130- append! (new_stmts, split (body. args))
132+ append! (new_stmts, split (emit_cpu, body. args))
131133 if force_inbounds
132134 push! (new_stmts, Expr (:inbounds , :pop ))
133135 end
167169
168170# TODO proper handling of LineInfo
169171function split (
172+ emit,
170173 stmts,
171174 indicies = Any[], private = Set {Symbol} (),
172175 )
@@ -197,7 +200,7 @@ function split(
197200 function recurse (expr:: Expr )
198201 expr = unblock (expr)
199202 if is_scope_construct (expr) && any (find_sync, expr. args)
200- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
203+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
201204 return Expr (expr. head, new_args... )
202205 else
203206 return Expr (expr. head, map (recurse, expr. args)... )
@@ -246,7 +249,7 @@ function split(
246249 return new_stmts
247250end
248251
249- function emit (loop)
252+ function emit_cpu (loop)
250253 idx = gensym (:I )
251254 for stmt in loop. indicies
252255 # splice index into the i = @index(Cartesian, $idx)
@@ -300,3 +303,38 @@ function emit(loop)
300303
301304 return unblock (Expr (:block , stmts... ))
302305end
306+
307+ function emit_gpu (loop)
308+ stmts = Any[]
309+ append! (stmts, loop. allocations)
310+ for stmt in loop. private_allocations
311+ if @capture (stmt, lhs_ = rhs_)
312+ push! (stmts, :($ lhs = $ rhs))
313+ else
314+ error (" @private $stmt not an assignment" )
315+ end
316+ end
317+
318+ # don't emit empty loops
319+ if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
320+ body = Expr (:block , loop. stmts... )
321+ body = postwalk (body) do expr
322+ if @capture (expr, lhs_ = rhs_)
323+ if lhs in loop. private
324+ error (" Can't assign to variables marked private" )
325+ end
326+ end
327+ return expr
328+ end
329+ loopexpr = quote
330+ if __active_lane__
331+ $ (loop. indicies... )
332+ $ (unblock (body))
333+ end
334+ $ __synchronize ()
335+ end
336+ push! (stmts, loopexpr)
337+ end
338+
339+ return unblock (Expr (:block , stmts... ))
340+ end
0 commit comments