Skip to content

Commit 1e01733

Browse files
committed
fix: print path for aliased buffers
1 parent ae4294f commit 1e01733

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/Compiler.jl

+15-5
Original file line numberDiff line numberDiff line change
@@ -1705,7 +1705,9 @@ function codegen_flatten!(
17051705
# Important to mark donated after we have extracted the data
17061706
push!(
17071707
flatten_code,
1708-
:(donate_argument!(donated_args_mask, $carg_sym, $i, donated_buffers)),
1708+
:(donate_argument!(
1709+
donated_args_mask, $carg_sym, $i, donated_buffers, $(path)
1710+
)),
17091711
)
17101712
elseif runtime isa Val{:IFRT}
17111713
push!(flatten_code, :($carg_sym = $flatcode))
@@ -1775,7 +1777,9 @@ function codegen_flatten!(
17751777
# Important to mark donated after we have extracted the data
17761778
push!(
17771779
flatten_code,
1778-
:(donate_argument!(donated_args_mask, $carg_sym, $i, donated_buffers)),
1780+
:(donate_argument!(
1781+
donated_args_mask, $carg_sym, $i, donated_buffers, $(path)
1782+
)),
17791783
)
17801784
else
17811785
error("Unsupported runtime $runtime")
@@ -1790,11 +1794,11 @@ function codegen_flatten!(
17901794
return flatten_names, flatten_code, resharded_inputs
17911795
end
17921796

1793-
function donate_argument!(donated_args_mask, carg, i::Int, donated_buffers)
1797+
function donate_argument!(donated_args_mask, carg, i::Int, donated_buffers, path)
17941798
if donated_args_mask[i]
17951799
if carg.data in donated_buffers
17961800
error("Donated buffer $(carg.data) is already marked as donated. Can't donate \
1797-
the same buffer multiple times.")
1801+
the same buffer multiple times. The argument is present at $(path)")
17981802
end
17991803
push!(donated_buffers, carg.data)
18001804
Reactant.mark_donated!(carg)
@@ -2395,9 +2399,15 @@ function compile(f, args; sync=false, kwargs...)
23952399

23962400
fname = gensym(Symbol(Symbol(f), :_reactant))
23972401

2402+
donated_buffers_set = if XLA.runtime(client) isa Val{:PJRT}
2403+
:(Base.IdSet{NTuple{<:Any,XLA.PJRT.AsyncBuffer}}())
2404+
else
2405+
:(Base.IdSet{XLA.IFRT.AsyncArray}())
2406+
end
2407+
23982408
body = quote
23992409
global_mesh = $(global_mesh_expr)
2400-
donated_buffers = IdSet()
2410+
donated_buffers = $(donated_buffers_set)
24012411
donated_args_mask = thunk.donated_args_mask
24022412
$(flatten_code...)
24032413
$(xla_call_code)

0 commit comments

Comments
 (0)