Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow splitting to any number of partitions #829

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

francois-rozet
Copy link

@francois-rozet francois-rozet commented Sep 2, 2024

Targets #824

Hi @patrick-kidger, in this PR I modify eqx.partition such that it accepts any ($N$) number of filter specifications and returns $N + 1$ partitions. To keep backward compatibility, I have kept the filter_spec argument and added a variadic positional argument *filter_specs. This however forces replace and is_leaf to become keyword arguments. I think that is ok.

Copy link
Contributor

@knyazer knyazer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! I left a bunch of comments all over the place, would appreciate to hear your opinion on them.

@@ -135,21 +135,45 @@ def filter(
def partition(
pytree: PyTree,
filter_spec: PyTree[AxisSpec],
*filter_specs: PyTree[AxisSpec],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike breaking backwards compatibility! Could you instead make filter_specs a keyword argument? This will be less ergonomic, since will force you to actually spell eqx.partition(..., filter_specs=()), but it is worth it.

And, the ordering of the filters matters in your implementation, so it is more 'logical' to have them as a sequence, and not as a variadic: in the current implementation it is not clear that order of filters matters a lot, but if it is another parameter, then it's pretty clear.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW if we made this a keyword then we'd need to allow not passing filter_spec itself in the case that it's used.

Hmm, this backward compatibility break may sink this whole endeavour.

Copy link
Author

@francois-rozet francois-rozet Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO forcing replace and is_leaf to be keyword arguments is actually an improvement as it could prevent user errors. It does break backward compatibility slightly, but I don't think there are many (if any) codebases using these arguments as positional. Note that jax.tree.map forces is_leaf as a keyword argument.

Concerning the Sequence instead of variadic, I disagree. map or even jax.tree.map uses variadic arguments in a similar manner.

Comment on lines +145 to +147
More generally, provide $N$ filter specifications to split the tree into $N + 1$
non-overlapping partitions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you slightly extend on the documentation? Before I looked at the code, it was not obvious what multiple partitioning does. My first assumption was that you would split things with some redundancy:

everything_that_is_array, everything_that_is_inexact_array, rest = eqx.filter(x, eqx.is_array, eqx.is_inexact_array)

I feel like this is more in line with the previous partitioning logic, at least that is how my brains understand the standard partitioning.

Actually, I really like this way to specify things, but it does indeed break backwards compatibility :(

I would probably implement it as a separate partition function, e.g. eqx.multi_partition? This way the syntax could be wonderfully concise, and compatibility will be kept intact. Anyhow, that is for Patrick to decide, I'm just passing by.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does also speak to another good point, how do overlapping filters work? It's not really clear which group things end up in.

The current 'only split into 2' scenario necessitates making this explicit.

Copy link
Author

@francois-rozet francois-rozet Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a leaf satisfies several filters, it is contained in the partition corresponding to the first filter that is satisfied, as determined by the user-provided order. The last partition is dedicated to leaves that do not satisfy any filter.

This is aligned with eqx.combine which assumes that each leaf is only represented in one of the partitions.

Comment on lines +153 to +176
filter_trees = [
jtu.tree_map(_make_filter_tree(is_leaf), spec, pytree)
for spec in (filter_spec, *filter_specs)
]

partitions = []

for i in range(len(filter_trees)):
partition = jtu.tree_map(
lambda x, curr, *prev: x if curr and not any(prev) else replace,
pytree,
filter_trees[i],
*filter_trees[:i],
)

partitions.append(partition)

rest = jtu.tree_map(
lambda x, prev: replace if any(prev) else x,
pytree,
*filter_trees,
)

return *partitions, rest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially considering how much new code is added into this 'simple' function, separating it into a distinct function makes sense to me.


for i in range(len(filter_trees)):
partition = jtu.tree_map(
lambda x, curr, *prev: x if curr and not any(prev) else replace,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it should be any here, not jnp.any? Could you also test this under jit, just to make sure it works correctly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your concern here. The filters should return True or False values, never arrays. In fact, you should never use booleans arrays for control flow.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knyazer has spotted some meaningful flaws, I think.

For that reason I'm afraid I'm inclined not to take this PR -- neither the backward compatibility break nor the unclear semantics seem worth the trade-off for this feature, I'm afraid.

@@ -135,21 +135,45 @@ def filter(
def partition(
pytree: PyTree,
filter_spec: PyTree[AxisSpec],
*filter_specs: PyTree[AxisSpec],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW if we made this a keyword then we'd need to allow not passing filter_spec itself in the case that it's used.

Hmm, this backward compatibility break may sink this whole endeavour.

Comment on lines +145 to +147
More generally, provide $N$ filter specifications to split the tree into $N + 1$
non-overlapping partitions.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does also speak to another good point, how do overlapping filters work? It's not really clear which group things end up in.

The current 'only split into 2' scenario necessitates making this explicit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants