diff --git a/Project.toml b/Project.toml index 54790cd73c..969ef058da 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.54" +version = "0.2.55" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -85,7 +85,7 @@ Preferences = "1.4" PythonCall = "0.9" Random = "1.10" Random123 = "1.7" -ReactantCore = "0.1.7" +ReactantCore = "0.1.8" Reactant_jll = "0.0.101" Scratch = "1.2" Sockets = "1.10" diff --git a/lib/ReactantCore/Project.toml b/lib/ReactantCore/Project.toml index 88227fc570..d994924f53 100644 --- a/lib/ReactantCore/Project.toml +++ b/lib/ReactantCore/Project.toml @@ -1,7 +1,7 @@ name = "ReactantCore" uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.1.7" +version = "0.1.8" [deps] ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43" diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 4e1bf73121..560210a227 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -11,6 +11,5 @@ end function ReactantCore.traced_while( cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing ) where {CFn,BFn} - @warn verify_arg_names return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names) end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 1800d5b3a8..54838f2753 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -349,15 +349,32 @@ function make_mlir_fn( append!(linear_results, linear_args[mutated_args]) end if !isnothing(verify_arg_names) && typeof.(linear_args) != typeof.(linear_results) - @assert length(linear_args) <= length(linear_results) - argis = first.(get_argidx.(linear_args)) - resis = Set(getindex.(get_residx.(linear_results), Ref(2))) + @assert length(linear_args) <= length(linear_results) "Expected to have missing traced arguments, but it seems like results are missing." + argis = map(get_argidx.(linear_args)) do (_, path) + path[2:end] + end + resis = Set(map(get_residx.(linear_results)) do path + path[2:end] + end) # this can be more efficient conflicts = setdiff(resis, argis) + if isempty(conflicts) + Core.println("linear_args: $linear_args") + Core.println("linear_results: $linear_results") + end @assert !isempty(conflicts) "Expected to have some conflicts, but none were found." + conflicts = collect(conflicts) + diagnostics = map(conflicts) do conflict + i = first(conflict) + remaining_path = conflict[2:end] + name = verify_arg_names.args[i] + "$name -> [$(join(remaining_path, ", "))]" + end + error("""Types do not match between function arguments and results. - The following arguments should be traced: $(join(verify_arg_names.args[collect(conflicts)], ", ")) + The following arguments should be traced: + $(join(diagnostics, "\n")) """) end