-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements and fixes to gradient accumulation #993
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from axlearn.common.config import ConfigOr, maybe_instantiate | ||
from axlearn.common.metrics import MetricAccumulator | ||
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs | ||
from axlearn.common.utils import Nested, Tensor, input_partition_spec, with_sharding_constraint | ||
from axlearn.common.utils import Nested, Tensor | ||
|
||
|
||
def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int: | ||
|
@@ -55,41 +55,29 @@ def _make_scan_minibatch_inputs( | |
*, | ||
forward_key: Tensor, | ||
param_noise_key: Tensor, | ||
minibatch_size: int, | ||
minibatch_index: int, | ||
) -> tuple[Nested[Tensor], Tensor, Tensor]: | ||
"""Creates minibatch inputs from inputs. | ||
|
||
This is a utility function that is only meant to be called from | ||
within a scan function body and is meant to slice the inputs | ||
into `minibatch_size` sized slices to run the ForwardFn on. | ||
|
||
Note that this only preserves the input sharding if the `input_partition_spec` | ||
returns the correct partition spec to shard the input slices with. | ||
within a scan function body and is meant to return sliced minibatches | ||
to run the ForwardFn on. | ||
|
||
Args: | ||
inputs: Same pytree as ForwardFn inputs. | ||
forward_key: The `forward_key` from the ForwardFn inputs | ||
param_noise_key: The `param_noise_key` from the ForwardFn inputs | ||
minibatch_size: Size of the minibatch. | ||
minibatch_index: Current scan minibatch index. | ||
|
||
Returns: | ||
A tuple of minibatch inputs which of the same structure as `inputs` | ||
and new (carry) forward_key and param_noise_key. | ||
""" | ||
minibatch_input = with_sharding_constraint( | ||
jax.tree.map( | ||
lambda x: jax.lax.dynamic_slice_in_dim( | ||
x, | ||
start_index=minibatch_index * minibatch_size, | ||
slice_size=minibatch_size, | ||
axis=0, | ||
), | ||
inputs["input_batch"], | ||
), | ||
input_partition_spec(), | ||
minibatch_input = jax.tree.map( | ||
lambda x: x[minibatch_index], | ||
inputs["input_batch"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step). Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?) Maybe we should reshard the batch axis only? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can completely avoid the batch reshards using a reshape + transpose. I added it to the, PR let me know if it addresses your concerns. Using the same example as @apghml:
Rather than using first 10 batches available in the global batch array for the first iteration, we construct the minibatch using the first batch from every device that is minibatch 0 =>[0, 10, 20 ....], minibatch 1 => [1, 11, 21, ...]. This is achieved using the reshape and transpose. Essentially the logic here is to ensure each device uses local batches avoiding extra reshards. This should addresses the concerns around input batch reshards, let me know if there are still more concerns.
@kelvin-zou I can't think of a way to get size of a specific axis at runtime, but I do believe JAX should be able to give an informative error if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation. Can you add a test that fails without this fix? |
||
) | ||
|
||
next_forward_key, forward_key = jax.random.split(forward_key) | ||
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key) | ||
|
||
|
@@ -172,12 +160,56 @@ def fwd_helper( | |
otherwise None. | ||
""" | ||
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps) | ||
|
||
def reshape_for_scan(x: Tensor): | ||
"""Helper function that adds a minibatch dimension while evenly dividing | ||
batches across gradient accumulation iterations. | ||
|
||
Input dimension is [GBS, seq], this first reshaped to [MBS, steps, seq], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace the acronyms with full names? |
||
then transposed to [steps, MBS, seq] this ensures that batches picked | ||
up from the global batch in a staggered pattern. | ||
|
||
The main benefit is that this avoids extra communication incurred in reshard | ||
for every minibatch. | ||
|
||
Args: | ||
x: Tensor to be reshaped. | ||
|
||
Returns: | ||
The reshaped tensor. | ||
""" | ||
if x.shape[0] % minibatch_size != 0: | ||
raise ValueError( | ||
f"minibatch_size {minibatch_size} does not evenly divide " | ||
f"global batch size of {x.shape[0]}" | ||
) | ||
|
||
x = x.reshape(minibatch_size, -1, *x.shape[1:]) | ||
# Set up transpose to swap the first two dimensions. | ||
dims = list(range(x.ndim)) | ||
dims[0], dims[1] = dims[1], dims[0] | ||
return x.transpose(dims) | ||
Comment on lines
+188
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we replace these three lines with one line if we use |
||
|
||
inputs["input_batch"] = jax.tree_map(reshape_for_scan, inputs["input_batch"]) | ||
|
||
# Create a sample minibatch for the carry buffer creation below | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain in more detail why this is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw broadcasting errors coming from the scan body, (example below), JAX complained that the carry buffer shape and the output of minibatch step are incompatible. PS below error where acc=4 and full batch size is 32 The carry buffer initialization uses the full batch while creating the buffer, which does not match with the output of minibatch step since it would use the shapes of minibatch. The simple fix for this is to use a minibatch sample for creating carry buffer ensuring it's shapes are same as the minibatch step. Let me know if I missed something. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know why this issue wasn't causing errors before? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The unit test uses a toy model which does not have any metric/output that relies on batch size which is why it does not catch this issue. I dug a bit deeper and found that for fuji models
|
||
( | ||
sample_minibatch_inputs, | ||
_, | ||
_, | ||
) = _make_scan_minibatch_inputs( | ||
inputs, | ||
forward_key=inputs["forward_key"], | ||
param_noise_key=inputs["param_noise_key"], | ||
minibatch_index=0, | ||
) | ||
|
||
# Carry initialization for the lax.scan procedure. Since we are passing a | ||
# `MetricAccumulator` into carry and carry input/output shapes must match | ||
# we need initialize the `MetricAccumulator` summary with the right PyTree | ||
# structure. | ||
_, primal_output_shape = jax.eval_shape( | ||
original_func_positional_args, model_params, inputs | ||
original_func_positional_args, model_params, sample_minibatch_inputs | ||
) | ||
init_primal_out = jax.tree.map(jnp.zeros_like, primal_output_shape) | ||
init_accumulator = maybe_instantiate(metric_accumulator) | ||
|
@@ -211,7 +243,6 @@ def scan_body( | |
inputs, | ||
forward_key=forward_key, | ||
param_noise_key=param_noise_key, | ||
minibatch_size=minibatch_size, | ||
minibatch_index=minibatch_index, | ||
) | ||
minibatch_args = (model_params, minibatch_inputs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it seems rather a hack than a proper solution, that is, we want to have a different
input_partition_spec()
than the default one, then we need this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed the default case, added it.
I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?