Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
439fa7e
Feat: initial impl
S1ro1 Sep 16, 2025
e238f7b
improve
sfc-gh-sbekman Oct 22, 2025
876fa2d
s/flavour/backend/
sfc-gh-sbekman Oct 22, 2025
b016317
style + ver
sfc-gh-sbekman Oct 22, 2025
a63d094
better check
sfc-gh-sbekman Oct 22, 2025
73d3dbc
check
sfc-gh-sbekman Oct 22, 2025
8b4a4f4
docs + example
sfc-gh-sbekman Oct 22, 2025
76288a8
add tests
sfc-gh-sbekman Oct 22, 2025
209eab7
add tests
sfc-gh-sbekman Oct 22, 2025
c013677
cleanup
sfc-gh-sbekman Oct 22, 2025
685453c
cleanup
sfc-gh-sbekman Oct 22, 2025
a396904
Apply suggestions from code review
stas00 Oct 23, 2025
fb80e02
add experimental notice
sfc-gh-sbekman Oct 23, 2025
21a4a2d
style
sfc-gh-sbekman Oct 23, 2025
05b6ac1
Merge branch 'alst-integration' of https://github.com/stas00/accelera…
sfc-gh-sbekman Oct 23, 2025
60f7493
new deepspeed version
sfc-gh-sbekman Oct 23, 2025
453fb55
additional checks + tests
sfc-gh-sbekman Oct 23, 2025
8677f23
more docs
sfc-gh-sbekman Oct 23, 2025
8ee9b03
more docs
sfc-gh-sbekman Oct 23, 2025
2e577d3
working now
sfc-gh-sbekman Oct 28, 2025
5c51897
style
sfc-gh-sbekman Oct 28, 2025
8317241
update docs
sfc-gh-sbekman Oct 28, 2025
94f558b
more robust config parsing
sfc-gh-sbekman Oct 28, 2025
a2388cd
fix
sfc-gh-sbekman Oct 28, 2025
e6e243f
Apply suggestions from code review
stas00 Nov 4, 2025
2330dcd
check backend, integrate ulysses API improvement
sfc-gh-sbekman Nov 5, 2025
9dbcf91
style
sfc-gh-sbekman Nov 5, 2025
e79034f
fix default to match the doc
sfc-gh-sbekman Nov 5, 2025
61873c6
Apply suggestions from code review
stas00 Nov 5, 2025
756bd9f
fix
sfc-gh-sbekman Nov 5, 2025
56df621
deepspeed=0.18.2 is out
sfc-gh-sbekman Nov 5, 2025
380747c
Apply suggestions from code review
stas00 Nov 10, 2025
38c84fa
s/cp/sp
sfc-gh-sbekman Nov 14, 2025
5c2f34e
fixes
sfc-gh-sbekman Nov 14, 2025
190494b
Apply suggestions from code review
stas00 Nov 14, 2025
285e24f
Update src/accelerate/parallelism_config.py
stas00 Nov 14, 2025
a7d2e5d
suggestion
sfc-gh-sbekman Nov 14, 2025
d4ee156
Update docs/source/concept_guides/sequence_parallelism.md
stas00 Nov 17, 2025
99b321a
Update sequence_parallelism.md
stas00 Nov 17, 2025
b769fc8
fix
sfc-gh-sbekman Nov 17, 2025
04b4dc3
fix
sfc-gh-sbekman Nov 17, 2025
a4005e7
fix
sfc-gh-sbekman Nov 18, 2025
10978a0
Apply suggestion from @kashif
kashif Nov 20, 2025
4115257
Apply suggestion from @kashif
kashif Nov 20, 2025
891c702
Apply suggestion from @kashif
kashif Nov 20, 2025
71f61d6
Apply suggestion from @kashif
kashif Nov 20, 2025
1d5bc22
Apply suggestion from @kashif
kashif Nov 20, 2025
7282db8
Apply suggestion from @kashif
kashif Nov 20, 2025
84638eb
Apply suggestion from @kashif
kashif Nov 20, 2025
d0d8860
Apply suggestion from @kashif
kashif Nov 20, 2025
1e19b82
Apply suggestion from @kashif
kashif Nov 20, 2025
5d099cf
Apply suggestion from @kashif
kashif Nov 20, 2025
c3b2ce7
Apply suggestion from @kashif
kashif Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,10 @@ def _prepare_cp(self, *args):
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method

if not self.is_fsdp2:
# deepspeed handles cp other way
return args

cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)

Expand Down Expand Up @@ -2075,6 +2079,10 @@ def _prepare_deepspeed(self, *args):

is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
tp_size = deepspeed_plugin.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0)

cp_size = self.parallelism_config.cp_size if self.parallelism_config else 1
cp_handler = self.parallelism_config.cp_handler if self.parallelism_config else None

if tp_size > 1:
if not compare_versions("deepspeed", ">=", "0.16.4"):
raise ImportError(
Expand Down Expand Up @@ -2146,7 +2154,10 @@ def _prepare_deepspeed(self, *args):
if batch_size_per_device is not None:
config_kwargs["train_micro_batch_size_per_gpu"] = batch_size_per_device
config_kwargs["train_batch_size"] = (
batch_size_per_device * deepspeed_plugin.get_value("gradient_accumulation_steps") * self.num_processes
batch_size_per_device
* deepspeed_plugin.get_value("gradient_accumulation_steps")
* self.num_processes
// cp_size
)

model = None
Expand All @@ -2162,6 +2173,29 @@ def _prepare_deepspeed(self, *args):
):
scheduler = obj

mpu = None
if cp_size > 1:
if is_dataloader_present and model is None:
raise ValueError(
"You cannot pass a dataloader to `accelerate.prepare()` without passing a model when using Context Parallelism."
)
ver_min_required = "0.18.0" # XXX: change to 0.18.1 when released
if not compare_versions("deepspeed", ">=", ver_min_required):
raise ImportError(
f"Deepspeed ALST/Ulysses requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`."
)

from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter
from deepspeed.utils import groups

mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=model,
sequence_parallel_size=cp_size,
max_length=cp_handler.max_length,
core_attn_implementation=cp_handler.attn_implementation or model.config.attn_implementation,
micro_batch_size=batch_size_per_device,
)

if optimizer is not None:
if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)):
raise ValueError(
Expand Down Expand Up @@ -2266,7 +2300,7 @@ def _prepare_deepspeed(self, *args):
)
deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs)
self.deepspeed_config = deepspeed_plugin.deepspeed_config
kwargs = dict(model=model, config_params=self.deepspeed_config)
kwargs = dict(model=model, config_params=self.deepspeed_config, mpu=mpu)
if optimizer is not None:
if isinstance(optimizer, (DummyOptim)):
kwargs["model_parameters"] = optimizer.params
Expand Down Expand Up @@ -2323,6 +2357,19 @@ def _prepare_deepspeed(self, *args):
type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
):
result[i] = scheduler
elif isinstance(result[i], torch.utils.data.DataLoader):
if cp_size > 1:
cp_group = groups._get_sequence_parallel_group()
cp_world_size = groups._get_sequence_parallel_world_size()
cp_rank = groups._get_sequence_parallel_rank()
result[i] = UlyssesSPDataLoaderAdapter(
result[i],
sp_rank=cp_rank,
sp_group=cp_group,
sp_world_size=cp_world_size,
device=model.device,
)

# pointing for deepspeed_engine_wrapped.backward()
if self.deepspeed_engine_wrapped is None:
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
Expand Down Expand Up @@ -3910,9 +3957,10 @@ def get_state_dict(self, model, unwrap=True):
tp_sharding = self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
if zero3_sharding or tp_sharding:
if model.zero_gather_16bit_weights_on_model_save():
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
ver_min_required = "0.16.4"
if tp_sharding and not compare_versions("deepspeed", ">=", ver_min_required):
raise ImportError(
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
f"Deepspeed TP requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`."
)
state_dict = (
model._consolidated_16bit_state_dict()
Expand Down
26 changes: 19 additions & 7 deletions src/accelerate/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
import os
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union

from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
from typing import TYPE_CHECKING, Literal, Optional, Union

from accelerate.utils.dataclasses import (
DeepSpeedContextParallelConfig,
DistributedType,
TorchContextParallelConfig,
TorchTensorParallelConfig,
)
from accelerate.utils.versions import is_torch_version


Expand Down Expand Up @@ -56,6 +61,7 @@ class ParallelismConfig:

"""

backend: Literal["torch", "deepspeed"] = "torch"
dp_replicate_size: Optional[int] = None
dp_shard_size: Optional[int] = None
tp_size: Optional[int] = None
Expand Down Expand Up @@ -254,11 +260,13 @@ def __post_init__(self):

if self.tp_size > 1:
if self.tp_handler is None:
self.tp_handler = TorchTensorParallelConfig()
self.tp_handler = TorchTensorParallelConfig() if self.backend == "torch" else None

if self.cp_size > 1:
if self.cp_handler is None:
self.cp_handler = TorchContextParallelConfig()
self.cp_handler = (
TorchContextParallelConfig() if self.backend == "torch" else DeepSpeedContextParallelConfig()
)

if self.dp_replicate_size < 1:
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
Expand Down Expand Up @@ -304,9 +312,13 @@ def _validate_accelerator(self, accelerator: "Accelerator"):
f"dp_shard_size/tp_size/cp_size."
)

if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
if self.total_size > 1 and not (
accelerator.is_fsdp2
or accelerator.multi_device
or accelerator.distributed_type == DistributedType.DEEPSPEED
):
raise ValueError(
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}."
)

for parallelism, size in self._sizes.items():
Expand Down
34 changes: 34 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,40 @@ def __post_init__(self):
)


@dataclass
class DeepSpeedContextParallelConfig:
max_length: int = field(
default=None,
metadata={"help": "Maximum sequence length to process."},
)
attn_implementation: str = field(
default=None,
metadata={
"help": "Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. If not provided, default from model will be used."
},
)

def __post_init__(self):
if self.max_length is None:
if "PARALLELISM_CONFIG_CP_MAX_LENGTH" not in os.environ:
raise ValueError(
"max_length must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_CP_MAX_LENGTH"
)
self.max_length = int(os.environ["PARALLELISM_CONFIG_CP_MAX_LENGTH"])

if self.attn_implementation is None:
self.attn_implementation = os.environ.get("PARALLELISM_CONFIG_CP_ATTN_IMPLEMENTATION", None)

if self.attn_implementation is not None and self.attn_implementation not in [
"flash_attention_2",
"flash_attention_3",
"sdpa",
]:
raise ValueError(
f"Invalid attn_implementation: {self.attn_implementation}. Must be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'."
)


@dataclass
class TorchTensorParallelConfig:
"""
Expand Down