-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable static arrays #800
Disable static arrays #800
Conversation
Addresses #798 |
Thinking about it more, it seems like there are two options:
|
Haha! That's a thing we could do in |
"I think what should be done is to tree-flatten static fields, and then check if any leaves are arrays." I thought about that, but is that universally valid? Specifically, I could make a pytree which has an array as a member variable, but also overrides the hash and equality operators. Like possessing an array in a marked as static pytree is not necessarily equivalent to not being hashable/equality checkable (if you override these methods), or is this cheating to do that? One could make an object and it would be fully hashable/equalable which is required for aux, but it would still possess array ownership (this is actually what I did when I was playing around in andraz Langevin PR and it seemed to work). Or is this not truly a valid auxiliary/static data? |
I don't think that's valid. Both JAX and Equinox do a lot of tree flattening in all kinds of places -- I think static data needs to be robust to that. |
Hmmm surprised I got it to work then (usually jax fails quickly when attempting these sort of things, but these aux data edge cases are surprisingly robust). I reverted it to be the original approach of just check if there are any arrays in the static object. |
Awesome! That was surprisingly painless. I would not be at all surprised if this breaks downstream parts of the ecosystem. (I'm probably equally guilty of accidentally using static arrays as anyone else...) Can you try running the Lineax/Optimistic/Diffrax tests locally and checking whether they succeed? |
I do love some breaking changes, but it seems like everything is passing locally. Looking at lineax, the only place I see static's are in |
equinox/_module.py
Outdated
# [Step 3.5] Prevent arrays from being marked as static | ||
for field in dataclasses.fields(self): | ||
if field.metadata.get("static", False): | ||
if any( | ||
jtu.tree_map( | ||
is_array, jtu.tree_flatten(getattr(self, field.name))[0] | ||
) | ||
): | ||
raise ValueError("JAX Arrays cannot be marked as static!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we attempt to hash
instead, do you think? And reject things that are unhashable. Also, jtu.tree_flatten(...)[0] -> jtu.tree_leaves
Other than that, I've been mulling this over and I think I'm not comfortable with the extent to which this is a breaking change. I think this is probably happening in quite a few places in harmless ways, and I'd like to be sure we don't needlessly break people. Can we switch this over to a warnings.warn(..., stacklevel=...)
call instead? That will still help everyone it needs to help, without breaking things overnight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, if static field just means aux data to pytree, we just pass that to jax and make jax tell us if it's good with it or not. So hash is sort of a proxy, in my mind the best test here would be whatever it is that jax does that would break (like whatever tree operation actually relies on this aux stuff being properly defined) since I don't think jax checks it. Not sure what that operation is tho.
I originally had hash, but your point before seemed valid, that having any arrays in the static pytree could have weird interactions with leaves/flattening/raveling even if there is a hash defined. That being said, if we are just raising a warning then I don't think it matters too much, and hash could be good (easy to check, and more direct of a proxy for what might actually fail if they try to treat it as aux data).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I believe JAX has some holes here too, for similar reasons to us: they weren't careful enough to begin with, and now people are quietly depending on this treatment of things. (I don't recall the specifics though.)
No strong feelings on hashability vs arrays really, whatever you think best!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to a warning, and I ultimately decided to leave it with leaves (as opposed to hash). My reasoning was probably the most common (incorrect) use of static arrays is to try to make parameters static/not differentiated. But you can make a validly hashable module/set of parameters, but then even if you mark it as static, you will still see gradients for the arrays. Hence, I went with checking the leaves for arrays over the hash.
Awesome stuff -- this LGTM. Merged :) |
Hello @lockwo, @patrick-kidger 👋 Nice change! As Equinox now prevents arrays to be marked static, what are the remaining use cases for
|
Probably the main one is so that e.g. |
Do we want users to use Equinox modules with non-lifted transformations? For the majority of cases, it is necessary to use lifted ones, so the simple modules like Linear or Conv are kind of exceptions. I am not sure it is best to provide several ways to do the same thing. |
I think this is a choice we now just have to live with for compatibility reasons. If I was to go back and redesign things then I'd probably have Modules auto-static anything that isn't a JAX array. (Maybe we still can do that? Not sure.) |
What do you mean by auto-static? Mark them on assignment? Infer by type hint? During flattening (i doubt that). |
During a Indeed not during flattening, as by that point values may have been replaced with arbitrary values (e.g. a |
I tried that at some point, but it's really hard to make it work as there are many edge cases. For example, adding a list attribute and then appending to that list during def __init__(self):
self.layers = []
for i in range(3):
self.layers.append(...) The only way I found to handle that is to "freeze" the structure at the end of Another issue with that frozen structure is that it messes up the "natural" structure of the tree. The With time, I came to appreciate the idea that modules are just containers (namespaces actually), like tuples, lists or dicts, without any additional logic. It makes it much easier for me to think about them. |
Interesting! Very interesting to hear how you've tried this. Both your points about |
Originally, by just checking the type, users could still footgun themselves with
Static
arrays if they hide them in pytrees (they still can if they use field). For example, tuple of array will still try to do this and won't error, but is actually an error since a tuple relies on the hashes of its members.As a solution, I figured it would be more direct to just see if it was hashable at all, but perhaps there are some edge cases where this might not be ideal(?). However, I could only do this for the
Static
directly not the field marker, since we don't access theinit
of the user so if they just type atuple
and mark it as static then fill it with arrays, I'm not sure how we could catch that.