diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3c57e2bb28d..209b035f5e3 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -94,6 +94,8 @@ title: FSDP1 vs FSDP2 - local: concept_guides/context_parallelism title: Context parallelism + - local: concept_guides/sequence_parallelism + title: Sequence parallelism - local: concept_guides/low_precision_training title: Low precision training methods - local: concept_guides/training_tpu diff --git a/docs/source/concept_guides/context_parallelism.md b/docs/source/concept_guides/context_parallelism.md index b628e18b8c4..44a6a4b7b7b 100644 --- a/docs/source/concept_guides/context_parallelism.md +++ b/docs/source/concept_guides/context_parallelism.md @@ -17,6 +17,8 @@ rendered properly in your Markdown viewer. This guide will cover basics of using context parallelism in 🤗`accelerate`, for the more curious readers, we will also cover some technicalities in the later sections. +See also the very related [Guide to Sequence Parallellism](./sequence_parallelism.md). + ## Why context parallelism? With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences. @@ -176,8 +178,8 @@ You can directly see this issue in the profiler output in the image below: ## Why only FSDP2? -We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to -utilize its full potential. +We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to +utilize its full potential. How it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings. This is a "free lunch" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below. diff --git a/docs/source/concept_guides/sequence_parallelism.md b/docs/source/concept_guides/sequence_parallelism.md new file mode 100644 index 00000000000..9eb7262f38c --- /dev/null +++ b/docs/source/concept_guides/sequence_parallelism.md @@ -0,0 +1,219 @@ + + +# Sequence parallel in 🤗`accelerate` + +This guide will cover basics of using sequence parallelism in 🤗`accelerate`. + +See also the very related [Context Parallellism](./context_parallelism.md). + +## Why sequence parallelism? + +With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences. +With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable. + +Ulysses Sequence parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention normally, but using only a slice of attention heads on each GPU. With this, we can train models with long sequences, with a few more tools, scaling to 15M+ sequence length. To see how to augment Ulysses SP with TiledMLP, Liger-Kernel, Activation checkpoint offload to cpu and a few other tricks pleae refer to the paper: [Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences](https://arxiv.org/abs/2506.13996). + +## How is Ulysses SP different from FSDP CP + +In the document [Context Parallellism](./context_parallelism.md) you can learn about deploying another technology called Context Parallelism, which too slices on the sequence dimension but uses Ring Attention instead of slicing on the head dimension. + +The following articles go into a very detailed explanation of the differences between the two technologies: +- https://insujang.github.io/2024-01-11/tensor-parallelism-and-sequence-parallelism-detailed-analysis/ +- https://huggingface.co/blog/exploding-gradients/ulysses-ring-attention + +A quick summary adapting from one of the articles: +- Ulysses SP has a relatively low communication overhead, but is limited by the number of Attention Heads and thus it has certain requirements for network topology (number of attention heads has has to be divisible by the number of participating gpus for a single replica). All-to-all communication is sensitive to latency and it requires Deepspeed. +- FSDP CP Ring-Attention's P2P ring communication has no aforementioned divisibilty requirements, but has a higher communication volume. + +Finally it should be possible to combine SP + CP as explained in the paper [USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719) to support an even longer sequence length, albeit this is not yet integrated into 🤗`accelerate`. + + +## Supported sequence parallelism backends + +Currently the only sequence parallelism backend is `deepspeed`, which comes from the modernized Ulysses SP which is part of the [Arctic Long Sequence Training technology](https://arxiv.org/abs/2506.13996). There is also a [tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/) should you want to integrate it into your own code directly. + +## How to use sequence parallelism? + +```diff +from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig + ++# Example: 4 GPUs with sp_size=4, dp_shard_size=1 ++# Ensure: dp_replicate_size × dp_shard_size × sp_size = 1 × 1 × 4 = 4 GPUs +parallelism_config = ParallelismConfig( ++ sp_backend="deepspeed", ++ sp_size=4, ++ dp_shard_size=1, # Explicit: no data parallelism ++ sp_handler=DeepSpeedSequenceParallelConfig( ++ sp_seq_length_is_variable: true, ++ sp_attn_implementation="sdpa", ++ ), ++ ) + +accelerator = Accelerator( + ..., + parallelism_config=parallelism_config, +) +``` + +As with any other feature in 🤗`accelerate`, you can enable sequence parallelism also by passing the corresponding flags to `accelerate launch`. In this case, it's no different: + +```bash +accelerate launch --parallelism-config-sp-size 8 ... +``` + +> [!Tip] +> You can also set the `sp_size` and other configuration in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script. + +> [!Tip] +> sequence parallelism combines with data parallelism. It doesn't require additional GPUs. +> So if you have 8 gpus you can do: `--parallelism-config-dp-shard-size 8 --parallelism-config-sp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically. +> +> **Important**: You must ensure `dp_replicate_size × dp_shard_size × sp_size = num_processes`. For example, with 8 GPUs and `sp_size=8`, you need `dp_shard_size=1` (since 1 × 1 × 8 = 8). With 4 GPUs and `sp_size=2`, you could use `dp_shard_size=2` (since 1 × 2 × 2 = 4) for 2D parallelism. + + +## ALST/Ulysses SP backend configuration + +ALST/UlyssesSP implements sequence parallelism using attention head parallelism, as explained in [this paper](https://arxiv.org/abs/2506.13996). For simplicity, we reuse the concept and setup of sequence parallelism, which, from the user's perspective, is the same: multiple GPUs are used to process a single batch. + +To give a sense of what ALST made possible - it allowed us to train in bf16 with 500K tokens on a single H100 GPU, 3.7M on a single node, and 15M on Llama-8B using just four nodes. This feature of HF Accelerate enables only 1 of the 3 ALST components, so the achievable sequence length will be smaller. You'd want TiledMLP, Activation checkpoint offload to CPU, and a few other things enabled to get the full power of ALST. For details, please refer to [this tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/). + +To configure the `deepspeed` backend: + +```python +# Example: 4 GPUs with sp_size=4, dp_shard_size=1 +# Ensure: dp_replicate_size × dp_shard_size × sp_size = 1 × 1 × 4 = 4 GPUs +parallelism_config = ParallelismConfig( + sp_backend="deepspeed", + sp_size=4, + dp_shard_size=1, # Explicit: no data parallelism + sp_handler=DeepSpeedSequenceParallelConfig( + sp_seq_length=256, + sp_seq_length_is_variable=True, + sp_attn_implementation="sdpa", + ), +) +accelerator = Accelerator( + ..., + parallelism_config=parallelism_config, +) +``` + +- `sp_backend`: set to `deepspeed` here +- `sp_size` is the degree of the sequence parallelism - in the above example it's 4, therefore 4 gpus will be used to process a single batch (while doing DP=4 over the same gpus) +- `sp_seq_length` and `sp_seq_length_is_variable` are used to deal with sequence lengths. If `sp_seq_length_is_variable=True` the backend will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by the sequence parallel degree or not set at all. In this case on every `forward` the sequence variables will be derived from input. If `False` then `seq_length` needs to match the batch's sequence length dimension, which then will have to be padded to be always the same. The default is `True`. +- `sp_attn_implementation` is one of `sdpa`, `flash_attention_2` or `flash_attention_3`. This sequence parallel implementation uses `position_ids` instead of `attention_mask` therefore, `eager` can't work here until it supports working with `position_ids`. Also, please note that `sdpa` doesn't handle multiple samples combined into one correctly; it will attend to the whole sample as one. If the samples aren't combined, `sdpa` will work correctly. Therefore, Flash Attention should be the ideal choice as it always works. + +Instead of setting these values in `DeepSpeedSequenceParallelConfig` object, you can also use the environment variables to accomplish the same - here they are correspondingly to the end of the list above. +- `PARALLELISM_CONFIG_SP_BACKEND` +- `PARALLELISM_CONFIG_SP_SEQ_LENGTH` +- `PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE` +- `PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION` + +If not passed in the code, `sp_size` can be set via `--parallelism_config_sp_size` CLI argument. Same for other arguments. You can also do the accelerate config file style config, e.g., for 2 GPUs: + +```yaml +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: path/to/ds_config.json +machine_rank: 0 +num_machines: 1 +num_processes: 2 +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 1 # Must satisfy: 1 × 1 × 2 = 2 num_processes + parallelism_config_sp_size: 2 + parallelism_config_sp_backend: deepspeed + parallelism_config_sp_seq_length_is_variable: true + parallelism_config_sp_attn_implementation: sdpa + +``` + +As mentioned earlier Ulysses sequence parallelism is normally overlayed with data parallelism - same ranks are used for feeding unique data streams and also perform Ulysses Sequence Parallelism. But you could also create replicas like so: + +```python +# Example: 4 GPUs with 2D parallelism (SP=2, DP=2) +# Ensure: dp_replicate_size × dp_shard_size × sp_size = 2 × 1 × 2 = 4 GPUs +parallelism_config = ParallelismConfig( + dp_replicate_size=2, + dp_shard_size=1, # Explicit: no sharding within replicas + sp_size=2, + sp_backend="deepspeed", + sp_handler=DeepSpeedSequenceParallelConfig(...), +) +``` +Here we use 4 gpus, with 2 sequence parallelism replicas. Deepspeed-ZeRO is what drives the data parallelism here. + +Please note that a lot of magic is hidden inside [UlyssesSPDataLoaderAdapter](https://github.com/deepspeedai/DeepSpeed/blob/64c0052fa08438b4ecf4cae30af15091a92d2108/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L442). It's used behind the scenes, wrapping your original DataLoader object, but you should be aware of it should you run into any problems. It also automatically injects the correct `shift_labels` into the batch dictionary, before the batch gets sharded across the participating ranks. + +Now the only remaining piece to start using ALST/UlyssesSP is to aggregate the loss across ranks using a differentiable `all_gather` to get the grads right. The following code does it, while also excluding any masked out with `-100` tokens, to get the correct average: + +```python +sp_size = parallelism_config.sp_size if parallelism_config is not None else 1 +if sp_size > 1: + sp_group = accelerator.torch_device_mesh["sp"].get_group() + sp_world_size = parallelism_config.sp_size + +# Normal training loop +for iter, batch in enumerate(dl): + optimizer.zero_grad() + + batch = move_to_device(batch, model.device) + outputs = model(**batch) + + # only if not using liger-kernel + shift_labels = batch["shift_labels"] + loss = unwrapped_model.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=unwrapped_model.config.vocab_size, + ) + + if sp_size > 1: + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) + # special dealing with SFT that has prompt tokens that aren't used in loss computation + good_tokens = (shift_labels != -100).view(-1).sum() + good_tokens_per_rank = torch.distributed.nn.functional.all_gather( + good_tokens, group=sp_group + ) + # Skip ranks with zero valid tokens to avoid NaN contamination (NaN * 0 = NaN) + total_loss = sum( + losses_per_rank[rank] * good_tokens_per_rank[rank] + for rank in range(sp_world_size) + if good_tokens_per_rank[rank] > 0 + ) + total_good_tokens = sum(good_tokens_per_rank) + loss = total_loss / max(total_good_tokens, 1) + + if rank == 0: accelerator.print(f"{iter}: {loss=}") + accelerator.log(dict(train_loss=loss, step=iter)) + + accelerator.backward(loss) + optimizer.step() +``` + +If you use [Liger Kernel](https://github.com/linkedin/Liger-Kernel) it already knows how to handle `shift_labels` so you don't need to go through manual loss calculation, just calling `model(**batch)` will already get the `loss` calculated and done in a very memory-efficient way. If you didn't know about Liger-Kernel - it's highly recommended to be used especially for long sequence length, since it liberates a lot of working GPU memory that can be used for handling longer sequences. For example, it performs a fused logit-loss computation, never manifesting the full logits tensor in memory. + +If you want to see what HF Accelerate did behind the scenes please read [this full integration tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/). + +For an example of an Accelerate training loop with enabled ALST/UlyssesSP see [examples/alst_ulysses_sequence_parallelism](https://github.com/huggingface/accelerate/blob/main/examples/alst_ulysses_sequence_parallelism). + +[!Warning] +> This API is quite new and still in its experimental stage. While we strive to provide a stable API, some small parts of the public API may change in the future. + +Since this is a Deepspeed backend the usual Deepspeed configuration applies, so you can combine sequence parallelism with optimizer states and/or weights offloading as well to liberate more gpu memory and enable an even longer sequence length. This technology has been tested to work with DeepSpeed ZeRO stage 2 and 3. + diff --git a/examples/alst_ulysses_sequence_parallelism/README.md b/examples/alst_ulysses_sequence_parallelism/README.md new file mode 100644 index 00000000000..dc4ec590f85 --- /dev/null +++ b/examples/alst_ulysses_sequence_parallelism/README.md @@ -0,0 +1,19 @@ +# Deepspeed's ALST/Ulysses sequence parallelism + +This is an example of the use of Ulysses Sequence Parallelism, which uses attention head parallelism and is part of the Arctic Long Sequence Training project at [ArcticTraining](https://github.com/snowflakedb/ArcticTraining). [This paper](https://arxiv.org/abs/2506.13996) goes into the details of this protocol. + +For nuances of usage please refer to the main HF Accelerate tutorial on [Context Parallelism](https://huggingface.co/docs/accelerate/en/concept_guides/context_parallelism). + +You need to use at least `2` gpus to enable ALST/Ulysses sequence parallelism. + +To run the example with `4` gpus: + +```bash +bash ./sp-alst.sh +``` + +Change `4` to the desired sequence parallelism degree in these 2 files: +``` +sp-alst.accelerate-config.yml:num_processes: 4 +sp-alst.py: sp_size=4, +``` diff --git a/examples/alst_ulysses_sequence_parallelism/sp-alst.accelerate-config.yml b/examples/alst_ulysses_sequence_parallelism/sp-alst.accelerate-config.yml new file mode 100644 index 00000000000..36bcafcf070 --- /dev/null +++ b/examples/alst_ulysses_sequence_parallelism/sp-alst.accelerate-config.yml @@ -0,0 +1,12 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_config_file: sp-alst.ds-config.json + zero3_init_flag: false +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/examples/alst_ulysses_sequence_parallelism/sp-alst.ds-config.json b/examples/alst_ulysses_sequence_parallelism/sp-alst.ds-config.json new file mode 100644 index 00000000000..3f8b0103f8c --- /dev/null +++ b/examples/alst_ulysses_sequence_parallelism/sp-alst.ds-config.json @@ -0,0 +1,12 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3 + }, + "gradient_accumulation_steps": 1, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "seq_parallel_communication_data_type": "bf16" +} \ No newline at end of file diff --git a/examples/alst_ulysses_sequence_parallelism/sp-alst.py b/examples/alst_ulysses_sequence_parallelism/sp-alst.py new file mode 100644 index 00000000000..87c228ca063 --- /dev/null +++ b/examples/alst_ulysses_sequence_parallelism/sp-alst.py @@ -0,0 +1,155 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from deepspeed.runtime.utils import move_to_device +from transformers import AutoModelForCausalLM, AutoTokenizer + +from accelerate import Accelerator +from accelerate.utils import ParallelismConfig, set_seed +from accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig + + +set_seed(42) + +model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +# to run the example faster switch to the random model +# model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + +micro_batch_size = 1 + +parallelism_config = ParallelismConfig( + sp_backend="deepspeed", + sp_size=4, + sp_handler=DeepSpeedSequenceParallelConfig( + sp_seq_length=256, + sp_seq_length_is_variable=True, + sp_attn_implementation="sdpa", + ), +) + +accelerator = Accelerator( + parallelism_config=parallelism_config, + # log_with="wandb", # enable to log into wandb +) +accelerator.init_trackers( + project_name="ulysses-accelerate", + config={}, + init_kwargs={"wandb": dict(entity="yak", name="deepspeed")}, +) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name) + +# 2 quick rough datasets to demonstrate the workings +if 1: # real dataset + from datasets import load_dataset + + ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:12]") + + # this is a quick example, it should be made more efficient to be used in real application + def convert(ex): + texts = tokenizer.apply_chat_template(conversation=ex["messages"], tokenize=False) + tokenized_dict = tokenizer(texts, max_length=256, padding=True, truncation=True) + return tokenized_dict + + ds = ds.map(convert, batched=False, remove_columns=["prompt", "prompt_id", "messages"]) + + def collate_fn(batch): + input_ids = torch.tensor(batch[0]["input_ids"]).unsqueeze(0) + attention_mask = torch.tensor(batch[0]["attention_mask"]).unsqueeze(0) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + return dict( + input_ids=input_ids, + position_ids=position_ids, + labels=input_ids, + attention_mask=attention_mask, + ) + + dl = torch.utils.data.DataLoader( + ds, batch_size=micro_batch_size, collate_fn=collate_fn, drop_last=True, shuffle=False + ) + +else: # fake dataset + samples = 16 + seqlen = 256 + input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100 + position_ids = torch.arange(seqlen * samples).view(-1, seqlen) + + ds = torch.utils.data.TensorDataset(input_ids, position_ids) + + def collate_fn(batch): + input_ids, position_ids = batch[0] + return dict( + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + labels=input_ids.unsqueeze(0), + ) + + dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + +rank = torch.distributed.get_rank() + +if rank == 0: + print(f"DL orig: {len(dl)} samples") + +model, optimizer, dl = accelerator.prepare(model, optimizer, dl) + +if rank == 0: + print(f"DL w/ adapter: {len(dl)} samples") + +sp_size = parallelism_config.sp_size if parallelism_config else 1 +if sp_size > 1: + sp_group = accelerator.torch_device_mesh["sp"].get_group() + sp_world_size = parallelism_config.sp_size + +unwrapped_model = accelerator.unwrap_model(model) + +# Normal training loop +for iter, batch in enumerate(dl): + optimizer.zero_grad() + + if rank == 0: + print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}") + batch = move_to_device(batch, model.device) + outputs = model(**batch) + + shift_labels = batch["shift_labels"] + loss = unwrapped_model.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=unwrapped_model.config.vocab_size, + ) + + if sp_size > 1: + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) + # special dealing with SFT that has prompt tokens that aren't used in loss computation + good_tokens = (shift_labels != -100).view(-1).sum() + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) + total_good_tokens = sum(good_tokens_per_rank) + loss = total_loss / max(total_good_tokens, 1) + + if rank == 0: + accelerator.print(f"{iter}: {loss=}") + accelerator.log(dict(train_loss=loss, step=iter)) + + accelerator.backward(loss) + optimizer.step() + +accelerator.end_training() diff --git a/examples/alst_ulysses_sequence_parallelism/sp-alst.sh b/examples/alst_ulysses_sequence_parallelism/sp-alst.sh new file mode 100755 index 00000000000..e7e43c6d2ce --- /dev/null +++ b/examples/alst_ulysses_sequence_parallelism/sp-alst.sh @@ -0,0 +1,8 @@ +export MASTER_ADDR=localhost +export MASTER_PORT=9998 +python -u -m accelerate.commands.launch \ + --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --config_file sp-alst.accelerate-config.yml \ + sp-alst.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index fadd03b3314..08b9b492386 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1616,6 +1616,10 @@ def _prepare_tp(self, *args): return args def _prepare_cp(self, *args): + if self.parallelism_config.sp_backend == "deepspeed": + # deepspeed handles cp in a different way, configured in _prepare_deepspeed + return args + from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method @@ -2075,6 +2079,11 @@ def _prepare_deepspeed(self, *args): is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) tp_size = deepspeed_plugin.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) + + sp_backend = self.parallelism_config.sp_backend if self.parallelism_config else None + sp_size = self.parallelism_config.sp_size if self.parallelism_config else 1 + sp_handler = self.parallelism_config.sp_handler if self.parallelism_config else None + if tp_size > 1: if not compare_versions("deepspeed", ">=", "0.16.4"): raise ImportError( @@ -2142,11 +2151,14 @@ def _prepare_deepspeed(self, *args): "gradient_clipping": 1.0, "zero_optimization.stage3_gather_16bit_weights_on_model_save": False, } - # This is skipped when preparing just a model + # This block is skipped when preparing just a model and DL is absent from current call's args if batch_size_per_device is not None: config_kwargs["train_micro_batch_size_per_gpu"] = batch_size_per_device config_kwargs["train_batch_size"] = ( - batch_size_per_device * deepspeed_plugin.get_value("gradient_accumulation_steps") * self.num_processes + batch_size_per_device + * deepspeed_plugin.get_value("gradient_accumulation_steps") + * self.num_processes + // sp_size ) model = None @@ -2264,8 +2276,19 @@ def _prepare_deepspeed(self, *args): if not self.split_batches else scheduler.total_num_steps ) + deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs) self.deepspeed_config = deepspeed_plugin.deepspeed_config + + # note: batch_size derivation is all over the map, especiall in HF Trainer, so try to fix it at the last moment if needed + pc = self.parallelism_config + if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_size > 1: + self.deepspeed_config["train_batch_size"] = ( + self.deepspeed_config["train_micro_batch_size_per_gpu"] + * self.deepspeed_config["gradient_accumulation_steps"] + * pc.data_parallel_size + ) + kwargs = dict(model=model, config_params=self.deepspeed_config) if optimizer is not None: if isinstance(optimizer, (DummyOptim)): @@ -2293,6 +2316,54 @@ def _prepare_deepspeed(self, *args): # It should be done by the launcher but it does not work for multi-node runs os.environ["DEEPSPEED_USE_HPU"] = "true" + mpu = None + if sp_size > 1: + if sp_backend != "deepspeed": + raise ValueError( + f"In order to use the configured {sp_size=} with DeepSpeed, you need to configure sp_backend='deepspeed', yet you configured it to be {sp_backend=}." + ) + + ver_min_required = "0.18.2" + if not compare_versions("deepspeed", ">=", ver_min_required): + raise ImportError( + f"Deepspeed ALST/Ulysses requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`." + ) + + from deepspeed.runtime.sequence_parallel.ulysses_sp import ( + UlyssesSPAttentionHF, + UlyssesSPDataLoaderAdapter, + ) + + if not hasattr(model, "config"): + raise ValueError( + "UlyssesSPAttentionHF currently works with HF Transformers and expects the model object to have a config attribute but this model doesn't have one." + ) + + mpu = UlyssesSPAttentionHF.register_with_transformers( + model_name_or_path=model, + sequence_parallel_size=sp_size, + seq_length=sp_handler.sp_seq_length, + seq_length_is_variable=sp_handler.sp_seq_length_is_variable, + core_attn_implementation=sp_handler.sp_attn_implementation, + micro_batch_size=batch_size_per_device, + ) + kwargs["mpu"] = mpu + + for i in range(len(result)): + if isinstance(result[i], torch.utils.data.DataLoader): + if sp_size > 1: + # note that in case dataloader was prepared apart from model (for the external accelerator.prepare call) you'd need to call deepspeed_ulysses_dl_adapter after prepare(model) (see HF Trainer as the use-case) + sp_group = mpu.get_sequence_parallel_group() + sp_world_size = mpu.get_sequence_parallel_world_size() + sp_rank = mpu.get_sequence_parallel_rank() + result[i] = UlyssesSPDataLoaderAdapter( + result[i], + sp_rank=sp_rank, + sp_group=sp_group, + sp_world_size=sp_world_size, + device=self.device, # model.device, + ) + engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs) if compare_versions("deepspeed", ">=", "0.14.4") and self.state.dynamo_plugin.backend != DynamoBackend.NO: @@ -2323,6 +2394,7 @@ def _prepare_deepspeed(self, *args): type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES ): result[i] = scheduler + # pointing for deepspeed_engine_wrapped.backward() if self.deepspeed_engine_wrapped is None: self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine) @@ -2339,6 +2411,26 @@ def _prepare_deepspeed(self, *args): self._schedulers.append(scheduler) return tuple(result) + def deepspeed_ulysses_dl_adapter(self, dl, model): + """this is normally called as part of `prepare` but when dataloader was prepared apart from model (for the external accelerator.prepare call) this additional call needs to be made after prepare(model) (see HF Trainer as the use-case)""" + sp_size = self.parallelism_config.sp_size if self.parallelism_config else 1 + if sp_size == 1: + return dl + from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter + from deepspeed.utils import groups + + sp_group = groups._get_sequence_parallel_group() + sp_world_size = groups._get_sequence_parallel_world_size() + sp_rank = groups._get_sequence_parallel_rank() + dl = UlyssesSPDataLoaderAdapter( + dl, + sp_rank=sp_rank, + sp_group=sp_group, + sp_world_size=sp_world_size, + device=model.device, + ) + return dl + def _prepare_megatron_lm(self, *args): megatron_lm_plugin = self.state.megatron_lm_plugin micro_batch_size = None @@ -3910,9 +4002,10 @@ def get_state_dict(self, model, unwrap=True): tp_sharding = self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1 if zero3_sharding or tp_sharding: if model.zero_gather_16bit_weights_on_model_save(): - if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"): + ver_min_required = "0.16.4" + if tp_sharding and not compare_versions("deepspeed", ">=", ver_min_required): raise ImportError( - "Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`." + f"Deepspeed TP requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`." ) state_dict = ( model._consolidated_16bit_state_dict() @@ -4009,7 +4102,7 @@ def maybe_context_parallel( - `context_parallel` is currently only supported together with FSDP2, and requires `parallelism_config.cp_size` > + `context_parallel` is currently supported with FSDP2 and requires `parallelism_config.cp_size` > 1. If either of these conditions are not met, this context manager will have no effect, though to enable fewer code changes it will not raise an Exception. @@ -4036,7 +4129,11 @@ def maybe_context_parallel( """ # We don't need to check FSDP2 as parallelism_config does that for us # Invariant: in this branch self._cp_context is set, as it was set by `self._prepare_cp` - if self.parallelism_config and self.parallelism_config.cp_enabled: + if ( + self.parallelism_config + and self.parallelism_config.cp_backend == "torch" + and self.parallelism_config.cp_enabled + ): with self._cp_context( buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers ): diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 00db75fc22b..6da31c52312 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -771,6 +771,7 @@ def launch_command_parser(subparsers=None): "ParallelismConfig Arguments", "Arguments related to the ParallelismConfig used for distributed training.", ) + parallelism_config_args.add_argument( "--parallelism_config_dp_replicate_size", type=int, @@ -798,6 +799,15 @@ def launch_command_parser(subparsers=None): default=1, help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).", ) + + parallelism_config_args.add_argument( + "--parallelism_config_cp_backend", + type=str, + choices=["torch"], + default="torch", + help="Context Parallelism backend: torch (FSDP2) or deepspeed (ALST/Ulysses)", + ) + parallelism_config_args.add_argument( "--parallelism_config_cp_comm_strategy", type=str, @@ -805,6 +815,42 @@ def launch_command_parser(subparsers=None): help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall", ) + parallelism_config_args.add_argument( + "--parallelism_config_sp_size", + type=int, + default=1, + help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_backend", + type=str, + choices=["deepspeed"], + default="deepspeed", + help="Sequence Parallelism backend: deepspeed (ALST/Ulysses)", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_seq_length", + type=str, + default=None, + help="Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `parallelism_config_sp_seq_length_is_variable=True`", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_seq_length_is_variable", + type=bool, + default=True, + help="If `True` will work with a sequence length that may change between batches, in which case `parallelism_config_sp_seq_length` value can be set to anything divisible by sp size or remain unset. If `False` then `parallelism_config_sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`.", + ) + + parallelism_config_args.add_argument( + "--parallelism_config_sp_attn_implementation", + type=str, + default="sdpa", + help="Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`.", + ) + # Other arguments of the training scripts parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.") diff --git a/src/accelerate/parallelism_config.py b/src/accelerate/parallelism_config.py index d9ee15e3deb..0869d0aca23 100644 --- a/src/accelerate/parallelism_config.py +++ b/src/accelerate/parallelism_config.py @@ -1,5 +1,5 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,9 +15,14 @@ import os import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union - -from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig +from typing import TYPE_CHECKING, Literal, Optional, Union + +from accelerate.utils.dataclasses import ( + DeepSpeedSequenceParallelConfig, + DistributedType, + TorchContextParallelConfig, + TorchTensorParallelConfig, +) from accelerate.utils.versions import is_torch_version @@ -41,11 +46,17 @@ class ParallelismConfig: tp_size (`int`, defaults to `1`): The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be used. + tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`): + The handler for the tensor parallel group. cp_size (`int`, defaults to `1`): The size of the context parallel group. Currently not supported, but reserved for future use and enabled for downstream libraries. - tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`): - The handler for the tensor parallel group. + cp_backend (`str`, defaults to `torch`): + Which CP backend to use: `torch` (FSDP2) + sp_size (`int`, defaults to `1`): + The size of the sequence parallel group. + sp_backend (`str`, defaults to `deepspeed`): + Which SP backend to use:`deepspeed` (ALST/Ulysses) You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size` together: @@ -60,10 +71,14 @@ class ParallelismConfig: dp_shard_size: Optional[int] = None tp_size: Optional[int] = None cp_size: Optional[int] = None + cp_backend: Literal["torch"] = None + sp_size: Optional[int] = None + sp_backend: Literal["deepspeed"] = None # we use Union because we might support other x parallel plugins (i.e. deepspeed, etc) tp_handler: Union[None, TorchTensorParallelConfig] = None cp_handler: Union[None, TorchContextParallelConfig] = None + sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None device_mesh = None @@ -74,6 +89,9 @@ def __repr__(self): f"\tdp_shard_size={self.dp_shard_size},\n" f"\ttp_size={self.tp_size},\n" f"\tcp_size={self.cp_size},\n" + f"\tcp_backend={self.cp_backend},\n" + f"\tsp_size={self.sp_size},\n" + f"\tsp_backend={self.sp_backend},\n" f"\ttotal_size={self.total_size}\n" f"\ttp_handler={self.tp_handler},\n" f"\tcp_handler={self.cp_handler})\n" @@ -110,6 +128,8 @@ def non_dp_dim_names(self): dims += ["tp"] if self.cp_enabled: dims += ["cp"] + if self.sp_enabled: + dims += ["sp"] return dims @property @@ -146,12 +166,12 @@ def fsdp_dim_names(self): @property def total_size(self): """The total size of the parallelism configuration, which is the product of all sizes.""" - return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size + return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size * self.sp_size @property def non_data_parallel_size(self): """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes.""" - return self.tp_size * self.cp_size + return self.tp_size * self.cp_size * self.sp_size @property def data_parallel_size(self): @@ -178,6 +198,11 @@ def cp_enabled(self): """True if context parallelism is enabled, i.e. `cp_size > 1`.""" return self.cp_size > 1 + @property + def sp_enabled(self): + """True if context parallelism is enabled, i.e. `sp_size > 1`.""" + return self.sp_size > 1 + @property def active_mesh_dims(self): """Names of all active mesh dimensions.""" @@ -234,7 +259,7 @@ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]: mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims} # Apply canonical ordering - mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"] + mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"] sorted_items = sorted( mesh_dims.items(), key=lambda x: (mesh_order.index(x[0])), @@ -251,6 +276,12 @@ def __post_init__(self): self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1")) if self.cp_size is None: self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1")) + if self.cp_backend is None: + self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch") + if self.sp_size is None: + self.sp_size = int(os.environ.get("PARALLELISM_CONFIG_SP_SIZE", "1")) + if self.sp_backend is None: + self.sp_backend = os.environ.get("PARALLELISM_CONFIG_SP_BACKEND", "deepspeed") if self.tp_size > 1: if self.tp_handler is None: @@ -259,7 +290,18 @@ def __post_init__(self): if self.cp_size > 1: if self.cp_handler is None: self.cp_handler = TorchContextParallelConfig() + else: + cp_backends_config_map = dict( + torch=TorchContextParallelConfig, + ) + if not isinstance(self.cp_handler, cp_backends_config_map[self.cp_backend]): + raise ValueError( + f"ParallelismConfig's cp_backend={self.cp_backend} requires {cp_backends_config_map[self.cp_backend]}, but cp_handler was set to {type(self.cp_handler)}" + ) + if self.sp_size > 1: + if self.sp_handler is None: + self.sp_handler = DeepSpeedSequenceParallelConfig() if self.dp_replicate_size < 1: raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}") if self.dp_shard_size < 1: @@ -268,6 +310,15 @@ def __post_init__(self): raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}") if self.cp_size < 1: raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}") + valid_cp_backends = ["torch"] + if self.cp_backend not in valid_cp_backends: + raise ValueError(f"cp_backend must be one of {valid_cp_backends}, but got {self.cp_backend}") + + if self.sp_size < 1: + raise ValueError(f"sp_size must be at least 1, but got {self.sp_size}") + valid_sp_backends = ["deepspeed"] + if self.sp_backend not in valid_sp_backends: + raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}") if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1: raise ValueError( @@ -280,6 +331,7 @@ def __post_init__(self): "dp_shard": self.dp_shard_size, "tp": self.tp_size, "cp": self.cp_size, + "sp": self.sp_size, } def _set_size(self, parallelism: str, size: int): @@ -301,12 +353,16 @@ def _validate_accelerator(self, accelerator: "Accelerator"): raise ValueError( f"ParallelismConfig total_size ({self.total_size}) does not match " f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ " - f"dp_shard_size/tp_size/cp_size." + f"dp_shard_size/tp_size/cp_size/sp_size." ) - if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device): + if self.total_size > 1 and not ( + accelerator.is_fsdp2 + or accelerator.multi_device + or accelerator.distributed_type == DistributedType.DEEPSPEED + ): raise ValueError( - f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}." + f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}." ) for parallelism, size in self._sizes.items(): diff --git a/src/accelerate/state.py b/src/accelerate/state.py index bd3e10104cb..4da3c91f3c2 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -991,7 +991,7 @@ def __init__( if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true": if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None: raise ValueError( - "`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism, as we also shard the model across the device mesh to save more memory" + "`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism with `cp_backend=torch`, as we also shard the model across the device mesh to save more memory" ) if ( self.parallelism_config is not None diff --git a/src/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py b/src/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py new file mode 100644 index 00000000000..eb910a45dfb --- /dev/null +++ b/src/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py @@ -0,0 +1,129 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for verifying ALST/Ulysses SP works +""" + +import torch +from deepspeed.runtime.utils import move_to_device +from transformers import AutoModelForCausalLM, AutoTokenizer + +from accelerate import Accelerator +from accelerate.utils import ParallelismConfig, set_seed +from accelerate.utils.dataclasses import DeepSpeedSequenceParallelConfig + + +set_seed(42) + +world_size = 2 +model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + +micro_batch_size = 1 + +parallelism_config = ParallelismConfig( + sp_backend="deepspeed", + sp_size=world_size, + # dp_shard_size=1, # set if dp is wanted as well + sp_handler=DeepSpeedSequenceParallelConfig( + sp_seq_length=256, + sp_seq_length_is_variable=True, + sp_attn_implementation="sdpa", + ), +) + +accelerator = Accelerator( + parallelism_config=parallelism_config, +) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name) + +samples = 4 +seqlen = 32 +input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100 +position_ids = torch.arange(seqlen * samples).view(-1, seqlen) + +ds = torch.utils.data.TensorDataset(input_ids, position_ids) + + +def collate_fn(batch): + input_ids, position_ids = batch[0] + return dict( + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), + labels=input_ids.unsqueeze(0), + ) + + +dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn) + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + +rank = torch.distributed.get_rank() + +if rank == 0: + print(f"DL orig: {len(dl)} samples") + +model, optimizer, dl = accelerator.prepare(model, optimizer, dl) + +if rank == 0: + print(f"DL w/ adapter: {len(dl)} samples") + +sp_size = parallelism_config.sp_size if parallelism_config else 1 +if sp_size > 1: + sp_group = accelerator.torch_device_mesh["sp"].get_group() + sp_world_size = parallelism_config.sp_size + +unwrapped_model = accelerator.unwrap_model(model) + +# Normal training loop +for iter, batch in enumerate(dl): + optimizer.zero_grad() + + if rank == 0: + print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}") + batch = move_to_device(batch, model.device) + outputs = model(**batch) + + shift_labels = batch["shift_labels"] + loss = unwrapped_model.loss_function( + logits=outputs.logits, + labels=None, + shift_labels=shift_labels, + vocab_size=unwrapped_model.config.vocab_size, + ) + + if sp_size > 1: + # differentiable weighted per-shard-loss aggregation across ranks + losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) + # special dealing with SFT that has prompt tokens that aren't used in loss computation + good_tokens = (shift_labels != -100).view(-1).sum() + good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) + total_loss = sum( + losses_per_rank[rank] * good_tokens_per_rank[rank] + for rank in range(sp_world_size) + if good_tokens_per_rank[rank] > 0 + ) + total_good_tokens = sum(good_tokens_per_rank) + loss = total_loss / max(total_good_tokens, 1) + + if rank == 0: + accelerator.print(f"{iter}: {loss=}") + accelerator.log(dict(train_loss=loss, step=iter)) + + accelerator.backward(loss) + optimizer.step() + +accelerator.end_training() diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 7bf4f6e070c..8979b3eadf7 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -42,6 +42,7 @@ DataLoaderConfiguration, DDPCommunicationHookType, DeepSpeedPlugin, + DeepSpeedSequenceParallelConfig, DistributedDataParallelKwargs, DistributedType, DynamoBackend, diff --git a/src/accelerate/utils/constants.py b/src/accelerate/utils/constants.py index 2e2be5434b2..10a381309f9 100644 --- a/src/accelerate/utils/constants.py +++ b/src/accelerate/utils/constants.py @@ -54,6 +54,7 @@ BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0" BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0" +BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2" STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index d5055d3bc0b..fb01d35d14c 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1227,6 +1227,7 @@ def __post_init__(self): if self.hf_ds_config is None: self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") + if ( isinstance(self.hf_ds_config, dict) or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none") @@ -2178,9 +2179,10 @@ class TorchContextParallelConfig: def __post_init__(self): if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION): raise ValueError( - f"Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. " + f"FSDP2-based Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. " "Please upgrade your PyTorch version." ) + if self.cp_comm_strategy is None: self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather") if self.cp_comm_strategy not in ["allgather", "alltoall"]: @@ -2189,6 +2191,56 @@ def __post_init__(self): ) +@dataclass +class DeepSpeedSequenceParallelConfig: + sp_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": "Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `sp_seq_length_is_variable=True` and leave this field unset" + }, + ) + sp_seq_length_is_variable: Optional[bool] = field( + default=None, + metadata={ + "help": "If `True` will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by cp size or remain unset. If `False` then `sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`." + }, + ) + sp_attn_implementation: Optional[str] = field( + default=None, + metadata={ + "help": "Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`." + }, + ) + + def __post_init__(self): + # sp_seq_length_is_variable and sp_seq_length are interconnected + if self.sp_seq_length_is_variable is None: + self.sp_seq_length_is_variable = ( + os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE", "true").lower() == "true" + ) + + if not self.sp_seq_length_is_variable and self.sp_seq_length is None: + if "PARALLELISM_CONFIG_SP_SEQ_LENGTH" not in os.environ: + raise ValueError( + "when `sp_seq_length_is_variable` is `False` `sp_seq_length` must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_SP_SEQ_LENGTH" + ) + else: + self.sp_seq_length = os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH") + self.sp_seq_length = None if self.sp_seq_length == "None" else int(self.sp_seq_length) + + if self.sp_attn_implementation is None: + self.sp_attn_implementation = os.environ.get("PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION", None) + + if self.sp_attn_implementation is not None and self.sp_attn_implementation not in [ + "flash_attention_2", + "flash_attention_3", + "sdpa", + ]: + raise ValueError( + f"Invalid sp_attn_implementation: {self.sp_attn_implementation}. Must be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'." + ) + + @dataclass class TorchTensorParallelConfig: """ diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index 020c0e820e5..6182bd40d6f 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -350,18 +350,35 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]: if args.enable_cpu_affinity: current_env["ACCELERATE_CPU_AFFINITY"] = "1" - if not args.use_parallelism_config: - return current_env + if args.use_parallelism_config: + current_env = prepare_extend_env_parallelism_config(args, current_env) + + return current_env + + +def prepare_extend_env_parallelism_config( + args: argparse.Namespace, current_env: dict +) -> tuple[list[str], dict[str, str]]: + """ + Extends `current_env` with context parallelism env vars if any have been set + """ prefix = "PARALLELISM_CONFIG_" - if args.use_parallelism_config: - current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true" - current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size) - current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size) - current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size) - current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size) - if args.parallelism_config_cp_size > 1: - current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy) + + current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true" + current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size) + current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size) + current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size) + current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size) + current_env[prefix + "CP_BACKEND"] = str(args.parallelism_config_cp_backend) + current_env[prefix + "SP_SIZE"] = str(args.parallelism_config_sp_size) + current_env[prefix + "SP_BACKEND"] = str(args.parallelism_config_sp_backend) + if args.parallelism_config_cp_size > 1: + current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy) + if args.parallelism_config_sp_size > 1: + current_env[prefix + "SP_SEQ_LENGTH"] = str(args.parallelism_config_sp_seq_length) + current_env[prefix + "SP_SEQ_LENGTH_IS_VARIABLE"] = str(args.parallelism_config_sp_seq_length_is_variable) + current_env[prefix + "SP_ATTN_IMPLEMENTATION"] = str(args.parallelism_config_sp_attn_implementation) return current_env @@ -521,6 +538,10 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict current_env["ACCELERATE_CPU_AFFINITY"] = "1" if args.deepspeed_moe_layer_cls_names is not None: current_env["ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES"] = str(args.deepspeed_moe_layer_cls_names) + + if args.use_parallelism_config: + current_env = prepare_extend_env_parallelism_config(args, current_env) + return cmd, current_env diff --git a/tests/deepspeed/test_alst_ulysses_sp.py b/tests/deepspeed/test_alst_ulysses_sp.py new file mode 100644 index 00000000000..08f6eb9d322 --- /dev/null +++ b/tests/deepspeed/test_alst_ulysses_sp.py @@ -0,0 +1,49 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parameterized import parameterized + +from accelerate.test_utils.testing import ( + TempDirTestCase, + execute_subprocess_async, + path_in_accelerate_package, + require_deepspeed, + require_multi_device, +) +from accelerate.utils import patch_environment + + +@require_deepspeed +@require_multi_device +class DeepSpeedALSTUlyssesSPTest(TempDirTestCase): + test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps") + + @parameterized.expand([2, 3]) + def test_deepspeed_alst_ulysses_sp(self, stage): + self.test_file_path = self.test_scripts_folder / "test_ds_alst_ulysses_sp.py" + world_size = 2 + cmd = [ + "accelerate", + "launch", + f"--num_processes={world_size}", + "--num_machines=1", + "--machine_rank=0", + "--mixed_precision=bf16", + "--use_deepspeed", + f"--zero_stage={stage}", + self.test_file_path, + f"--output_dir={self.tmpdir}", + ] + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 52c9f4aca71..a6c463d4d18 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -20,10 +20,11 @@ from accelerate.utils import patch_environment from accelerate.utils.constants import ( BETA_CP_AVAILABLE_PYTORCH_VERSION, + BETA_SP_AVAILABLE_DEEPSPEED_VERSION, BETA_TP_AVAILABLE_PYTORCH_VERSION, BETA_TP_AVAILABLE_TRANSFORMERS_VERSION, ) -from accelerate.utils.imports import is_transformers_available +from accelerate.utils.imports import is_deepspeed_available, is_transformers_available from accelerate.utils.versions import compare_versions, is_torch_version @@ -32,6 +33,15 @@ def _should_skip_cp_test(cp_size): return cp_size > 1 and not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION) +def _should_skip_sp_test(sp_size): + """Check if SP test should be skipped based on sp_size and deepspeed version.""" + if sp_size <= 1: + return False + if not is_deepspeed_available(): + return True + return not compare_versions("deepspeed", ">=", BETA_SP_AVAILABLE_DEEPSPEED_VERSION) + + def _should_skip_tp_test(tp_size): """Check if TP test should be skipped based on tp_size, torch version, and transformers availability.""" if tp_size <= 1: @@ -212,8 +222,8 @@ def test_from_env( for key, value in new_env.items(): assert getattr(config, key.split("PARALLELISM_CONFIG_")[-1].lower()) == value - def test_cp_handler(self): - """Test CP handler with various configurations.""" + def test_cp_torch_handler(self): + """Test CP Torch/FSDP2 handler with various configurations.""" # Any cp_size > 1 requires torch >= BETA_CP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size if _should_skip_cp_test(2): @@ -246,5 +256,22 @@ def test_cp_handler(self): with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"): pc = ParallelismConfig(cp_size=2) + def test_sp_deepspeed_handler(self): + """Test SP DeepSpeed/ALST/UlyssesSP handler with various configurations.""" + + # Any sp_size > 1 requires torch >= BETA_SP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size + if _should_skip_sp_test(2): + pytest.skip(f"tests with `sp_size>1` require deepspeed >= {BETA_SP_AVAILABLE_DEEPSPEED_VERSION}") + + from accelerate.utils import DeepSpeedSequenceParallelConfig + + sp_handler = DeepSpeedSequenceParallelConfig() + pc = ParallelismConfig(sp_backend="deepspeed", sp_size=2, sp_handler=sp_handler) + assert pc.sp_handler is not None, "SP handler should be set" + assert pc.sp_handler.sp_seq_length_is_variable is True, "by default we set to expect a variable seqlen" + + with pytest.raises(ValueError, match="Invalid sp_attn_implementation"): + DeepSpeedSequenceParallelConfig(sp_attn_implementation="foobar") + def test_tp_handler(self): assert True, "Tensor parallelism handler doesn't hold any logic yet"