Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow splitting to any number of partitions #829

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions equinox/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,45 @@ def filter(
def partition(
pytree: PyTree,
filter_spec: PyTree[AxisSpec],
*filter_specs: PyTree[AxisSpec],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike breaking backwards compatibility! Could you instead make filter_specs a keyword argument? This will be less ergonomic, since will force you to actually spell eqx.partition(..., filter_specs=()), but it is worth it.

And, the ordering of the filters matters in your implementation, so it is more 'logical' to have them as a sequence, and not as a variadic: in the current implementation it is not clear that order of filters matters a lot, but if it is another parameter, then it's pretty clear.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW if we made this a keyword then we'd need to allow not passing filter_spec itself in the case that it's used.

Hmm, this backward compatibility break may sink this whole endeavour.

Copy link
Author

@francois-rozet francois-rozet Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO forcing replace and is_leaf to be keyword arguments is actually an improvement as it could prevent user errors. It does break backward compatibility slightly, but I don't think there are many (if any) codebases using these arguments as positional. Note that jax.tree.map forces is_leaf as a keyword argument.

Concerning the Sequence instead of variadic, I disagree. map or even jax.tree.map uses variadic arguments in a similar manner.

replace: Any = None,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> tuple[PyTree, PyTree]:
"""Splits a PyTree into two pieces. Equivalent to
`filter(...), filter(..., inverse=True)`, but slightly more efficient.

More generally, provide $N$ filter specifications to split the tree into $N + 1$
non-overlapping partitions.

Comment on lines +145 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you slightly extend on the documentation? Before I looked at the code, it was not obvious what multiple partitioning does. My first assumption was that you would split things with some redundancy:

everything_that_is_array, everything_that_is_inexact_array, rest = eqx.filter(x, eqx.is_array, eqx.is_inexact_array)

I feel like this is more in line with the previous partitioning logic, at least that is how my brains understand the standard partitioning.

Actually, I really like this way to specify things, but it does indeed break backwards compatibility :(

I would probably implement it as a separate partition function, e.g. eqx.multi_partition? This way the syntax could be wonderfully concise, and compatibility will be kept intact. Anyhow, that is for Patrick to decide, I'm just passing by.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does also speak to another good point, how do overlapping filters work? It's not really clear which group things end up in.

The current 'only split into 2' scenario necessitates making this explicit.

Copy link
Author

@francois-rozet francois-rozet Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a leaf satisfies several filters, it is contained in the partition corresponding to the first filter that is satisfied, as determined by the user-provided order. The last partition is dedicated to leaves that do not satisfy any filter.

This is aligned with eqx.combine which assumes that each leaf is only represented in one of the partitions.

!!! info

See also [`equinox.combine`][] to reconstitute the PyTree again.
"""

filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
left = jtu.tree_map(lambda mask, x: x if mask else replace, filter_tree, pytree)
right = jtu.tree_map(lambda mask, x: replace if mask else x, filter_tree, pytree)
return left, right
filter_trees = [
jtu.tree_map(_make_filter_tree(is_leaf), spec, pytree)
for spec in (filter_spec, *filter_specs)
]

partitions = []

for i in range(len(filter_trees)):
partition = jtu.tree_map(
lambda x, curr, *prev: x if curr and not any(prev) else replace,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it should be any here, not jnp.any? Could you also test this under jit, just to make sure it works correctly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your concern here. The filters should return True or False values, never arrays. In fact, you should never use booleans arrays for control flow.

pytree,
filter_trees[i],
*filter_trees[:i],
)

partitions.append(partition)

rest = jtu.tree_map(
lambda x, prev: replace if any(prev) else x,
pytree,
*filter_trees,
)

return *partitions, rest
Comment on lines +153 to +176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially considering how much new code is added into this 'simple' function, separating it into a distinct function makes sense to me.



def _combine(*args):
Expand Down
20 changes: 18 additions & 2 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,24 @@ def test_partition_and_combine(getkey):


def test_partition_subtree():
a, b = eqx.partition([(1,), 2], [True, False])
eqx.combine(a, b)
pytree = [(1,), 2]

a, b = eqx.partition(pytree, [True, False])

assert eqx.combine(a, b) == pytree


def test_partition_multi():
pytree = [(1,), 2, (3.0, "four")]

partitions = eqx.partition(
pytree,
[True, False, False],
[False, False, (False, True)],
lambda x: isinstance(x, float),
)

assert eqx.combine(*partitions) == pytree


def test_is_leaf():
Expand Down