From 3114e549d7c30ac6e02436e197b0d61c44256fbd Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 25 Mar 2025 12:48:25 -0500 Subject: [PATCH 1/3] nicer error for type mutation in traced loop body --- lib/ReactantCore/src/ReactantCore.jl | 2 +- src/ControlFlow.jl | 5 +++-- src/Ops.jl | 4 +++- src/TracedUtils.jl | 13 +++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index e4e3bb3117..74bfc4f6a7 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -231,7 +231,7 @@ function trace_for(mod, expr; track_numbers) end $(ReactantCore).traced_while( - cond_fn, body_fn, args; track_numbers=$(track_numbers) + cond_fn, body_fn, args; track_numbers=$(track_numbers), verify_arg_names=$(QuoteNode(args_init)) ) end end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index cce7f1fe39..4e1bf73121 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -9,7 +9,8 @@ function ReactantCore.traced_call(f::Function, args...) end function ReactantCore.traced_while( - cond_fn::CFn, body_fn::BFn, args; track_numbers=Number + cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing ) where {CFn,BFn} - return Ops.while_loop(cond_fn, body_fn, args...; track_numbers) + @warn verify_arg_names + return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names) end diff --git a/src/Ops.jl b/src/Ops.jl index 9b3187a99e..290237cb6e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1746,7 +1746,7 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. end @noinline function while_loop( - cond_fn::CFn, body_fn::BFn, args...; track_numbers + cond_fn::CFn, body_fn::BFn, args...; track_numbers, verify_arg_names=nothing ) where {CFn,BFn} # TODO: detect and prevent mutation within the condition @@ -1780,6 +1780,7 @@ end do_transpose=false, ).f + @warn verify_arg_names body_fn_compiled = Reactant.TracedUtils.make_mlir_fn( body_fn, @@ -1790,6 +1791,7 @@ end return_dialect=:stablehlo, args_in_result=:none, do_transpose=false, + verify_arg_names ).f cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 9f8a54bb9a..937e2f6180 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -189,6 +189,7 @@ function make_mlir_fn( input_shardings=nothing, # This is not meant to be used by the user. output_shardings=nothing, # This is not meant to be used by the user. runtime=nothing, + verify_arg_names=nothing, ) if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction mlir_fn_res = make_mlir_fn( @@ -347,6 +348,18 @@ function make_mlir_fn( if args_in_result == :mutated append!(linear_results, linear_args[mutated_args]) end + if !isnothing(verify_arg_names) && typeof.(linear_args) != typeof.(linear_results) + @assert length(linear_args) <= length(linear_results) + argis = first.(get_argidx.(linear_args)) + resis = Set(getindex.(get_residx.(linear_results), Ref(2))) + # this can be more efficient + conflicts = setdiff(resis, argis) + @assert !isempty(conflicts) "Expected to have some conflicts, but none were found." + + error("""Types do not match between function arguments and results. + The following arguments should be traced: $(join(verify_arg_names.args[collect(conflicts)], ", ")) + """) + end out_tys = if do_transpose [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] From 36cde363834984cd65c44b18842c81c52f92a192 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 25 Mar 2025 16:13:30 -0500 Subject: [PATCH 2/3] fix and better reporting (including path) --- src/ControlFlow.jl | 1 - src/Ops.jl | 1 - src/TracedUtils.jl | 26 ++++++++++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 4e1bf73121..560210a227 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -11,6 +11,5 @@ end function ReactantCore.traced_while( cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing ) where {CFn,BFn} - @warn verify_arg_names return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names) end diff --git a/src/Ops.jl b/src/Ops.jl index 290237cb6e..1d16b5ea13 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1780,7 +1780,6 @@ end do_transpose=false, ).f - @warn verify_arg_names body_fn_compiled = Reactant.TracedUtils.make_mlir_fn( body_fn, diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 937e2f6180..b818d493a3 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -349,15 +349,33 @@ function make_mlir_fn( append!(linear_results, linear_args[mutated_args]) end if !isnothing(verify_arg_names) && typeof.(linear_args) != typeof.(linear_results) - @assert length(linear_args) <= length(linear_results) - argis = first.(get_argidx.(linear_args)) - resis = Set(getindex.(get_residx.(linear_results), Ref(2))) + @assert length(linear_args) <= length(linear_results) "Expected to have missing traced arguments, but it seems like results are missing." + argis = map(get_argidx.(linear_args)) do (_, path) + path[2:end] + end + resis = Set(map(get_residx.(linear_results)) do path + path[2:end] + end) # this can be more efficient conflicts = setdiff(resis, argis) + if isempty(conflicts) + Core.println("linear_args: $linear_args") + Core.println("linear_results: $linear_results") + end @assert !isempty(conflicts) "Expected to have some conflicts, but none were found." + conflicts = collect(conflicts) + @warn "" conflicts + diagnostics = map(conflicts) do conflict + i = first(conflict) + remaining_path = conflict[2:end] + name = verify_arg_names.args[i] + "$name -> [$(join(remaining_path, ", "))]" + end + error("""Types do not match between function arguments and results. - The following arguments should be traced: $(join(verify_arg_names.args[collect(conflicts)], ", ")) + The following arguments should be traced: + $(join(diagnostics, "\n")) """) end From ab520a387f34bd0f4fb986bb5852978092b0f3c4 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 25 Mar 2025 16:18:42 -0500 Subject: [PATCH 3/3] bump --- Project.toml | 4 ++-- lib/ReactantCore/Project.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 54790cd73c..969ef058da 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.54" +version = "0.2.55" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -85,7 +85,7 @@ Preferences = "1.4" PythonCall = "0.9" Random = "1.10" Random123 = "1.7" -ReactantCore = "0.1.7" +ReactantCore = "0.1.8" Reactant_jll = "0.0.101" Scratch = "1.2" Sockets = "1.10" diff --git a/lib/ReactantCore/Project.toml b/lib/ReactantCore/Project.toml index 88227fc570..d994924f53 100644 --- a/lib/ReactantCore/Project.toml +++ b/lib/ReactantCore/Project.toml @@ -1,7 +1,7 @@ name = "ReactantCore" uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.1.7" +version = "0.1.8" [deps] ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"