Skip to content

Use Bijection for thunk body cache #1338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Use Bijection for thunk body cache #1338

wants to merge 1 commit into from

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented May 27, 2025

The __thunk_fwd_body_cache and __thunk_rev_body_cache internal variables can be unified as a Bijection.

@wsmoses The new Bijections.jl 0.2 release also has all I required from BijectiveDicts.jl for Tenet.jl, so I'm refactoring my package with it.

If we are able to correctly and performantly trace over a Bijection (more specifically, a Bijection{X, TracedRArray, Dict{X, TracedRArray}, IdDict{TracedRArray, X}} where X can be whatever, then I shouldn't need any Reactant tracing customization in Tenet.

EDIT: I've generalized some tracing functions to work with Bijection too.

github-actions[bot]

This comment was marked as outdated.

@mofeing mofeing requested a review from wsmoses June 2, 2025 21:31
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@mofeing
Copy link
Collaborator Author

mofeing commented Jun 2, 2025

mmm all these format comments seem to come from code unrelated to this PR...

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

Copy link
Collaborator Author

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses I managed to get tracing working with general AbstractDict. Instead of tracing collect(dict) which is a Vector{<:Pair}, it does so by tracing the Pairs iteratively and enumerating them. Also, I had to refactor traced_getfield(::AbstractDict) to accept a Integer which is the iteration index.

When I introduced traced_getfield, it was to "fake" fields. It's behavior seems to have evolved but I have the feeling that for this case it doesn't fit perfectly. The reason is that we are now passing a iteration index (and before this PR it should be sth like traced_getindex).

Finally, I didn't manage to get it with inplace mutations (i.e. adding a new field), so adding or removing entries might be problematic.

For example, this works

julia> function combine_ab_to_c!(d)
         d[:a] = d[:a] + d[:b]
         return d
       end

julia> @jit combine_ab_to_c!(d)
Bijection{Symbol, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Dict{Symbol, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, IdDict{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Symbol}} with 2 entries:
  :a => ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([5.0])
  :b => ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([1.0])

but this doesn't work

julia> function combine_ab_to_c!(d)
         d[:c] = d[:a] + d[:b]
         return d
       end
combine_ab_to_c! (generic function with 1 method)

julia> @jit combine_ab_to_c!(d)
ERROR: ArgumentError: collection must be non-empty
Stacktrace:
 [1] first
   @ ./abstractarray.jl:473 [inlined]
 [2] traced_getfield
   @ ~/Developer/Reactant.jl/src/Compiler.jl:37 [inlined]
 [3] macro expansion
   @ ~/Developer/Reactant.jl/src/Compiler.jl:2914 [inlined]
 [4] (::Reactant.Compiler.Thunk{…})(args::Bijection{…})
   @ Reactant.Compiler ~/Developer/Reactant.jl/src/Compiler.jl:3473
 [5] top-level scope
   @ ~/Developer/Reactant.jl/src/Compiler.jl:2333
Some type information was truncated. Use `show(err)` to see complete types.

@mofeing
Copy link
Collaborator Author

mofeing commented Jun 13, 2025

As discussed in the last meeting, the tracing modifications to make Bijection and other AbstractDict traceable are being worked on in #1398

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant