Migrate from Legacy JAX APIs jax.tree_util to jax.tree#986
Migrate from Legacy JAX APIs jax.tree_util to jax.tree#986apivovarov wants to merge 1 commit intoapple:mainfrom
Conversation
axlearn/common/struct_test.py
Outdated
| def test_pytree_nodes(self): | ||
| p = _Point(x=1, y=2, meta={"abc": True}) | ||
| leaves = jax.tree_util.tree_leaves(p) | ||
| leaves = jax.jax.tree.leaves(p) |
There was a problem hiding this comment.
Hm, this looks like a typo?
| leaves = jax.jax.tree.leaves(p) | |
| leaves = jax.tree.leaves(p) |
Likewise below.
axlearn/vision/virtex.py
Outdated
|
|
||
| def _paths(x): | ||
| return jax.tree_util.tree_leaves(tree_paths(x)) | ||
| return jax.jax.tree.leaves(tree_paths(x)) |
There was a problem hiding this comment.
fixed all jax.jax typos and rerun pre-commit and pytype
|
I wonder how the CI passed with those typos. Do they fail locally for you? |
|
Apparently, jax.jax.jax.tree.leaves works file https://colab.research.google.com/drive/1ruOWXG6GXFSh1xdHyBVJVLyRwZTYq6GQ?usp=sharing |
|
Thanks @apivovarov , though this kind of cleanups would be better handled if you propose and let us fix it. We typically need to run many internal validation etc before merging the PR. The hairy part is from our internal repo which uses AxLearn as the core library. We will take this cleanup PR as a low priority, so does the other cleanups since they are not blocking anything at the moment. It would be great if aws can focus more on prioritizing trainium2 fixes. |
|
Hi @apivovarov do you intend to move forward with this PR? Thanks. |
Description
This PR migrates the axlearn codebase from Legacy JAX APIs (
jax.tree_util) to the recommendedjax.treemodule.The jax.tree API was introduced in JAX v0.4.25 and is now the preferred approach over
jax.tree_util. Upgrading tojax.treeensures better compatibility with future JAX versions and improves code maintainability.jax.tree doc
jax.tree_util doc
pre-commit
pytype
pytest