-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements and fixes to gradient accumulation #993
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,14 @@ | |
import jax | ||
import numpy as np | ||
from jax import numpy as jnp | ||
from jax.sharding import PartitionSpec | ||
|
||
from axlearn.common import utils | ||
from axlearn.common.config import ConfigOr, maybe_instantiate | ||
from axlearn.common.input_base import InputPartitionFn, partition_by_path_rank | ||
from axlearn.common.metrics import MetricAccumulator | ||
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs | ||
from axlearn.common.utils import Nested, Tensor, input_partition_spec, with_sharding_constraint | ||
from axlearn.common.utils import Nested, Tensor | ||
|
||
|
||
def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int: | ||
|
@@ -57,39 +59,38 @@ def _make_scan_minibatch_inputs( | |
param_noise_key: Tensor, | ||
minibatch_size: int, | ||
minibatch_index: int, | ||
minibatch_partitioner: Optional[InputPartitionFn], | ||
) -> tuple[Nested[Tensor], Tensor, Tensor]: | ||
"""Creates minibatch inputs from inputs. | ||
|
||
This is a utility function that is only meant to be called from | ||
within a scan function body and is meant to slice the inputs | ||
into `minibatch_size` sized slices to run the ForwardFn on. | ||
|
||
Note that this only preserves the input sharding if the `input_partition_spec` | ||
returns the correct partition spec to shard the input slices with. | ||
|
||
Args: | ||
inputs: Same pytree as ForwardFn inputs. | ||
forward_key: The `forward_key` from the ForwardFn inputs | ||
param_noise_key: The `param_noise_key` from the ForwardFn inputs | ||
minibatch_size: Size of the minibatch. | ||
minibatch_index: Current scan minibatch index. | ||
minibatch_partitioner: Applies sharding constraints | ||
on each minibatch created. | ||
|
||
Returns: | ||
A tuple of minibatch inputs which of the same structure as `inputs` | ||
and new (carry) forward_key and param_noise_key. | ||
""" | ||
minibatch_input = with_sharding_constraint( | ||
jax.tree.map( | ||
lambda x: jax.lax.dynamic_slice_in_dim( | ||
x, | ||
start_index=minibatch_index * minibatch_size, | ||
slice_size=minibatch_size, | ||
axis=0, | ||
), | ||
inputs["input_batch"], | ||
minibatch_input = jax.tree.map( | ||
lambda x: jax.lax.dynamic_slice_in_dim( | ||
x, | ||
start_index=minibatch_index * minibatch_size, | ||
slice_size=minibatch_size, | ||
axis=0, | ||
), | ||
input_partition_spec(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me, it seems rather a hack than a proper solution, that is, we want to have a different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I missed the default case, added it. I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?
|
||
inputs["input_batch"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step). Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?) Maybe we should reshard the batch axis only? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can completely avoid the batch reshards using a reshape + transpose. I added it to the, PR let me know if it addresses your concerns. Using the same example as @apghml:
Rather than using first 10 batches available in the global batch array for the first iteration, we construct the minibatch using the first batch from every device that is minibatch 0 =>[0, 10, 20 ....], minibatch 1 => [1, 11, 21, ...]. This is achieved using the reshape and transpose. Essentially the logic here is to ensure each device uses local batches avoiding extra reshards. This should addresses the concerns around input batch reshards, let me know if there are still more concerns.
@kelvin-zou I can't think of a way to get size of a specific axis at runtime, but I do believe JAX should be able to give an informative error if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation. Can you add a test that fails without this fix? |
||
) | ||
|
||
minibatch_input = minibatch_partitioner(minibatch_input) | ||
next_forward_key, forward_key = jax.random.split(forward_key) | ||
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key) | ||
|
||
|
@@ -106,6 +107,7 @@ def with_minibatch_steps( | |
steps: int, | ||
metric_accumulator: ConfigOr[MetricAccumulator], | ||
grad_dtype: Optional[jnp.dtype] = None, | ||
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None, | ||
) -> Callable[[ForwardFn], ForwardFn]: | ||
"""Decorate a ForwardFn to accumulate gradients over minibatch steps. | ||
|
||
|
@@ -134,16 +136,32 @@ def with_minibatch_steps( | |
|
||
TODO(cemkoc): Investigate the slight difference in loss curves when decorated. | ||
|
||
A minibatch_partitioner is used to partition minibatch inputs to the original_func. | ||
Note that if minibatch_partitioner is None, the default minibatch partitioner is used which | ||
partitions the minibatch along (("data", "expert", "fsdp"), "seq"). Otherwise the | ||
minibatch_partitioner passed in is used. | ||
|
||
Args: | ||
steps: Number of gradient accumulation steps. | ||
metric_accumulator: A `MetricAccumulator` to accumulate minibatch summaries from the | ||
forward output. | ||
grad_dtype: Optional dtype to cast the grads back to after accumulating in fp32. | ||
minibatch_partitioner: If not None, contains config for a partitioner that applies | ||
additional sharding constraints on each minibatch created. | ||
|
||
Returns: | ||
Decorated ForwardFn. | ||
""" | ||
|
||
# Default partitioner for minibatches. | ||
if not minibatch_partitioner: | ||
minibatch_partitioner = partition_by_path_rank( | ||
path_rank_to_partition={ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we default this to the same sharding the input is already using along all non-batch axes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just confirming if I read it correctly, we want to default to input_partition_specs from Or the ask is to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly. I was envisioning that for all axes other than axis 0, we default to whatever sharding the input already has. For axis 0, ideally we could also keep whatever sharding the input already has too, although I'm not sure that would work with logical batching. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think preserving the sharding of the input would be perfect, logical batching already inserts the correct sharding constraint after squeezing out the padded batches |
||
(None, 1): PartitionSpec(("data", "expert", "fsdp")), | ||
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"), | ||
} | ||
) | ||
|
||
def decorator(fn: ForwardFn) -> ForwardFn: | ||
# We define a positional arg only version of the original function | ||
# that is passed because jax.value_and_grad does not accept | ||
|
@@ -171,13 +189,29 @@ def fwd_helper( | |
and second is the accumulated grads (if `compute_grad` is True) | ||
otherwise None. | ||
""" | ||
partitioner = maybe_instantiate(minibatch_partitioner) | ||
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps) | ||
|
||
# Create a sample minibatch for the carry buffer creation below | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain in more detail why this is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw broadcasting errors coming from the scan body, (example below), JAX complained that the carry buffer shape and the output of minibatch step are incompatible. PS below error where acc=4 and full batch size is 32 The carry buffer initialization uses the full batch while creating the buffer, which does not match with the output of minibatch step since it would use the shapes of minibatch. The simple fix for this is to use a minibatch sample for creating carry buffer ensuring it's shapes are same as the minibatch step. Let me know if I missed something. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know why this issue wasn't causing errors before? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The unit test uses a toy model which does not have any metric/output that relies on batch size which is why it does not catch this issue. I dug a bit deeper and found that for fuji models
|
||
( | ||
sample_minibatch_inputs, | ||
_, | ||
_, | ||
) = _make_scan_minibatch_inputs( | ||
inputs, | ||
forward_key=inputs["forward_key"], | ||
param_noise_key=inputs["param_noise_key"], | ||
minibatch_size=minibatch_size, | ||
minibatch_index=0, | ||
minibatch_partitioner=partitioner, | ||
) | ||
|
||
# Carry initialization for the lax.scan procedure. Since we are passing a | ||
# `MetricAccumulator` into carry and carry input/output shapes must match | ||
# we need initialize the `MetricAccumulator` summary with the right PyTree | ||
# structure. | ||
_, primal_output_shape = jax.eval_shape( | ||
original_func_positional_args, model_params, inputs | ||
original_func_positional_args, model_params, sample_minibatch_inputs | ||
) | ||
init_primal_out = jax.tree.map(jnp.zeros_like, primal_output_shape) | ||
init_accumulator = maybe_instantiate(metric_accumulator) | ||
|
@@ -213,6 +247,7 @@ def scan_body( | |
param_noise_key=param_noise_key, | ||
minibatch_size=minibatch_size, | ||
minibatch_index=minibatch_index, | ||
minibatch_partitioner=partitioner, | ||
) | ||
minibatch_args = (model_params, minibatch_inputs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Echoing Kelvin's comment, could you explain concretely why we need this functionality? If it's just something that might be useful, maybe we can wait until we are certain that we will need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case where gradient accumulation is not enabled, the inputs to the graph are sharded as per the policy in input_partitioner. This ensures the batch dimension is sharded on data, expert and fsdp axes while sequence dimension is replicated on model axis.
Gradient accumulation wraps the train steps in a scan loop, while the input_partitioner shards the input batch to correctly at first. In the gradient accumulation wrapper the input batches are resharded/overridden by the function _make_scan_minibatch_inputs and sharded along all axes available which is probably unexpected and inefficient. Minibatches should follow the same PartitionSpec as input_batches.
The addition of the minibatch_partitioner allows the minibatches to use the same sharding/PartitionSpec as
input_partitioner
provides in the input batches in the case gradient accumulation is not used.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we just preserve the sharding the input already has, would that also address the concern about the input sharding being changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah preserving sharding of the input and not having a
sharding_constraint
for minibatches would address the concern as well.