Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable static arrays #800

Merged
merged 5 commits into from
Aug 21, 2024

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented Aug 13, 2024

Originally, by just checking the type, users could still footgun themselves with Static arrays if they hide them in pytrees (they still can if they use field). For example, tuple of array will still try to do this and won't error, but is actually an error since a tuple relies on the hashes of its members.

As a solution, I figured it would be more direct to just see if it was hashable at all, but perhaps there are some edge cases where this might not be ideal(?). However, I could only do this for the Static directly not the field marker, since we don't access the init of the user so if they just type a tuple and mark it as static then fill it with arrays, I'm not sure how we could catch that.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 13, 2024

Addresses #798

@lockwo lockwo marked this pull request as draft August 13, 2024 00:10
@lockwo
Copy link
Contributor Author

lockwo commented Aug 13, 2024

Thinking about it more, it seems like there are two options:

  1. just disable jax arrays from being typed as a static field (this doesn't require any changes to code base, just 2 LoC addition). Users can still footgun themselves with hashable containers of jax arrays, which break when they try to hash them, but are invisible at declaration time (also Static is still unchanged). This is probably always the case, given the nature of the code (defining a static field tuple then just filling it with arrays, I don't see a way to prevent the user from doing this)
  2. Explicitly define static field == aux data in which case, the Static class itself needs to be rethought a little since it marks a list as a static field (which is not hashable).
  3. The eternal last option: do nothing and leave everything as is
  4. The secret fourth option: remove static fields entirely 👁️

@patrick-kidger
Copy link
Owner

Haha!
I think what should be done is to tree-flatten static fields, and then check if any leaves are arrays.

That's a thing we could do in type(Module).__call__ I think.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 13, 2024

"I think what should be done is to tree-flatten static fields, and then check if any leaves are arrays."

I thought about that, but is that universally valid? Specifically, I could make a pytree which has an array as a member variable, but also overrides the hash and equality operators. Like possessing an array in a marked as static pytree is not necessarily equivalent to not being hashable/equality checkable (if you override these methods), or is this cheating to do that? One could make an object and it would be fully hashable/equalable which is required for aux, but it would still possess array ownership (this is actually what I did when I was playing around in andraz Langevin PR and it seemed to work). Or is this not truly a valid auxiliary/static data?

@patrick-kidger
Copy link
Owner

I don't think that's valid. Both JAX and Equinox do a lot of tree flattening in all kinds of places -- I think static data needs to be robust to that.

@lockwo
Copy link
Contributor Author

lockwo commented Aug 13, 2024

Hmmm surprised I got it to work then (usually jax fails quickly when attempting these sort of things, but these aux data edge cases are surprisingly robust). I reverted it to be the original approach of just check if there are any arrays in the static object.

@lockwo lockwo marked this pull request as ready for review August 13, 2024 18:45
@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 16, 2024

Awesome! That was surprisingly painless.

I would not be at all surprised if this breaks downstream parts of the ecosystem. (I'm probably equally guilty of accidentally using static arrays as anyone else...)

Can you try running the Lineax/Optimistic/Diffrax tests locally and checking whether they succeed?

@lockwo
Copy link
Contributor Author

lockwo commented Aug 16, 2024

I do love some breaking changes, but it seems like everything is passing locally. Looking at lineax, the only place I see static's are in packed_structures which are just pytrees of jaxstruct objects which should be fine. Optimistix just uses it in a filter cond when it's already filtered to be sure they are non arrays. Diffrax just uses it to mark some things (such as brownian classes and some floats) which are all readily static-able. Given that static arrays could often break unless special attention was payed to them, this isn't too surprising.

equinox/_ad.py Outdated Show resolved Hide resolved
Comment on lines 587 to 595
# [Step 3.5] Prevent arrays from being marked as static
for field in dataclasses.fields(self):
if field.metadata.get("static", False):
if any(
jtu.tree_map(
is_array, jtu.tree_flatten(getattr(self, field.name))[0]
)
):
raise ValueError("JAX Arrays cannot be marked as static!")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we attempt to hash instead, do you think? And reject things that are unhashable. Also, jtu.tree_flatten(...)[0] -> jtu.tree_leaves

Other than that, I've been mulling this over and I think I'm not comfortable with the extent to which this is a breaking change. I think this is probably happening in quite a few places in harmless ways, and I'd like to be sure we don't needlessly break people. Can we switch this over to a warnings.warn(..., stacklevel=...) call instead? That will still help everyone it needs to help, without breaking things overnight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, if static field just means aux data to pytree, we just pass that to jax and make jax tell us if it's good with it or not. So hash is sort of a proxy, in my mind the best test here would be whatever it is that jax does that would break (like whatever tree operation actually relies on this aux stuff being properly defined) since I don't think jax checks it. Not sure what that operation is tho.

I originally had hash, but your point before seemed valid, that having any arrays in the static pytree could have weird interactions with leaves/flattening/raveling even if there is a hash defined. That being said, if we are just raising a warning then I don't think it matters too much, and hash could be good (easy to check, and more direct of a proxy for what might actually fail if they try to treat it as aux data).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I believe JAX has some holes here too, for similar reasons to us: they weren't careful enough to begin with, and now people are quietly depending on this treatment of things. (I don't recall the specifics though.)

No strong feelings on hashability vs arrays really, whatever you think best!

Copy link
Contributor Author

@lockwo lockwo Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to a warning, and I ultimately decided to leave it with leaves (as opposed to hash). My reasoning was probably the most common (incorrect) use of static arrays is to try to make parameters static/not differentiated. But you can make a validly hashable module/set of parameters, but then even if you mark it as static, you will still see gradients for the arrays. Hence, I went with checking the leaves for arrays over the hash.

@patrick-kidger patrick-kidger merged commit 7c68c4a into patrick-kidger:main Aug 21, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Awesome stuff -- this LGTM. Merged :)
Thank you for proactively thinking about this one!

@lockwo lockwo deleted the Owen/kill-static-arrays branch August 21, 2024 05:43
@francois-rozet
Copy link

Hello @lockwo, @patrick-kidger 👋 Nice change!

As Equinox now prevents arrays to be marked static, what are the remaining use cases for static_field which are not handled by eqx.filter_*? Maybe is it time for the fourth option? 👀

  1. The secret fourth option: remove static fields entirely

@patrick-kidger
Copy link
Owner

Probably the main one is so that e.g. eqx.nn.Linear and jax.jit work together: the static attributes corresponding to dimension sizes don't get promoted to tracers.

@francois-rozet
Copy link

Do we want users to use Equinox modules with non-lifted transformations? For the majority of cases, it is necessary to use lifted ones, so the simple modules like Linear or Conv are kind of exceptions. I am not sure it is best to provide several ways to do the same thing.

@patrick-kidger
Copy link
Owner

I think this is a choice we now just have to live with for compatibility reasons.

If I was to go back and redesign things then I'd probably have Modules auto-static anything that isn't a JAX array. (Maybe we still can do that? Not sure.)

@francois-rozet
Copy link

francois-rozet commented Aug 30, 2024

What do you mean by auto-static? Mark them on assignment? Infer by type hint? During flattening (i doubt that).

@patrick-kidger
Copy link
Owner

During a __setattr__ call that happens in __init__, run [eqx.is_array(x) for x in jax.tree_util.tree_leaves(item)], and then store statically/dynamically as appropriate.

Indeed not during flattening, as by that point values may have been replaced with arbitrary values (e.g. a vmap(..., in_axes=...) spec, which consists of just integers and None).

@francois-rozet
Copy link

francois-rozet commented Aug 30, 2024

I tried that at some point, but it's really hard to make it work as there are many edge cases. For example, adding a list attribute and then appending to that list during __init__.

def __init__(self):
    self.layers = []
    for i in range(3):
         self.layers.append(...)

The only way I found to handle that is to "freeze" the structure at the end of __init__ (that is keeping a copy of the tree with True or False at the leaves).

Another issue with that frozen structure is that it messes up the "natural" structure of the tree. The tree_flatten method cannot simply return the attributes, it must flatten everything to filter out the leaves that are static. Consequently, you cannot "traverse" the tree anymore (for example with tree_map). It also makes key paths almost unusable, while they could be super useful for inspection and saving weights on disk.

With time, I came to appreciate the idea that modules are just containers (namespaces actually), like tuples, lists or dicts, without any additional logic. It makes it much easier for me to think about them.

@patrick-kidger
Copy link
Owner

Interesting! Very interesting to hear how you've tried this. Both your points about __init__ and tree_flatten are well-made. Maybe that doesn't work so well after all...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants