@@ -1705,7 +1705,9 @@ function codegen_flatten!(
1705
1705
# Important to mark donated after we have extracted the data
1706
1706
push! (
1707
1707
flatten_code,
1708
- :(donate_argument! (donated_args_mask, $ carg_sym, $ i, donated_buffers)),
1708
+ :(donate_argument! (
1709
+ donated_args_mask, $ carg_sym, $ i, donated_buffers, $ (path)
1710
+ )),
1709
1711
)
1710
1712
elseif runtime isa Val{:IFRT }
1711
1713
push! (flatten_code, :($ carg_sym = $ flatcode))
@@ -1775,7 +1777,9 @@ function codegen_flatten!(
1775
1777
# Important to mark donated after we have extracted the data
1776
1778
push! (
1777
1779
flatten_code,
1778
- :(donate_argument! (donated_args_mask, $ carg_sym, $ i, donated_buffers)),
1780
+ :(donate_argument! (
1781
+ donated_args_mask, $ carg_sym, $ i, donated_buffers, $ (path)
1782
+ )),
1779
1783
)
1780
1784
else
1781
1785
error (" Unsupported runtime $runtime " )
@@ -1790,11 +1794,11 @@ function codegen_flatten!(
1790
1794
return flatten_names, flatten_code, resharded_inputs
1791
1795
end
1792
1796
1793
- function donate_argument! (donated_args_mask, carg, i:: Int , donated_buffers)
1797
+ function donate_argument! (donated_args_mask, carg, i:: Int , donated_buffers, path )
1794
1798
if donated_args_mask[i]
1795
1799
if carg. data in donated_buffers
1796
1800
error (" Donated buffer $(carg. data) is already marked as donated. Can't donate \
1797
- the same buffer multiple times." )
1801
+ the same buffer multiple times. The argument is present at $(path) " )
1798
1802
end
1799
1803
push! (donated_buffers, carg. data)
1800
1804
Reactant. mark_donated! (carg)
@@ -2395,9 +2399,15 @@ function compile(f, args; sync=false, kwargs...)
2395
2399
2396
2400
fname = gensym (Symbol (Symbol (f), :_reactant ))
2397
2401
2402
+ donated_buffers_set = if XLA. runtime (client) isa Val{:PJRT }
2403
+ :(Base. IdSet {NTuple{<:Any,XLA.PJRT.AsyncBuffer}} ())
2404
+ else
2405
+ :(Base. IdSet {XLA.IFRT.AsyncArray} ())
2406
+ end
2407
+
2398
2408
body = quote
2399
2409
global_mesh = $ (global_mesh_expr)
2400
- donated_buffers = IdSet ( )
2410
+ donated_buffers = $ (donated_buffers_set )
2401
2411
donated_args_mask = thunk. donated_args_mask
2402
2412
$ (flatten_code... )
2403
2413
$ (xla_call_code)
0 commit comments