@@ -81,11 +81,8 @@ function transform_gpu!(def, constargs, force_inbounds)
8181end
8282
8383struct WorkgroupLoop
84- indicies:: Vector{Any}
8584 stmts:: Vector{Any}
8685 allocations:: Vector{Any}
87- private_allocations:: Vector{Any}
88- private:: Set{Symbol}
8986 terminated_in_sync:: Bool
9087end
9188
@@ -106,26 +103,18 @@ function find_sync(stmt)
106103end
107104
108105# TODO proper handling of LineInfo
109- function split (
110- stmts,
111- indicies = Any[], private = Set {Symbol} (),
112- )
106+ function split (stmts)
113107 # 1. Split the code into blocks separated by `@synchronize`
114- # 2. Aggregate `@index` expressions
115- # 3. Hoist allocations
116- # 4. Hoist uniforms
117108
118109 current = Any[]
119110 allocations = Any[]
120- private_allocations = Any[]
121111 new_stmts = Any[]
122112 for stmt in stmts
123113 has_sync = find_sync (stmt)
124114 if has_sync
125- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , is_sync (stmt))
115+ loop = WorkgroupLoop (current, allocations, is_sync (stmt))
126116 push! (new_stmts, emit (loop))
127117 allocations = Any[]
128- private_allocations = Any[]
129118 current = Any[]
130119 is_sync (stmt) && continue
131120
@@ -137,7 +126,7 @@ function split(
137126 function recurse (expr:: Expr )
138127 expr = unblock (expr)
139128 if is_scope_construct (expr) && any (find_sync, expr. args)
140- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private) ))
129+ new_args = unblock (split (expr. args))
141130 return Expr (expr. head, new_args... )
142131 else
143132 return Expr (expr. head, map (recurse, expr. args)... )
@@ -151,8 +140,7 @@ function split(
151140 push! (allocations, stmt)
152141 continue
153142 elseif @capture (stmt, @private lhs_ = rhs_)
154- push! (private, lhs)
155- push! (private_allocations, :($ lhs = $ rhs))
143+ push! (allocations, :($ lhs = $ rhs))
156144 continue
157145 elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
158146 if @capture (rhs, @index (args__))
@@ -170,7 +158,6 @@ function split(
170158 end
171159 alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
172160 push! (allocations, :($ lhs = $ alloc))
173- push! (private, lhs)
174161 continue
175162 end
176163 end
@@ -180,7 +167,7 @@ function split(
180167
181168 # everything since the last `@synchronize`
182169 if ! isempty (current)
183- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , false )
170+ loop = WorkgroupLoop (current, allocations, false )
184171 push! (new_stmts, emit (loop))
185172 end
186173 return new_stmts
@@ -192,9 +179,7 @@ function emit(loop)
192179 body = Expr (:block , loop. stmts... )
193180 loopexpr = quote
194181 $ (loop. allocations... )
195- $ (loop. private_allocations... )
196182 if __active_lane__
197- $ (loop. indicies... )
198183 $ (unblock (body))
199184 end
200185 end
0 commit comments