Skip to content

Commit

Permalink
remove sharding constraints from gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Feb 26, 2025
1 parent 186e082 commit 8c16718
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 192 deletions.
26 changes: 0 additions & 26 deletions axlearn/common/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import PartitionSpec

from axlearn.common import utils
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.input_base import InputPartitionFn, partition_by_path_rank
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs
from axlearn.common.utils import Nested, Tensor
Expand Down Expand Up @@ -59,7 +57,6 @@ def _make_scan_minibatch_inputs(
param_noise_key: Tensor,
minibatch_size: int,
minibatch_index: int,
minibatch_partitioner: Optional[InputPartitionFn],
) -> tuple[Nested[Tensor], Tensor, Tensor]:
"""Creates minibatch inputs from inputs.
Expand All @@ -73,8 +70,6 @@ def _make_scan_minibatch_inputs(
param_noise_key: The `param_noise_key` from the ForwardFn inputs
minibatch_size: Size of the minibatch.
minibatch_index: Current scan minibatch index.
minibatch_partitioner: Applies sharding constraints
on each minibatch created.
Returns:
A tuple of minibatch inputs which of the same structure as `inputs`
Expand All @@ -90,7 +85,6 @@ def _make_scan_minibatch_inputs(
inputs["input_batch"],
)

minibatch_input = minibatch_partitioner(minibatch_input)
next_forward_key, forward_key = jax.random.split(forward_key)
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key)

Expand All @@ -107,7 +101,6 @@ def with_minibatch_steps(
steps: int,
metric_accumulator: ConfigOr[MetricAccumulator],
grad_dtype: Optional[jnp.dtype] = None,
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None,
) -> Callable[[ForwardFn], ForwardFn]:
"""Decorate a ForwardFn to accumulate gradients over minibatch steps.
Expand Down Expand Up @@ -136,32 +129,16 @@ def with_minibatch_steps(
TODO(cemkoc): Investigate the slight difference in loss curves when decorated.
A minibatch_partitioner is used to partition minibatch inputs to the original_func.
Note that if minibatch_partitioner is None, the default minibatch partitioner is used which
partitions the minibatch along (("data", "expert", "fsdp"), "seq"). Otherwise the
minibatch_partitioner passed in is used.
Args:
steps: Number of gradient accumulation steps.
metric_accumulator: A `MetricAccumulator` to accumulate minibatch summaries from the
forward output.
grad_dtype: Optional dtype to cast the grads back to after accumulating in fp32.
minibatch_partitioner: If not None, contains config for a partitioner that applies
additional sharding constraints on each minibatch created.
Returns:
Decorated ForwardFn.
"""

# Default partitioner for minibatches.
if not minibatch_partitioner:
minibatch_partitioner = partition_by_path_rank(
path_rank_to_partition={
(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"),
}
)

def decorator(fn: ForwardFn) -> ForwardFn:
# We define a positional arg only version of the original function
# that is passed because jax.value_and_grad does not accept
Expand Down Expand Up @@ -189,7 +166,6 @@ def fwd_helper(
and second is the accumulated grads (if `compute_grad` is True)
otherwise None.
"""
partitioner = maybe_instantiate(minibatch_partitioner)
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps)

# Create a sample minibatch for the carry buffer creation below
Expand All @@ -203,7 +179,6 @@ def fwd_helper(
param_noise_key=inputs["param_noise_key"],
minibatch_size=minibatch_size,
minibatch_index=0,
minibatch_partitioner=partitioner,
)

# Carry initialization for the lax.scan procedure. Since we are passing a
Expand Down Expand Up @@ -247,7 +222,6 @@ def scan_body(
param_noise_key=param_noise_key,
minibatch_size=minibatch_size,
minibatch_index=minibatch_index,
minibatch_partitioner=partitioner,
)
minibatch_args = (model_params, minibatch_inputs)

Expand Down
160 changes: 0 additions & 160 deletions axlearn/common/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
@@ -1,175 +1,15 @@
# Copyright © 2024 Apple Inc.
"""Test module for gradient_accumulation.py"""
from typing import Callable

import chex
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from absl.testing import absltest, parameterized
from jax.experimental.pjit import pjit

from axlearn.common import gradient_accumulation, test_utils
from axlearn.common.config import config_for_function
from axlearn.common.input_base import partition_by_path_rank
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import new_output_collection
from axlearn.common.update_transformation import ForwardOutputs
from axlearn.common.utils import Nested, PartitionSpec, Tensor, tree_paths


class TestMinibatchPartitioner(test_utils.TestCase):
"""Test `with_minibatch_steps` decorator argument minibatch_partitioner."""

def create_dummy_inputs(self, steps):
# Multiply by accumulation steps
self.batch_size = 4 * steps
self.seq_len = 8
self.params = dict(
w=jnp.asarray([0.0, 2.0, 2.0, -3.0]),
b=jnp.asarray([0.0, -1.0, 0.0, 0.0]),
)

self.input_batch = {
"input_ids": jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32),
"target_labels": jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32),
"target_num_bytes": jnp.ones((self.batch_size,), dtype=jnp.int32),
}
forward_key, param_noise_key = jax.random.split(jax.random.PRNGKey(0), 2)
self.inputs = dict(
input_batch=self.input_batch,
forward_key=forward_key,
param_noise_key=param_noise_key,
)

def create_loss_fn(self, expected_minibatch_sharding):
"""Simple ForwardFn with a check for minibatch sharding."""

def _check_equal_sharding(input_batch: Nested[Tensor], expected: dict):
"""Checks if sharding for input_batch matches expected."""

def callback_sharding(
*,
input_batch: Nested[Tensor],
callback: Callable[[str, jax.sharding.Sharding], None],
):
"""Invokes callback with the sharding.
The callback is invoked with (path: str, sharding: Sharding).
"""

def check_sharding(path, value):
jax.debug.inspect_array_sharding(
value, callback=lambda sharding: callback(path, sharding)
)

jax.tree_map(check_sharding, tree_paths(input_batch), input_batch)
return input_batch

callback = lambda path, sharding: self.assertEqual(expected[path], sharding.spec)

callback_sharding(
input_batch=input_batch,
callback=callback,
)

def loss_fn(*, model_params, inputs) -> ForwardOutputs:
"""Simple ForwardFn."""
_check_equal_sharding(
input_batch=inputs["input_batch"],
expected=expected_minibatch_sharding,
)
loss = -jax.nn.log_softmax(model_params["w"] + model_params["b"])[1]
output_collection = new_output_collection()
output_collection.state_updates["w"] = model_params["w"] + 1
output_collection.state_updates["loss"] = WeightedScalar(loss, 1)
return ForwardOutputs(loss=loss, aux={}, output_collection=output_collection)

return loss_fn

@pytest.mark.skipif(
jax.device_count() != 4 or jax.process_count() != 1,
reason=(
"Incorrect device & process count for mesh.\n"
"Use XLA_FLAGS=--xla_force_host_platform_device_count=4 to run locally."
),
)
@parameterized.named_parameters(
("one_step", 1), # no accumulation
("two_steps", 2),
("four_steps", 4),
)
def test_minibatch_partitioner_default(self, steps):
"""Tests grad accumulation with minibatch steps and default minibatch partitioner."""

# pylint: disable=too-many-function-args
with jax.sharding.Mesh(
devices=np.array(jax.devices()).reshape(1, 2, 1, 2)[..., None],
axis_names=("expert", "data", "fsdp", "seq", "model"),
):
self.create_dummy_inputs(steps)
loss_fn = self.create_loss_fn(
expected_minibatch_sharding={
"input_ids": PartitionSpec(("data"), "seq"),
"target_labels": PartitionSpec(("data"), "seq"),
"target_num_bytes": PartitionSpec(("data")),
},
)

loss_fn = gradient_accumulation.with_minibatch_steps(
steps=steps,
metric_accumulator=MetricAccumulator.default_config(),
minibatch_partitioner=None,
)(loss_fn)

pjit(loss_fn, in_shardings=None).lower(
model_params=self.params, inputs=self.inputs
).compile()

@pytest.mark.skipif(
jax.device_count() != 4 or jax.process_count() != 1,
reason=(
"Incorrect device & process count for mesh.\n"
"Use XLA_FLAGS=--xla_force_host_platform_device_count=4 to run locally."
),
)
@parameterized.named_parameters(
("one_step", 1), # no accumulation
("two_steps", 2),
("four_steps", 4),
)
def test_minibatch_partitioner_non_default(self, steps):
"""Tests grad accumulation with minibatch steps and a custom minibatch partitioner."""

with jax.sharding.Mesh(
devices=np.array(jax.devices()).reshape(2, 2)[..., None],
axis_names=("data", "seq", "model"),
):
self.create_dummy_inputs(steps)
loss_fn = self.create_loss_fn(
expected_minibatch_sharding={
"input_ids": PartitionSpec(("data", "seq")),
"target_labels": PartitionSpec(("data", "seq")),
"target_num_bytes": PartitionSpec(("data", "seq")),
},
)

loss_fn = gradient_accumulation.with_minibatch_steps(
steps=steps,
metric_accumulator=MetricAccumulator.default_config(),
minibatch_partitioner=config_for_function(partition_by_path_rank).set(
path_rank_to_partition={
# Shard batch dim on all available axis
(None, 1): PartitionSpec(("data", "seq")),
(None, 2): PartitionSpec(("data", "seq"), None),
}
),
)(loss_fn)

pjit(loss_fn, in_shardings=None).lower(
model_params=self.params, inputs=self.inputs
).compile()


class TestMinibatchSteps(test_utils.TestCase):
Expand Down
7 changes: 1 addition & 6 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, Optional, Sequence, Union
from typing import Dict, Sequence, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
Expand All @@ -16,7 +16,6 @@
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.input_base import InputPartitionFn
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec
Expand All @@ -33,19 +32,16 @@ class Config(ConfigModifier.Config):
grad_acc_steps: The number of steps to accumulate the gradients from mini-batches.
grad_acc_steps: The number of steps to accumulate the gradients from mini-batches.
metric_accumulator: The metric accumulator to export the metrics.
minibatch_partitioner: Constraints the minibatch to a PartitionSpec.
"""

grad_acc_steps: Required[int] = REQUIRED
metric_accumulator: MetricAccumulator.Config = MetricAccumulator.default_config()
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._grad_acc_steps = cfg.grad_acc_steps
self._metric_accumulator = cfg.metric_accumulator
self._minibatch_partitioner = cfg.minibatch_partitioner

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the forward_fn_transformation to accumulate gradients for grad_acc_steps steps.
Expand All @@ -68,7 +64,6 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
).set(
steps=self._grad_acc_steps,
metric_accumulator=self._metric_accumulator,
minibatch_partitioner=self._minibatch_partitioner,
)
return cfg

Expand Down

0 comments on commit 8c16718

Please sign in to comment.