Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
439fa7e
Feat: initial impl
S1ro1 Sep 16, 2025
e238f7b
improve
sfc-gh-sbekman Oct 22, 2025
876fa2d
s/flavour/backend/
sfc-gh-sbekman Oct 22, 2025
b016317
style + ver
sfc-gh-sbekman Oct 22, 2025
a63d094
better check
sfc-gh-sbekman Oct 22, 2025
73d3dbc
check
sfc-gh-sbekman Oct 22, 2025
8b4a4f4
docs + example
sfc-gh-sbekman Oct 22, 2025
76288a8
add tests
sfc-gh-sbekman Oct 22, 2025
209eab7
add tests
sfc-gh-sbekman Oct 22, 2025
c013677
cleanup
sfc-gh-sbekman Oct 22, 2025
685453c
cleanup
sfc-gh-sbekman Oct 22, 2025
a396904
Apply suggestions from code review
stas00 Oct 23, 2025
fb80e02
add experimental notice
sfc-gh-sbekman Oct 23, 2025
21a4a2d
style
sfc-gh-sbekman Oct 23, 2025
05b6ac1
Merge branch 'alst-integration' of https://github.com/stas00/accelera…
sfc-gh-sbekman Oct 23, 2025
60f7493
new deepspeed version
sfc-gh-sbekman Oct 23, 2025
453fb55
additional checks + tests
sfc-gh-sbekman Oct 23, 2025
8677f23
more docs
sfc-gh-sbekman Oct 23, 2025
8ee9b03
more docs
sfc-gh-sbekman Oct 23, 2025
2e577d3
working now
sfc-gh-sbekman Oct 28, 2025
5c51897
style
sfc-gh-sbekman Oct 28, 2025
8317241
update docs
sfc-gh-sbekman Oct 28, 2025
94f558b
more robust config parsing
sfc-gh-sbekman Oct 28, 2025
a2388cd
fix
sfc-gh-sbekman Oct 28, 2025
e6e243f
Apply suggestions from code review
stas00 Nov 4, 2025
2330dcd
check backend, integrate ulysses API improvement
sfc-gh-sbekman Nov 5, 2025
9dbcf91
style
sfc-gh-sbekman Nov 5, 2025
e79034f
fix default to match the doc
sfc-gh-sbekman Nov 5, 2025
61873c6
Apply suggestions from code review
stas00 Nov 5, 2025
756bd9f
fix
sfc-gh-sbekman Nov 5, 2025
56df621
deepspeed=0.18.2 is out
sfc-gh-sbekman Nov 5, 2025
380747c
Apply suggestions from code review
stas00 Nov 10, 2025
38c84fa
s/cp/sp
sfc-gh-sbekman Nov 14, 2025
5c2f34e
fixes
sfc-gh-sbekman Nov 14, 2025
190494b
Apply suggestions from code review
stas00 Nov 14, 2025
285e24f
Update src/accelerate/parallelism_config.py
stas00 Nov 14, 2025
a7d2e5d
suggestion
sfc-gh-sbekman Nov 14, 2025
d4ee156
Update docs/source/concept_guides/sequence_parallelism.md
stas00 Nov 17, 2025
99b321a
Update sequence_parallelism.md
stas00 Nov 17, 2025
b769fc8
fix
sfc-gh-sbekman Nov 17, 2025
04b4dc3
fix
sfc-gh-sbekman Nov 17, 2025
a4005e7
fix
sfc-gh-sbekman Nov 18, 2025
10978a0
Apply suggestion from @kashif
kashif Nov 20, 2025
4115257
Apply suggestion from @kashif
kashif Nov 20, 2025
891c702
Apply suggestion from @kashif
kashif Nov 20, 2025
71f61d6
Apply suggestion from @kashif
kashif Nov 20, 2025
1d5bc22
Apply suggestion from @kashif
kashif Nov 20, 2025
7282db8
Apply suggestion from @kashif
kashif Nov 20, 2025
84638eb
Apply suggestion from @kashif
kashif Nov 20, 2025
d0d8860
Apply suggestion from @kashif
kashif Nov 20, 2025
1e19b82
Apply suggestion from @kashif
kashif Nov 20, 2025
5d099cf
Apply suggestion from @kashif
kashif Nov 20, 2025
c3b2ce7
Apply suggestion from @kashif
kashif Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 119 additions & 8 deletions docs/source/concept_guides/context_parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ With sequence length of 128k, the memory requirement of the attention matrix is

Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.

## Supported backends

Multiple backends are currently supported

1. `torch`: PyTorch/FSDP2,which implements several of Ring Attention context parallel protocols [tutorial](https://docs.pytorch.org/tutorials/unstable/context_parallel.html) and [api](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel).
2. `deepspeed`: DeepSpeed/ALST/UlyssesSP, which implements sequence parallelism using attention head parallelism: [tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/) and [paper](https://arxiv.org/abs/2506.13996)

## How to use context parallelism?

```diff
Expand All @@ -44,8 +51,21 @@ accelerator = Accelerator(
)
```

As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.
In this case, it's no different:
By default the `torch` backend is selected, but you can select the deepspeed backend via:

```python
parallelism_config = ParallelismConfig(
backend="deepspeed",
cp_size=4,
cp_handler=DeepSpeedContextParallelConfig(
seq_length=256,
attn_implementation="sdpa"
),
)
```
See the following sections for nuances of each backend.

As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`. In this case, it's no different:

```bash
accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...
Expand All @@ -58,12 +78,14 @@ accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-st
> Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2.
> You can simply combine them by setting your parallelism sizes to the desired values, e.g. `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.

## Torch/FSDP2 backend

> [!Warning]
> Context parallelism is tightly coupled with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism only works if you use `FullyShardedDataParallelPlugin` or `--use-fsdp` with version set to 2 to your
> program. If no `FSDP2` is used, error will be raised.

> [!Warning]
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
> `torch`-backend Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.

After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
You can use it as follows:
Expand Down Expand Up @@ -91,11 +113,100 @@ This can scale your context size to 1M+ sequence length potentially. Below, we s
</p>

> [!Tip]
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
> ```bash
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/torch_native_parallelism/nd_parallel.py --cp-size=8 --sequence-length=128000
> ```

## DeepSpeed/ALST/UlyssesSP backend

ALST/UlyssesSP implements a sequence parallelism using attention head parallelism as explained in [this paper](https://arxiv.org/abs/2506.13996) - for simplicity we re-use the concept and the setup of context parallelism, which from the user's end of view is the same - multiple gpus are used to process a single batch.

To give a sense of what ALST made possible - it allowed us to train in bf16 with 500K tokens on a single H100 GPU, 3.7M on a single node, and 15M on Llama-8B using just four nodes. This feature of HF Accelerate enables only 1 of the 3 ALST components so the achievable sequence length will be smaller. You'd want TiledMLP, Activation checkpoint offload to CPU and a few other things enabled to get the full power of ALST, for details please refer to [this tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).

To configure the `deepspeed` backend:

```python
parallelism_config = ParallelismConfig(
backend="deepspeed",
cp_size=4,
cp_handler=DeepSpeedContextParallelConfig(
seq_length=256,
seq_length_is_variable=True,
attn_implementation="sdpa",
),
)
accelerator = Accelerator(
...,
parallelism_config=parallelism_config,
)
```

- `cp_size` is the degree of the sequence parallelism - in the above example it's 4, therefore 4 gpus will be used to process a single batch.
- `seq_length` and `seq_length_is_variable` are used to deal with sequence lengths. If `seq_length_is_variable=True` the backend will work with a sequence length that may change between batches, in which case `seq_length` value can be set to anything divisible by the context parallel degree or not set at all. In this case on every `forward` the sequence variables will be derived from input. If `False` then `seq_length` needs to match the batch's sequence length dimension, which then will have to be padded to be always the same. The default is `True`.
- `attn_implementation` is one of `sdpa`, `flash_attention_2` or ``flash_attention_3`. This sequence parallel implementation uses `position_ids` instead of `attention_mask` therefore `eager` can't work here until it'd support working with `position_ids`. Also please note that `sdpa` doesn't handle correctly combined into one multiple-samples, it'd attend to the whole sample as one. If the samples aren't combined `sdpa` will work correctly. Therefore Flash Attention should be the ideal choise as it always works.

Instead of setting these values in `DeepSpeedContextParallelConfig` object, you can also use the environment variables to accomplish the same - here they are correspondingly to the end of the list above.
- `PARALLELISM_CONFIG_CP_SEQ_LENGTH`
- `PARALLELISM_CONFIG_CP_SEQ_LENGTH_IS_VARIABLE`
- `PARALLELISM_CONFIG_CP_ATTN_IMPLEMENTATION`

If not passed in the code `cp_size` can be set via `--parallelism_config_cp_size` CLI argument.

Please note that a lot of magic is hidden inside [UlyssesSPDataLoaderAdapter](https://github.com/deepspeedai/DeepSpeed/blob/64c0052fa08438b4ecf4cae30af15091a92d2108/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L442). It's used behind the scenes, wrapping your original DataLoader object, but you should be aware of it should you run into any problems. It also automatically injects the correct `shift_labels` into the batch dictionary, before the batch gets sharded across the participating ranks.

Now the only remaining piece to start using ALST/UlyssesSP is to aggregate the loss across ranks using a differentiable `all_gather` to get the grads right. The following code does it, while also exlcuding any masked out with `-100` tokens, to get the correct average:

```python
cp_size = parallelism_config.cp_size if parallelism_config else 1
if cp_size > 1:
sp_group = accelerator.torch_device_mesh["cp"].get_group()
sp_world_size = parallelism_config.cp_size

# Normal training loop
for iter, batch in enumerate(dl):
optimizer.zero_grad()

batch = move_to_device(batch, model.device)
outputs = model(**batch)

# only if not using liger-kernel
shift_labels = batch["shift_labels"]
loss = unwrapped_model.loss_function(
logits=outputs.logits,
labels=None,
shift_labels=shift_labels,
vocab_size=unwrapped_model.config.vocab_size,
)

if cp_size > 1:
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(
good_tokens, group=sp_group
)
total_loss = sum(
losses_per_rank[rank] * good_tokens_per_rank[rank]
for rank in range(sp_world_size)
)
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / max(total_good_tokens, 1)

if rank == 0: accelerator.print(f"{iter}: {loss=}")
accelerator.log(dict(train_loss=loss, step=iter))

accelerator.backward(loss)
optimizer.step()
```

If you use [Liger Kernel](https://github.com/linkedin/Liger-Kernel) it already knows how to handle `shift_labels` so you don't need to go through manual loss calculation, just calling `model(**batch)` will already get the `loss` calculated and done in a very memory-efficient way. If you didn't know about Liger-Kernel - it's highly recommended to be used especially for long sequence length since it liberates a lot of working memory that can be used for handling longer sequences.

If you want to see what HF Accelerate did behind the scenes please read [this full integration tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).

For an example of an Accelerate training loop with enabled ALST/UlyssesSP see [examples/alst_ulysses_sequence_parallelism](https://github.com/huggingface/accelerate/blob/main/examples/alst_ulysses_sequence_parallelism).


## Accelerate's interface

Expand Down Expand Up @@ -174,10 +285,10 @@ You can directly see this issue in the profiler output in the image below:
</p>


## Why only FSDP2?
## Why FSDP1 is not supported

We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
utilize its full potential.
We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
utilize its full potential.
How it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings.
This is a "free lunch" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below.

Expand Down
19 changes: 19 additions & 0 deletions examples/alst_ulysses_sequence_parallelism/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Deepspeed's ALST/Ulysses sequence parallelism

This is an example of the use of Ulysses Sequence Parallelism, which uses attention head parallelism and which is part of the Arctic Long Sequence Training project at [ArcticTraining](https://github.com/snowflakedb/ArcticTraining). [This paper](https://arxiv.org/abs/2506.13996) goes into the details of this protocol.

For nuances of usage please refer to the main HF Accelerate tutorial on [Context Parallelism](https://huggingface.co/docs/accelerate/en/concept_guides/context_parallelism).

You need to use at least `2` gpus to enable ALST/Ulysses sequence parallelism.

To run the example with `4` gpus:

```bash
bash ./cp-alst.sh
```

Change `4` to the desired sequence parallelism degree in these 2 files:
```
cp-alst.accelerate-config.yml:num_processes: 4
cp-alst.py: cp_size=4,
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: cp-alst.ds-config.json
zero3_init_flag: false
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
use_cpu: false
12 changes: 12 additions & 0 deletions examples/alst_ulysses_sequence_parallelism/cp-alst.ds-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3
},
"gradient_accumulation_steps": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"seq_parallel_communication_data_type": "bf16"
}
155 changes: 155 additions & 0 deletions examples/alst_ulysses_sequence_parallelism/cp-alst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from deepspeed.runtime.utils import move_to_device
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import Accelerator
from accelerate.utils import ParallelismConfig, set_seed
from accelerate.utils.dataclasses import DeepSpeedContextParallelConfig


set_seed(42)

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# to run the example faster switch to the random model
# model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"

micro_batch_size = 1

parallelism_config = ParallelismConfig(
backend="deepspeed",
cp_size=4,
# dp_shard_size=1, # set if dp is wanted as well
cp_handler=DeepSpeedContextParallelConfig(
seq_length=256,
seq_length_is_variable=True,
attn_implementation="sdpa",
),
)

accelerator = Accelerator(
parallelism_config=parallelism_config,
# log_with="wandb", # enable to log into wandb
)
accelerator.init_trackers(
project_name="ulysses-accelerate",
config={},
init_kwargs={"wandb": dict(entity="yak", name="deepspeed")},
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 2 quick rough datasets to demonstrate the workings
if 1: # real dataset
from datasets import load_dataset

ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:12]")

# this is a quick example, it should be made more efficient to be used in real application
def convert(ex):
texts = tokenizer.apply_chat_template(conversation=ex["messages"], tokenize=False)
tokenized_dict = tokenizer(texts, max_length=256, padding=True, truncation=True)
return tokenized_dict
ds = ds.map(convert, batched=False, remove_columns=["prompt", "prompt_id", "messages"])

def collate_fn(batch):
input_ids = torch.tensor(batch[0]["input_ids"]).unsqueeze(0)
attention_mask = torch.tensor(batch[0]["attention_mask"]).unsqueeze(0)
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0)
return dict(
input_ids=input_ids,
position_ids=position_ids,
labels=input_ids,
attention_mask=attention_mask,
)

dl = torch.utils.data.DataLoader(
ds, batch_size=micro_batch_size, collate_fn=collate_fn, drop_last=True, shuffle=False
)

else: # fake dataset
samples = 16
seqlen = 256
input_ids = torch.arange(1, seqlen * samples + 1).view(-1, seqlen) + 100
position_ids = torch.arange(seqlen * samples).view(-1, seqlen)

ds = torch.utils.data.TensorDataset(input_ids, position_ids)

def collate_fn(batch):
input_ids, position_ids = batch[0]
return dict(
input_ids=input_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
labels=input_ids.unsqueeze(0),
)

dl = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

rank = torch.distributed.get_rank()

if rank == 0:
print(f"DL orig: {len(dl)} samples")

model, optimizer, dl = accelerator.prepare(model, optimizer, dl)

if rank == 0:
print(f"DL w/ adapter: {len(dl)} samples")

cp_size = parallelism_config.cp_size if parallelism_config else 1
if cp_size > 1:
sp_group = accelerator.torch_device_mesh["cp"].get_group()
sp_world_size = parallelism_config.cp_size

unwrapped_model = accelerator.unwrap_model(model)

# Normal training loop
for iter, batch in enumerate(dl):
optimizer.zero_grad()

if rank == 0:
print(f"batch {iter}: seqlen: {len(batch['input_ids'][0])}")
batch = move_to_device(batch, model.device)
outputs = model(**batch)

shift_labels = batch["shift_labels"]
loss = unwrapped_model.loss_function(
logits=outputs.logits,
labels=None,
shift_labels=shift_labels,
vocab_size=unwrapped_model.config.vocab_size,
)

if cp_size > 1:
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / max(total_good_tokens, 1)

if rank == 0:
accelerator.print(f"{iter}: {loss=}")
accelerator.log(dict(train_loss=loss, step=iter))

accelerator.backward(loss)
optimizer.step()

accelerator.end_training()
8 changes: 8 additions & 0 deletions examples/alst_ulysses_sequence_parallelism/cp-alst.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export MASTER_ADDR=localhost
export MASTER_PORT=9998
python -u -m accelerate.commands.launch \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--config_file cp-alst.accelerate-config.yml \
cp-alst.py
Loading
Loading