Skip to content

Migrate from Legacy JAX APIs jax.tree_util to jax.tree#986

Closed
apivovarov wants to merge 1 commit intoapple:mainfrom
apivovarov:jax_tree
Closed

Migrate from Legacy JAX APIs jax.tree_util to jax.tree#986
apivovarov wants to merge 1 commit intoapple:mainfrom
apivovarov:jax_tree

Conversation

@apivovarov
Copy link
Contributor

Description

This PR migrates the axlearn codebase from Legacy JAX APIs (jax.tree_util) to the recommended jax.tree module.

The jax.tree API was introduced in JAX v0.4.25 and is now the preferred approach over jax.tree_util. Upgrading to jax.tree ensures better compatibility with future JAX versions and improves code maintainability.

jax.tree doc

jax.tree_util doc

pre-commit

$ pre-commit run -a      
Check Yaml...............................................................Passed
Fix End of Files.........................................................Passed
Trim Trailing Whitespace.................................................Passed
black....................................................................Passed
isort....................................................................Passed
pylint...................................................................Passed

pytype

$ pytype -j auto axlearn
...
Success: no errors found

pytest

pytest -v -n 96 -m "not (gs_login or tpu or high_cpu or fp64)" axlearn/common

========== 0 failed, 6220 passed, 10364 skipped in 734.23s (0:12:14) ==========

@apivovarov apivovarov requested review from a team, markblee and ruomingp as code owners February 12, 2025 22:41
def test_pytree_nodes(self):
p = _Point(x=1, y=2, meta={"abc": True})
leaves = jax.tree_util.tree_leaves(p)
leaves = jax.jax.tree.leaves(p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this looks like a typo?

Suggested change
leaves = jax.jax.tree.leaves(p)
leaves = jax.tree.leaves(p)

Likewise below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


def _paths(x):
return jax.tree_util.tree_leaves(tree_paths(x))
return jax.jax.tree.leaves(tree_paths(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Contributor Author

@apivovarov apivovarov Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed all jax.jax typos and rerun pre-commit and pytype

@markblee
Copy link
Contributor

I wonder how the CI passed with those typos. Do they fail locally for you?

@apivovarov
Copy link
Contributor Author

apivovarov commented Feb 14, 2025

Apparently, jax.jax.jax.tree.leaves works file

import jax
import jax.numpy as jnp
import sys

x = (jnp.array(0), jnp.array(1))
y = jax.jax.jax.tree.leaves(x)

print(y)
print(type(y))
print(jax.jax)
print(jax.jax.jax)
[Array(0, dtype=int32, weak_type=True), Array(1, dtype=int32, weak_type=True)]
<class 'list'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>

https://colab.research.google.com/drive/1ruOWXG6GXFSh1xdHyBVJVLyRwZTYq6GQ?usp=sharing

@kelvin-zou
Copy link
Contributor

Thanks @apivovarov , though this kind of cleanups would be better handled if you propose and let us fix it. We typically need to run many internal validation etc before merging the PR. The hairy part is from our internal repo which uses AxLearn as the core library.

We will take this cleanup PR as a low priority, so does the other cleanups since they are not blocking anything at the moment. It would be great if aws can focus more on prioritizing trainium2 fixes.

@changlan
Copy link
Contributor

Hi @apivovarov do you intend to move forward with this PR? Thanks.

@apivovarov apivovarov closed this Jul 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants