Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from Legacy JAX APIs jax.tree_util to jax.tree #986

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,7 @@ async def _run_serializer():

asyncio.run(_run_serializer())

self._add_futures(
jax.tree_util.tree_flatten(commit_futures)[0] + (additional_futures or [])
)
self._add_futures(jax.tree.flatten(commit_futures)[0] + (additional_futures or []))

# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5161,7 +5161,7 @@ def has_prebuilt_layers(path):
lambda path, spec: spec if has_prebuilt_layers(path) else None, param_specs
)
if prebuilt_layers:
self.assertNotEmpty(jax.tree_util.tree_leaves(prebuilt_specs))
self.assertNotEmpty(jax.tree.leaves(prebuilt_specs))
initialized_state = layer.initialize_parameters_recursively(
prng_key=jax.random.PRNGKey(123), prebuilt=prebuilt_specs
)
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,7 @@ def restore_from_dir(
else:
raise RuntimeError(f"Unknown index entry '{value}'")

restored_state = jax.tree_util.tree_unflatten(
jax.tree_util.tree_structure(state), state_leaves
)
restored_state = jax.tree.unflatten(jax.tree.structure(state), state_leaves)
multihost_utils.sync_global_devices(ckpt_dir)
return restored_state

Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,7 @@ def test_custom_dict(self, checkpointer_cls, custom_dict_type):
step, restored_state = ckpt.restore(step=None, state=state0)
self.assertEqual(100, step)
self.assertEqual(type(restored_state), custom_dict_type)
self.assertIn(
custom_dict_type.__name__, str(jax.tree_util.tree_structure(restored_state))
)
self.assertIn(custom_dict_type.__name__, str(jax.tree.structure(restored_state)))
self.assertNestedEqual(state0, restored_state)
ckpt.stop()

Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __call__(self, input_batch: NestedTensor) -> Output:
isinstance(x, jax.Array) and len(x.devices()) == 1
) or isinstance(x, np.ndarray)
all_host_local_inputs = all(
is_host_local_input_check(t) for t in jax.tree_util.tree_leaves(input_batch)
is_host_local_input_check(t) for t in jax.tree.leaves(input_batch)
)

if all_host_local_inputs:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/inference_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def write(self, *, input_batch: NestedTensor, output_batch: NestedTensor):
output_batch: A NestedTensor whose leaves must be tensors of shape [batch_size, ...].
"""
local_data = dict(input=input_batch, output=output_batch)
local_batch_size = jax.tree_util.tree_leaves(local_data)[0].shape[0]
local_batch_size = jax.tree.leaves(local_data)[0].shape[0]

for i in range(local_batch_size):
example = jax.tree.map(lambda x, index=i: x[index], local_data)
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run(self, **kwargs):
self.summary_writer(step=batch_index, values=output.summaries)

if (batch_index + 1) % 10 == 0:
global_batch_size = len(jax.tree_util.tree_leaves(global_input_batch)[0])
global_batch_size = len(jax.tree.leaves(global_input_batch)[0])
logging.info(
"Processed %d batches and %d examples",
batch_index + 1,
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,7 @@ def _learner_tree(self, params: Nested[Any]) -> Nested[str]:
tree_paths(params),
)
# Check that all params is covered.
if not jax.tree_util.tree_reduce(
lambda x, y: x and (y != ""), learner_name_tree, initializer=True
):
if not jax.tree.reduce(lambda x, y: x and (y != ""), learner_name_tree, initializer=True):
raise ValueError("Composite learner rules do not update all model params.")
return learner_name_tree

Expand Down
14 changes: 7 additions & 7 deletions axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and Adam mu states are same.
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
jax.tree.structure(updated_model_params),
jax.tree.structure(learner_state["optimizer"][1].mu),
)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
Expand Down Expand Up @@ -983,14 +983,14 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
# The structure of updated params and optimizer states are same.
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
jax.tree.structure(updated_model_params),
jax.tree.structure(
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
),
)
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
jax.tree.structure(updated_model_params),
jax.tree.structure(
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
),
)
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def loss_fn(model_params, inputs):
summaries={},
module_outputs={},
)
result = jax.tree_util.tree_reduce(lambda x, y: x.sum() + y.sum(), model_params)
result = jax.tree.reduce(lambda x, y: x.sum() + y.sum(), model_params)
return ForwardOutputs(loss=result, aux={}, output_collection=output_collection)

grads = jax.tree_map(lambda p: jnp.ones_like(p.value), params)
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def test_metric_accumulator(self):
)

chex.assert_trees_all_equal_structs(result, expected)
result = jax.tree_util.tree_leaves(result)
expected = jax.tree_util.tree_leaves(expected)
result = jax.tree.leaves(result)
expected = jax.tree.leaves(expected)
chex.assert_trees_all_close(result, expected)

def test_flatten_unflatten_metric_accumulator(self):
Expand All @@ -75,10 +75,10 @@ def test_flatten_unflatten_metric_accumulator(self):
for s in summaries_copy:
acc.update(s)

flat, tree = jax.tree_util.tree_flatten(acc)
unflattened = jax.tree_util.tree_unflatten(tree, flat)
expected = jax.tree_util.tree_leaves(acc.summaries())
result = jax.tree_util.tree_leaves(unflattened.summaries())
flat, tree = jax.tree.flatten(acc)
unflattened = jax.tree.unflatten(tree, flat)
expected = jax.tree.leaves(acc.summaries())
result = jax.tree.leaves(unflattened.summaries())
chex.assert_trees_all_close(result, expected)

@parameterized.parameters(
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def convert_fn(source_parameters: Nested[Tensor]) -> Nested[Tensor]:
) from e
# The target layer is a RepeatedTransformerLayer.
target_parameters = {"repeat": VDict({"layer": {}})}
num_stages = jax.tree_util.tree_leaves(stage_parameter_specs)[0].shape[0]
num_stages = jax.tree.leaves(stage_parameter_specs)[0].shape[0]
# The target stage is expected to be a StackedTransformerLayer.
num_layers_per_stage = len(stage_parameter_specs)
for layer_i in range(num_layers_per_stage):
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def propagate_repeated_output_collections(
# if a repeated layer outputs a scalar summary value, it will have shape [N].
# Below we split the stacked values and output them separately under scope
# "{child_name_prefix}{i}" so that scalar summaries can be handled correctly.
summary_values = jax.tree_util.tree_leaves(repeated_output_collection.summaries)
summary_values = jax.tree.leaves(repeated_output_collection.summaries)
if summary_values:
first_summary_value = summary_values[0]
assert first_summary_value.shape, "Stacked summaries should have a leading stack dimension."
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ def _is_valid_step(
return is_valid, new_drop_stats

# Check if every gradient is finite.
flat_updates = jax.tree_util.tree_flatten(updates)[0]
flat_updates = jax.tree.flatten(updates)[0]
is_finite = jnp.all(jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
g_norm = optax.global_norm(updates)
if drop_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def _run(
cfg: Pipeline.Config = self.config
self.vlog(1, "carry=%s xs=%s", shapes(carry), shapes(xs))

carry_leaves = jax.tree_util.tree_leaves(carry)
carry_leaves = jax.tree.leaves(carry)
if not carry_leaves:
raise ValueError("Expected at least one input leaf.")
if carry_leaves[0].ndim < 2:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _run(self, fn, carry=None, *, xs=None):
with child_context("layer", output_collection=layer_output_collection) as layer_context:
# Note, actual `num_layers` might be smaller than `cfg.num_layers` depending on
# the invocation context.
num_layers = jax.tree_util.tree_reduce(
num_layers = jax.tree.reduce(
lambda num, x: min(num, x.shape[0]),
tree=(layer_context.state, xs),
initializer=cfg.num_layers,
Expand Down
8 changes: 3 additions & 5 deletions axlearn/common/rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,9 @@ def test_repeat_forward_vs_layerwise(self, norm_cfg, hidden_dim, num_layers):
final_states_list.append(output_collections.module_outputs["final_states"])

# Stack the tree leaves.
tree_leaves = [jax.tree_util.tree_flatten(t)[0] for t in final_states_list]
tree_def = jax.tree_util.tree_structure(final_states_list[0])
final_states = jax.tree_util.tree_unflatten(
tree_def, [jnp.stack(leaf) for leaf in zip(*tree_leaves)]
)
tree_leaves = [jax.tree.flatten(t)[0] for t in final_states_list]
tree_def = jax.tree.structure(final_states_list[0])
final_states = jax.tree.unflatten(tree_def, [jnp.stack(leaf) for leaf in zip(*tree_leaves)])
self.assertEqual(shapes(final_states), shapes(init_states))

forward_outputs, forward_collections = F(
Expand Down
6 changes: 2 additions & 4 deletions axlearn/common/state_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,7 @@ def _run_builder(
**extra_converter_config_kwargs,
):
source_state = _mock_state(source_cfg, seed=0)
initial_trainer_state_tree_structure = jax.tree_util.tree_structure(
source_state.trainer_state
)
initial_trainer_state_tree_structure = jax.tree.structure(source_state.trainer_state)

builder = (
builder_cls.default_config()
Expand All @@ -689,7 +687,7 @@ def _run_builder(
source_model = source_state.trainer_state.model

converted_state = builder(deepcopy(source_state))
assert initial_trainer_state_tree_structure == jax.tree_util.tree_structure(
assert initial_trainer_state_tree_structure == jax.tree.structure(
converted_state.trainer_state
)
converted_model = converted_state.trainer_state.model
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SlotsPoint:

def test_pytree_nodes(self):
p = _Point(x=1, y=2, meta={"abc": True})
leaves = jax.tree_util.tree_leaves(p)
leaves = jax.tree.leaves(p)
self.assertEqual(leaves, [1, 2])
new_p = jax.tree.map(lambda x: x + x, p)
self.assertEqual(new_p, _Point(x=2, y=4, meta={"abc": True}))
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_chex_tree_leaves_compatibility(self):
)
# tree_flatten_with_path is not preserved because Chex does not support this so the
# fallback jax implementation with numbered keys gets used.
flattened.append(jax.tree_util.tree_leaves(instance))
flattened.append(jax.tree.leaves(instance))
chex.assert_trees_all_equal(*flattened)

def test_constructor_order(self):
Expand Down Expand Up @@ -133,7 +133,7 @@ class C:
field_b: int
field_a: int

result = jax.tree_util.tree_leaves(C(field_b=1, field_a=2))
result = jax.tree.leaves(C(field_b=1, field_a=2))
expected = (1, 2)
self.assertSequenceEqual(result, expected)

Expand Down
12 changes: 5 additions & 7 deletions axlearn/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def _compute_layer_outputs(
# Optionally, test that trees also have the same structure.
if require_same_tree_structure:
# Prune empty subtrees so we don't require empty dicts for layers with no params.
ref_structure = jax.tree_util.tree_structure(prune_empty(params_from_ref))
test_structure = jax.tree_util.tree_structure(prune_empty(layer_params))
ref_structure = jax.tree.structure(prune_empty(params_from_ref))
test_structure = jax.tree.structure(prune_empty(layer_params))
self.assertEqual(
ref_structure, test_structure, msg=f"\nRef: {ref_structure}\nTest: {test_structure}"
)
Expand Down Expand Up @@ -428,8 +428,8 @@ def replace_keys(v, mapping):
params_with_nones = jax.tree_map(
partial(replace_keys, mapping={k: None for k in delegates}), params, is_leaf=is_leaf
)
_, treedef = jax.tree_util.tree_flatten(params_with_nones)
inits_with_nones = jax.tree_util.tree_unflatten(treedef, param_init_specs)
_, treedef = jax.tree.flatten(params_with_nones)
inits_with_nones = jax.tree.unflatten(treedef, param_init_specs)

# Replace the Nones with a delegate.
return jax.tree_map(partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf)
Expand Down Expand Up @@ -563,9 +563,7 @@ def patched_register_per_param_settings(
model_params = model.initialize_parameters_recursively(jax.random.PRNGKey(0))

model_specs = model.create_parameter_specs_recursively()
model_specs = complete_partition_spec_tree(
jax.tree_util.tree_structure(model_params), model_specs
)
model_specs = complete_partition_spec_tree(jax.tree.structure(model_params), model_specs)
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
Expand Down
4 changes: 2 additions & 2 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _opt_params(self, model_params: NestedTensor) -> NestedOptParam:
"""Returns a tree of OptParam for Learner.{init,update}."""
# self._model_param_specs can be incomplete. Complete it first.
specs = utils.complete_partition_spec_tree(
jax.tree_util.tree_structure(model_params), self._model_param_specs
jax.tree.structure(model_params), self._model_param_specs
)
return jax.tree.map(
lambda param, spec: OptParam(
Expand Down Expand Up @@ -852,7 +852,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
# Log trainer state tree.
if not self.step and jax.process_index() == 0:
with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f:
f.write(str(jax.tree_util.tree_structure(self._trainer_state)))
f.write(str(jax.tree.structure(self._trainer_state)))

with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f:
f.write(model_analysis)
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,7 @@ def test_compile_train_step(self, *, platform, mesh_shape):
trainer: SpmdTrainer = cfg.instantiate(parent=None)
compiled_without_args = trainer.compile_train_step()
# pylint: disable=protected-access
input_batch = jax.tree_util.tree_map(
jnp.array, next(trainer.input.batches(trainer._input_iter))
)
input_batch = jax.tree.map(jnp.array, next(trainer.input.batches(trainer._input_iter)))
# pylint: enable=protected-access
compiled_with_input_batch = trainer.compile_train_step(input_batch=input_batch)
# In a single-host environment, both compiled functions should match.
Expand Down
22 changes: 10 additions & 12 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def vectorized_tree_map(fn, tree, *rest):

def vectorized_fn(*nodes):
if isinstance(nodes[0], VDict):
if not jax.tree_util.tree_leaves(nodes[0]):
if not jax.tree.leaves(nodes[0]):
# This can happen when all VDict values are None and cause issues with jax.vmap.
return nodes[0]
nodes = [dict(**node) for node in nodes]
Expand Down Expand Up @@ -469,7 +469,7 @@ def fn(value: Union[Tensor, VDict]) -> NestedTensor:
if not isinstance(value, VDict):
return value

leaves = jax.tree_util.tree_leaves(value)
leaves = jax.tree.leaves(value)
if not leaves:
# An empty VDict.
return value
Expand Down Expand Up @@ -653,7 +653,7 @@ def complete_partition_spec_tree(
prefix of treedef.
"""
proxy = object()
dummy = jax.tree_util.tree_unflatten(treedef, [object()] * treedef.num_leaves)
dummy = jax.tree.unflatten(treedef, [object()] * treedef.num_leaves)
axes = []

def replace_none_with_proxy(tree):
Expand All @@ -672,17 +672,17 @@ def replace_none_with_proxy(tree):
partition_spec_tree_with_proxy = replace_none_with_proxy(partition_spec_tree)

def add_leaves(i, x):
axes.extend([i] * len(jax.tree_util.tree_flatten(x)[0]))
axes.extend([i] * len(jax.tree.flatten(x)[0]))

try:
jax.tree.map(add_leaves, partition_spec_tree_with_proxy, dummy)
except ValueError as err:
logging.info("[complete_partition_spec_tree] ValueError: %s", err)
logging.info(
"[complete_partition_spec_tree] partition_spec_tree_with_proxy=%s",
jax.tree_util.tree_structure(partition_spec_tree_with_proxy),
jax.tree.structure(partition_spec_tree_with_proxy),
)
logging.info("[complete_partition_spec_tree] dummy=%s", jax.tree_util.tree_structure(dummy))
logging.info("[complete_partition_spec_tree] dummy=%s", jax.tree.structure(dummy))
for path, value in flatten_items(partition_spec_tree_with_proxy):
logging.info(
"[complete_partition_spec_tree] partition_spec_tree_with_proxy leaf: %s=%s",
Expand All @@ -701,7 +701,7 @@ def add_leaves(i, x):
assert (
len(axes) == treedef.num_leaves
), f"({len(axes)} vs. {treedef.num_leaves}) {axes} {treedef}"
return jax.tree_util.tree_unflatten(treedef, axes)
return jax.tree.unflatten(treedef, axes)


def input_partition_spec() -> PartitionSpec:
Expand Down Expand Up @@ -801,9 +801,7 @@ def host_to_global_device_array(
"""
mesh = thread_resources.env.physical_mesh
partition_spec = data_partition_type_to_spec(partition)
partition_specs = complete_partition_spec_tree(
jax.tree_util.tree_structure(host_arrays), partition_spec
)
partition_specs = complete_partition_spec_tree(jax.tree.structure(host_arrays), partition_spec)
process_count = jax.process_count()

def make_gda(x, partition_spec):
Expand Down Expand Up @@ -1031,7 +1029,7 @@ def cast(x: Union[Tensor, TensorSpec]) -> Union[Tensor, TensorSpec]:

def count_model_params(tree: NestedTensor) -> int:
"""Count the number of parameters in a model."""
return sum(x.size for x in jax.tree_util.tree_leaves(tree))
return sum(x.size for x in jax.tree.leaves(tree))


def check_param_shape_alignment(
Expand Down Expand Up @@ -1095,7 +1093,7 @@ def check_jax_type(
pretty_named_args.update({f"kwargs[{key}]": kwargs[key] for key in kwargs})

for name, arg in pretty_named_args.items():
values, _ = jax.tree_util.tree_flatten(arg)
values, _ = jax.tree.flatten(arg)
for value in values:
if not isinstance(value, (type(None), jax.Array, int, float)):
if msg is None:
Expand Down
Loading