-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow splitting to any number of partitions #829
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,21 +135,45 @@ def filter( | |
def partition( | ||
pytree: PyTree, | ||
filter_spec: PyTree[AxisSpec], | ||
*filter_specs: PyTree[AxisSpec], | ||
replace: Any = None, | ||
is_leaf: Optional[Callable[[Any], bool]] = None, | ||
) -> tuple[PyTree, PyTree]: | ||
"""Splits a PyTree into two pieces. Equivalent to | ||
`filter(...), filter(..., inverse=True)`, but slightly more efficient. | ||
|
||
More generally, provide $N$ filter specifications to split the tree into $N + 1$ | ||
non-overlapping partitions. | ||
|
||
Comment on lines
+145
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you slightly extend on the documentation? Before I looked at the code, it was not obvious what multiple partitioning does. My first assumption was that you would split things with some redundancy:
I feel like this is more in line with the previous partitioning logic, at least that is how my brains understand the standard partitioning. Actually, I really like this way to specify things, but it does indeed break backwards compatibility :( I would probably implement it as a separate partition function, e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does also speak to another good point, how do overlapping filters work? It's not really clear which group things end up in. The current 'only split into 2' scenario necessitates making this explicit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When a leaf satisfies several filters, it is contained in the partition corresponding to the first filter that is satisfied, as determined by the user-provided order. The last partition is dedicated to leaves that do not satisfy any filter. This is aligned with |
||
!!! info | ||
|
||
See also [`equinox.combine`][] to reconstitute the PyTree again. | ||
""" | ||
|
||
filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree) | ||
left = jtu.tree_map(lambda mask, x: x if mask else replace, filter_tree, pytree) | ||
right = jtu.tree_map(lambda mask, x: replace if mask else x, filter_tree, pytree) | ||
return left, right | ||
filter_trees = [ | ||
jtu.tree_map(_make_filter_tree(is_leaf), spec, pytree) | ||
for spec in (filter_spec, *filter_specs) | ||
] | ||
|
||
partitions = [] | ||
|
||
for i in range(len(filter_trees)): | ||
partition = jtu.tree_map( | ||
lambda x, curr, *prev: x if curr and not any(prev) else replace, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand your concern here. The filters should return |
||
pytree, | ||
filter_trees[i], | ||
*filter_trees[:i], | ||
) | ||
|
||
partitions.append(partition) | ||
|
||
rest = jtu.tree_map( | ||
lambda x, prev: replace if any(prev) else x, | ||
pytree, | ||
*filter_trees, | ||
) | ||
|
||
return *partitions, rest | ||
Comment on lines
+153
to
+176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Especially considering how much new code is added into this 'simple' function, separating it into a distinct function makes sense to me. |
||
|
||
|
||
def _combine(*args): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dislike breaking backwards compatibility! Could you instead make
filter_specs
a keyword argument? This will be less ergonomic, since will force you to actually spelleqx.partition(..., filter_specs=())
, but it is worth it.And, the ordering of the filters matters in your implementation, so it is more 'logical' to have them as a sequence, and not as a variadic: in the current implementation it is not clear that order of filters matters a lot, but if it is another parameter, then it's pretty clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW if we made this a keyword then we'd need to allow not passing
filter_spec
itself in the case that it's used.Hmm, this backward compatibility break may sink this whole endeavour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO forcing
replace
andis_leaf
to be keyword arguments is actually an improvement as it could prevent user errors. It does break backward compatibility slightly, but I don't think there are many (if any) codebases using these arguments as positional. Note thatjax.tree.map
forcesis_leaf
as a keyword argument.Concerning the
Sequence
instead of variadic, I disagree.map
or evenjax.tree.map
uses variadic arguments in a similar manner.