@@ -58,22 +58,105 @@ function transform_gpu!(def, constargs, force_inbounds)
5858 end
5959 end
6060 pushfirst! (def[:args ], :__ctx__ )
61- body = def[:body ]
61+ new_stmts = Expr[]
62+ body = MacroTools. flatten (def[:body ])
63+ stmts = body. args
64+ push! (new_stmts, Expr (:aliasscope ))
65+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
6266 if force_inbounds
63- body = quote
64- @inbounds $ (body)
65- end
67+ push! (new_stmts, Expr (:inbounds , true ))
6668 end
67- body = quote
68- if $ __validindex (__ctx__)
69- $ (body)
70- end
71- return nothing
69+ append! (new_stmts, split (emit_gpu, body. args))
70+ if force_inbounds
71+ push! (new_stmts, Expr (:inbounds , :pop ))
7272 end
73+ push! (new_stmts, Expr (:popaliasscope ))
74+ push! (new_stmts, :(return nothing ))
7375 def[:body ] = Expr (
7476 :let ,
7577 Expr (:block , let_constargs... ),
76- body ,
78+ Expr ( :block , new_stmts ... ) ,
7779 )
7880 return
7981end
82+
83+ struct WorkgroupLoop
84+ stmts:: Vector{Any}
85+ terminated_in_sync:: Bool
86+ end
87+
88+ is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
89+
90+ function is_scope_construct (expr:: Expr )
91+ return expr. head === :block # ||
92+ # expr.head === :let
93+ end
94+
95+ function find_sync (stmt)
96+ result = false
97+ postwalk (stmt) do expr
98+ result |= is_sync (expr)
99+ expr
100+ end
101+ return result
102+ end
103+
104+ # TODO proper handling of LineInfo
105+ function split (emit, stmts)
106+ # 1. Split the code into blocks separated by `@synchronize`
107+
108+ current = Any[]
109+ new_stmts = Any[]
110+ for stmt in stmts
111+ has_sync = find_sync (stmt)
112+ if has_sync
113+ loop = WorkgroupLoop (current, is_sync (stmt))
114+ push! (new_stmts, emit (loop))
115+ current = Any[]
116+ is_sync (stmt) && continue
117+
118+ # Recurse into scope constructs
119+ # TODO : This currently implements hard scoping
120+ # probably need to implemet soft scoping
121+ # by not deepcopying the environment.
122+ recurse (x) = x
123+ function recurse (expr:: Expr )
124+ expr = unblock (expr)
125+ if is_scope_construct (expr) && any (find_sync, expr. args)
126+ new_args = unblock (split (emit, expr. args))
127+ return Expr (expr. head, new_args... )
128+ else
129+ return Expr (expr. head, map (recurse, expr. args)... )
130+ end
131+ end
132+ push! (new_stmts, recurse (stmt))
133+ continue
134+ end
135+
136+ push! (current, stmt)
137+ end
138+
139+ # everything since the last `@synchronize`
140+ if ! isempty (current)
141+ loop = WorkgroupLoop (current, false )
142+ push! (new_stmts, emit (loop))
143+ end
144+ return new_stmts
145+ end
146+
147+ function emit_gpu (loop)
148+ stmts = Any[]
149+
150+ body = Expr (:block , loop. stmts... )
151+ loopexpr = quote
152+ if __active_lane__
153+ $ (unblock (body))
154+ end
155+ end
156+ push! (stmts, loopexpr)
157+ if loop. terminated_in_sync
158+ push! (stmts, :($ __synchronize ()))
159+ end
160+
161+ return unblock (Expr (:block , stmts... ))
162+ end
0 commit comments