Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and fixes to gradient accumulation #993

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions axlearn/common/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import PartitionSpec

from axlearn.common import utils
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.input_base import InputPartitionFn, partition_by_path_rank
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs
from axlearn.common.utils import Nested, Tensor, input_partition_spec, with_sharding_constraint
from axlearn.common.utils import Nested, Tensor


def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int:
Expand Down Expand Up @@ -57,39 +59,38 @@ def _make_scan_minibatch_inputs(
param_noise_key: Tensor,
minibatch_size: int,
minibatch_index: int,
minibatch_partitioner: Optional[InputPartitionFn],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echoing Kelvin's comment, could you explain concretely why we need this functionality? If it's just something that might be useful, maybe we can wait until we are certain that we will need it?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where gradient accumulation is not enabled, the inputs to the graph are sharded as per the policy in input_partitioner. This ensures the batch dimension is sharded on data, expert and fsdp axes while sequence dimension is replicated on model axis.

Gradient accumulation wraps the train steps in a scan loop, while the input_partitioner shards the input batch to correctly at first. In the gradient accumulation wrapper the input batches are resharded/overridden by the function _make_scan_minibatch_inputs and sharded along all axes available which is probably unexpected and inefficient. Minibatches should follow the same PartitionSpec as input_batches.

The addition of the minibatch_partitioner allows the minibatches to use the same sharding/PartitionSpec as input_partitioner provides in the input batches in the case gradient accumulation is not used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just preserve the sharding the input already has, would that also address the concern about the input sharding being changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah preserving sharding of the input and not having a sharding_constraint for minibatches would address the concern as well.

) -> tuple[Nested[Tensor], Tensor, Tensor]:
"""Creates minibatch inputs from inputs.

This is a utility function that is only meant to be called from
within a scan function body and is meant to slice the inputs
into `minibatch_size` sized slices to run the ForwardFn on.

Note that this only preserves the input sharding if the `input_partition_spec`
returns the correct partition spec to shard the input slices with.

Args:
inputs: Same pytree as ForwardFn inputs.
forward_key: The `forward_key` from the ForwardFn inputs
param_noise_key: The `param_noise_key` from the ForwardFn inputs
minibatch_size: Size of the minibatch.
minibatch_index: Current scan minibatch index.
minibatch_partitioner: Applies sharding constraints
on each minibatch created.

Returns:
A tuple of minibatch inputs which of the same structure as `inputs`
and new (carry) forward_key and param_noise_key.
"""
minibatch_input = with_sharding_constraint(
jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(
x,
start_index=minibatch_index * minibatch_size,
slice_size=minibatch_size,
axis=0,
),
inputs["input_batch"],
minibatch_input = jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(
x,
start_index=minibatch_index * minibatch_size,
slice_size=minibatch_size,
axis=0,
),
input_partition_spec(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, it seems rather a hack than a proper solution, that is, we want to have a different input_partition_spec() than the default one, then we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the default case, added it.

I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?

(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"), 

inputs["input_batch"],
Copy link
Contributor

@apghml apghml Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step).

Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?)

Maybe we should reshard the batch axis only?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can completely avoid the batch reshards using a reshape + transpose. I added it to the, PR let me know if it addresses your concerns.

Using the same example as @apghml:

Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step).
Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?)

Rather than using first 10 batches available in the global batch array for the first iteration, we construct the minibatch using the first batch from every device that is minibatch 0 =>[0, 10, 20 ....], minibatch 1 => [1, 11, 21, ...]. This is achieved using the reshape and transpose.

Essentially the logic here is to ensure each device uses local batches avoiding extra reshards.
This also scales well across multiple nodes as each node only runs a local reshape + transpose, also higher per device BS is also supported.

This should addresses the concerns around input batch reshards, let me know if there are still more concerns.

+1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size?

@kelvin-zou I can't think of a way to get size of a specific axis at runtime, but I do believe JAX should be able to give an informative error if the batch size % batch axis size != 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Can you add a test that fails without this fix?

)

minibatch_input = minibatch_partitioner(minibatch_input)
next_forward_key, forward_key = jax.random.split(forward_key)
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key)

Expand All @@ -106,6 +107,7 @@ def with_minibatch_steps(
steps: int,
metric_accumulator: ConfigOr[MetricAccumulator],
grad_dtype: Optional[jnp.dtype] = None,
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None,
) -> Callable[[ForwardFn], ForwardFn]:
"""Decorate a ForwardFn to accumulate gradients over minibatch steps.

Expand Down Expand Up @@ -134,16 +136,32 @@ def with_minibatch_steps(

TODO(cemkoc): Investigate the slight difference in loss curves when decorated.

A minibatch_partitioner is used to partition minibatch inputs to the original_func.
Note that if minibatch_partitioner is None, the default minibatch partitioner is used which
partitions the minibatch along (("data", "expert", "fsdp"), "seq"). Otherwise the
minibatch_partitioner passed in is used.

Args:
steps: Number of gradient accumulation steps.
metric_accumulator: A `MetricAccumulator` to accumulate minibatch summaries from the
forward output.
grad_dtype: Optional dtype to cast the grads back to after accumulating in fp32.
minibatch_partitioner: If not None, contains config for a partitioner that applies
additional sharding constraints on each minibatch created.

Returns:
Decorated ForwardFn.
"""

# Default partitioner for minibatches.
if not minibatch_partitioner:
minibatch_partitioner = partition_by_path_rank(
path_rank_to_partition={
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we default this to the same sharding the input is already using along all non-batch axes?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming if I read it correctly, we want to default to input_partition_specs from utils.py like before, and not what the input_partitioner sets.

Or the ask is to use the partition_by_path_rank to replicate what input_partition_specs was doing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. I was envisioning that for all axes other than axis 0, we default to whatever sharding the input already has. For axis 0, ideally we could also keep whatever sharding the input already has too, although I'm not sure that would work with logical batching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For axis 0, ideally we could also keep whatever sharding the input already has too, although I'm not sure that would work with logical batching

I think preserving the sharding of the input would be perfect, logical batching already inserts the correct sharding constraint after squeezing out the padded batches

(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"),
}
)

def decorator(fn: ForwardFn) -> ForwardFn:
# We define a positional arg only version of the original function
# that is passed because jax.value_and_grad does not accept
Expand Down Expand Up @@ -171,13 +189,29 @@ def fwd_helper(
and second is the accumulated grads (if `compute_grad` is True)
otherwise None.
"""
partitioner = maybe_instantiate(minibatch_partitioner)
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps)

# Create a sample minibatch for the carry buffer creation below
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain in more detail why this is needed?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

@apoorvtintin apoorvtintin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw broadcasting errors coming from the scan body, (example below), JAX complained that the carry buffer shape and the output of minibatch step are incompatible.

PS below error where acc=4 and full batch size is 32
TypeError: add got incompatible shapes for broadcasting: (32, 4096, 3072), (8, 4096, 3072).

The carry buffer initialization uses the full batch while creating the buffer, which does not match with the output of minibatch step since it would use the shapes of minibatch.

The simple fix for this is to use a minibatch sample for creating carry buffer ensuring it's shapes are same as the minibatch step.

Let me know if I missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why this issue wasn't causing errors before?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit test uses a toy model which does not have any metric/output that relies on batch size which is why it does not catch this issue. I dug a bit deeper and found that for fuji modelsoutput_collection/module_outputs/decoder/transformer/layer3/output carries batch dimension in output - ref below.

path (GetAttrKey(name='output_collection'), GetAttrKey(name='module_outputs'), DictKey(key='decoder'), DictKey(key='transformer'), DictKey(key='layer3'), DictKey(key='output')) shape (32, 4096, 3072)

(
sample_minibatch_inputs,
_,
_,
) = _make_scan_minibatch_inputs(
inputs,
forward_key=inputs["forward_key"],
param_noise_key=inputs["param_noise_key"],
minibatch_size=minibatch_size,
minibatch_index=0,
minibatch_partitioner=partitioner,
)

# Carry initialization for the lax.scan procedure. Since we are passing a
# `MetricAccumulator` into carry and carry input/output shapes must match
# we need initialize the `MetricAccumulator` summary with the right PyTree
# structure.
_, primal_output_shape = jax.eval_shape(
original_func_positional_args, model_params, inputs
original_func_positional_args, model_params, sample_minibatch_inputs
)
init_primal_out = jax.tree.map(jnp.zeros_like, primal_output_shape)
init_accumulator = maybe_instantiate(metric_accumulator)
Expand Down Expand Up @@ -213,6 +247,7 @@ def scan_body(
param_noise_key=param_noise_key,
minibatch_size=minibatch_size,
minibatch_index=minibatch_index,
minibatch_partitioner=partitioner,
)
minibatch_args = (model_params, minibatch_inputs)

Expand Down
161 changes: 161 additions & 0 deletions axlearn/common/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,175 @@
# Copyright © 2024 Apple Inc.
"""Test module for gradient_accumulation.py"""
from typing import Callable

import chex
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from absl.testing import absltest, parameterized
from jax.experimental.pjit import pjit

from axlearn.common import gradient_accumulation, test_utils
from axlearn.common.config import config_for_function
from axlearn.common.input_base import partition_by_path_rank
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import new_output_collection
from axlearn.common.update_transformation import ForwardOutputs
from axlearn.common.utils import Nested, PartitionSpec, Tensor, tree_paths


class TestMinibatchPartitioner(test_utils.TestCase):
"""Test `with_minibatch_steps` decorator argument minibatch_partitioner."""

def create_dummy_inputs(self, steps):
# Multiply by accumulation steps
self.batch_size = 4 * steps
self.seq_len = 8
self.params = dict(
w=jnp.asarray([0.0, 2.0, 2.0, -3.0]),
b=jnp.asarray([0.0, -1.0, 0.0, 0.0]),
)

self.input_batch = {
"input_ids": jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32),
"target_labels": jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32),
"target_num_bytes": jnp.ones((self.batch_size,), dtype=jnp.int32),
}
forward_key, param_noise_key = jax.random.split(jax.random.PRNGKey(0), 2)
self.inputs = dict(
input_batch=self.input_batch,
forward_key=forward_key,
param_noise_key=param_noise_key,
)

def create_loss_fn(self, expected_minibatch_sharding):
"""Simple ForwardFn with a check for minibatch sharding."""

def _check_equal_sharding(input_batch: Nested[Tensor], expected: dict):
"""Checks if sharding for input_batch matches expected."""

def callback_sharding(
*,
input_batch: Nested[Tensor],
callback: Callable[[str, jax.sharding.Sharding], None],
):
"""Invokes callback with the sharding.

The callback is invoked with (path: str, sharding: Sharding).
"""

def check_sharding(path, value):
jax.debug.inspect_array_sharding(
value, callback=lambda sharding: callback(path, sharding)
)

jax.tree_map(check_sharding, tree_paths(input_batch), input_batch)
return input_batch

callback = lambda path, sharding: self.assertEqual(expected[path], sharding.spec)

callback_sharding(
input_batch=input_batch,
callback=callback,
)

def loss_fn(*, model_params, inputs) -> ForwardOutputs:
"""Simple ForwardFn."""
_check_equal_sharding(
input_batch=inputs["input_batch"],
expected=expected_minibatch_sharding,
)
loss = -jax.nn.log_softmax(model_params["w"] + model_params["b"])[1]
output_collection = new_output_collection()
output_collection.state_updates["w"] = model_params["w"] + 1
output_collection.state_updates["loss"] = WeightedScalar(loss, 1)
return ForwardOutputs(loss=loss, aux={}, output_collection=output_collection)

return loss_fn

@pytest.mark.skipif(
jax.device_count() != 4 or jax.process_count() != 1,
reason=(
"Incorrect device & process count for mesh.\n"
"Use XLA_FLAGS=--xla_force_host_platform_device_count=4 to run locally."
),
)
@parameterized.named_parameters(
("one_step", 1), # no accumulation
("two_steps", 2),
("four_steps", 4),
)
def test_minibatch_partitioner_default(self, steps):
"""Tests grad accumulation with minibatch steps and default minibatch partitioner."""

# pylint: disable=too-many-function-args
with jax.sharding.Mesh(
devices=np.array(jax.devices()).reshape(1, 2, 1, 2)[..., None],
axis_names=("expert", "data", "fsdp", "seq", "model"),
):
self.create_dummy_inputs(steps)
loss_fn = self.create_loss_fn(
expected_minibatch_sharding={
"input_ids": PartitionSpec(("data"), "seq"),
"target_labels": PartitionSpec(("data"), "seq"),
"target_num_bytes": PartitionSpec(("data")),
},
)

loss_fn = gradient_accumulation.with_minibatch_steps(
steps=steps,
metric_accumulator=MetricAccumulator.default_config(),
minibatch_partitioner=None,
)(loss_fn)

pjit(loss_fn, in_shardings=None).lower(
model_params=self.params, inputs=self.inputs
).compile()

@pytest.mark.skipif(
jax.device_count() != 4 or jax.process_count() != 1,
reason=(
"Incorrect device & process count for mesh.\n"
"Use XLA_FLAGS=--xla_force_host_platform_device_count=4 to run locally."
),
)
@parameterized.named_parameters(
("one_step", 1), # no accumulation
("two_steps", 2),
("four_steps", 4),
)
def test_minibatch_partitioner_non_default(self, steps):
"""Tests grad accumulation with minibatch steps and a custom minibatch partitioner."""

with jax.sharding.Mesh(
devices=np.array(jax.devices()).reshape(2, 2)[..., None],
axis_names=("data", "seq", "model"),
):
self.create_dummy_inputs(steps)
loss_fn = self.create_loss_fn(
expected_minibatch_sharding={
"input_ids": PartitionSpec(("data", "seq")),
"target_labels": PartitionSpec(("data", "seq")),
"target_num_bytes": PartitionSpec(("data", "seq")),
},
)

loss_fn = gradient_accumulation.with_minibatch_steps(
steps=steps,
metric_accumulator=MetricAccumulator.default_config(),
minibatch_partitioner=config_for_function(partition_by_path_rank).set(
path_rank_to_partition={
# Shard batch dim on all available axis
(None, 1): PartitionSpec(("data", "seq")),
(None, 2): PartitionSpec(("data", "seq"), None),
}
),
)(loss_fn)

pjit(loss_fn, in_shardings=None).lower(
model_params=self.params, inputs=self.inputs
).compile()


class TestMinibatchSteps(test_utils.TestCase):
Expand Down
8 changes: 7 additions & 1 deletion axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, Sequence, Union
from typing import Dict, Optional, Sequence, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
Expand All @@ -16,6 +16,7 @@
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.input_base import InputPartitionFn
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec
Expand All @@ -29,18 +30,22 @@ class Config(ConfigModifier.Config):
"""Configure GradientAccumulationModifier.

Attributes:
grad_acc_steps: The number of steps to accumulate the gradients from mini-batches.
grad_acc_steps: The number of steps to accumulate the gradients from mini-batches.
metric_accumulator: The metric accumulator to export the metrics.
minibatch_partitioner: Constraints the minibatch to a PartitionSpec.
"""

grad_acc_steps: Required[int] = REQUIRED
metric_accumulator: MetricAccumulator.Config = MetricAccumulator.default_config()
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._grad_acc_steps = cfg.grad_acc_steps
self._metric_accumulator = cfg.metric_accumulator
self._minibatch_partitioner = cfg.minibatch_partitioner

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the forward_fn_transformation to accumulate gradients for grad_acc_steps steps.
Expand All @@ -63,6 +68,7 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
).set(
steps=self._grad_acc_steps,
metric_accumulator=self._metric_accumulator,
minibatch_partitioner=self._minibatch_partitioner,
)
return cfg

Expand Down
3 changes: 3 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def get_trainer_kwargs(
),
*trn2_config.module_modifications,
*trn2_config.partition_spec_modifications,
GradientAccumulationModifier.default_config().set(
grad_acc_steps=4,
),
],
),
),
Expand Down