Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
cfee9c9
HF Trainer: ALST/Ulysses sequence parallelism integration via HF Acce…
sfc-gh-sbekman Oct 23, 2025
6e28ca8
make it work + tests
sfc-gh-sbekman Oct 28, 2025
86a09b9
cleanup
sfc-gh-sbekman Oct 28, 2025
bb902f9
Merge branch 'main' into alst-integration
stas00 Oct 28, 2025
c0e8e0d
undo
sfc-gh-sbekman Oct 28, 2025
101eaff
normalize
sfc-gh-sbekman Nov 5, 2025
d8770d5
always return cp_size
sfc-gh-sbekman Nov 5, 2025
4f416a4
cleanup
sfc-gh-sbekman Nov 5, 2025
ce5e392
extract code into _deepspeed_cp_compute_loss
sfc-gh-sbekman Nov 5, 2025
3ceaa94
fix
sfc-gh-sbekman Nov 5, 2025
607e166
Merge branch 'main' into alst-integration
stas00 Nov 5, 2025
211b6df
ALST/Ulysses sequence parallelism docs
kashif Nov 9, 2025
34b208c
typo
kashif Nov 10, 2025
816cc96
add link to UlyssesSPDataLoaderAdapter
kashif Nov 10, 2025
b3cbfb1
Merge pull request #3 from kashif/alst-doc
stas00 Nov 10, 2025
674db46
Merge remote-tracking branch 'origin/main' into alst-integration
sfc-gh-sbekman Nov 17, 2025
b12249a
adapt to renaming to SP
sfc-gh-sbekman Nov 17, 2025
4be7619
improve
sfc-gh-sbekman Nov 17, 2025
21ec5e5
fix
sfc-gh-sbekman Nov 17, 2025
bc32a16
Update docs/source/en/deepspeed.md
stas00 Nov 17, 2025
a850a3a
Merge branch 'main' into alst-integration
stas00 Nov 18, 2025
0127933
address comments
sfc-gh-sbekman Nov 18, 2025
a50c89c
Merge branch 'alst-integration' of https://github.com/stas00/transfor…
sfc-gh-sbekman Nov 18, 2025
5e29dd9
address comments
sfc-gh-sbekman Nov 18, 2025
6ce745d
Update src/transformers/trainer.py
stas00 Nov 18, 2025
59972a3
address comments
sfc-gh-sbekman Nov 18, 2025
f554277
address comments
sfc-gh-sbekman Nov 18, 2025
0eef76f
Update src/transformers/trainer.py
stas00 Nov 18, 2025
c277586
Merge branch 'main' into alst-integration
stas00 Nov 18, 2025
8201150
Update src/transformers/trainer.py
stas00 Nov 18, 2025
854fd51
style
sfc-gh-sbekman Nov 18, 2025
76ee3ad
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
083ca01
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
6929fb2
Account for Sequence Parallelism (SP) dataloader adapter effect
kashif Nov 19, 2025
6ae2bb0
Update src/transformers/trainer.py
stas00 Nov 19, 2025
407f34a
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
363909b
Update docs/source/en/deepspeed.md
stas00 Nov 19, 2025
8f62f14
Merge branch 'main' into alst-integration
stas00 Nov 19, 2025
6c5b00c
Merge pull request #4 from kashif/sp_len
stas00 Nov 19, 2025
49c5ed7
model_accepts_loss_kwargs to False
kashif Nov 19, 2025
4cafb9b
better comment
kashif Nov 19, 2025
7c1abd5
Merge pull request #5 from kashif/loss_kwargs
stas00 Nov 19, 2025
58e4e13
Apply suggestion from @kashif
kashif Nov 19, 2025
d8d53c2
Apply suggestion from @kashif
kashif Nov 19, 2025
a05eb52
Apply suggestions from code review
kashif Nov 19, 2025
ad61079
Merge branch 'main' into alst-integration
kashif Nov 20, 2025
3fd097d
Apply suggestion from @kashif
kashif Nov 20, 2025
e3d8eda
Apply suggestion from @kashif
kashif Nov 20, 2025
ef59f3e
Apply suggestion from @kashif
kashif Nov 20, 2025
59487a8
Update src/transformers/trainer.py
kashif Nov 20, 2025
2444728
Update src/transformers/training_args.py
kashif Nov 20, 2025
4f33c2f
Merge branch 'main' into alst-integration
kashif Nov 21, 2025
7d09b28
Apply suggestion from @kashif
kashif Nov 21, 2025
2e52913
Apply suggestion from @kashif
kashif Nov 21, 2025
7a5c45e
Merge branch 'main' into alst-integration
SunMarc Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import gc
import importlib
import inspect
import json
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -2005,14 +2006,12 @@ def get_env(self):
paths = [self.repo_root_dir_str, self.src_dir_str]
if "/examples" in self.test_file_dir_str:
paths.append(self.examples_dir_str)
else:
paths.append(self.tests_dir_str)
paths.append(env.get("PYTHONPATH", ""))

env["PYTHONPATH"] = ":".join(paths)
return env

def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False):
Copy link
Contributor Author

@stas00 stas00 Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a really old version. In the latest incarnation it always return a Path object. But to keep BC, I added a new flag here instead. The tests are less clunkier then.

The latest version is here: https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py

If wanted you could switch to the latest version instead and adapt tests to simplify.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's much better for it to always return a pathlib.Path object but you'd need to tweak a few tests which use this API.

"""
Args:
tmp_dir (`string`, *optional*):
Expand All @@ -2032,6 +2031,8 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
after (`bool`, *optional*):
If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
intact at the end of the test.
return_pathlib_obj (`bool`, *optional*):
If `True` will return a pathlib.Path object

Returns:
tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
Expand Down Expand Up @@ -2078,7 +2079,7 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
# register for deletion
self.teardown_tmp_dirs.append(tmp_dir)

return tmp_dir
return Path(tmp_dir).resolve() if return_pathlib_obj else tmp_dir

def python_one_liner_max_rss(self, one_liner_str):
"""
Expand Down Expand Up @@ -4076,3 +4077,13 @@ def use_one_line_repr(obj):
cache[(id(obj), indent, mode, prefix)] = output

return output


def write_file(file, content):
with open(file, "w") as f:
f.write(content)


def read_json_file(file):
with open(file, "r") as fh:
return json.load(fh)
94 changes: 77 additions & 17 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,17 @@ def train(
ignore_keys_for_eval=ignore_keys_for_eval,
)

def get_cp_size(self) -> int:
"""Get the context parallel size"""
if getattr(self.accelerator, "parallelism_config", None) is None:
return 1
pc = self.accelerator.parallelism_config
if pc.cp_backend == "deepspeed":
return pc.cp_size
# XXX: not sure if backend=="torch" needs to return cp_size here as well it wasn't originally (a bug?)

return 1

def get_tp_size(self) -> int:
"""Get the tensor parallel size from either the model or DeepSpeed config."""

Expand All @@ -2176,8 +2187,8 @@ def get_tp_size(self) -> int:
def get_total_train_batch_size(self, args) -> int:
"""Calculates total batch size (micro_batch * grad_accum * dp_world_size).

Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
dp_world_size = args.world_size // self.get_tp_size()
Note: Only considers DP and TP and SP/CP (dp_world_size = world_size // tp_size // cp_size)."""
dp_world_size = args.world_size // self.get_tp_size() // self.get_cp_size()
return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size

def _inner_training_loop(
Expand Down Expand Up @@ -2301,6 +2312,11 @@ def _inner_training_loop(
else:
self.optimizer = self.accelerator.prepare(self.optimizer)

# since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.cp_backend == "deepspeed":
train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model

Expand Down Expand Up @@ -3635,23 +3651,32 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.
getattr(self.accelerator, "parallelism_config", None) is not None
and self.accelerator.parallelism_config.cp_enabled
):
if hasattr(model, "config"):
if model.config._attn_implementation != "sdpa":
raise ValueError(
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
)
if self.accelerator.parallelism_config.cp_backend == "torch":
if hasattr(model, "config"):
if model.config._attn_implementation != "sdpa":
raise ValueError(
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
)

if "shift_labels" not in inputs:
logger.warning_once("Shift labels not found in the inputs, shifting manually")
if "labels" in inputs:
_ignore_index = -100
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
inputs["shift_labels"] = labels[:, 1:].contiguous()

# carve out space to make it clear there are other backends with different requirements, even though no code needs to be run at the moment
elif self.accelerator.parallelism_config.cp_backend == "deepspeed":
# - accelerator.parallelism_config performs the `model.config._attn_implementation` checks already and it supports more than `dspa`
# - UlyssesSPDataLoaderAdapter called from Accelerate performs the `shift_label` creation - must not interfere
# - position_ids generation should be done by HF Trainer if it wasn't done by the user
pass

if "position_ids" not in inputs:
logger.warning_once("Position IDs not found in the inputs, generating manually")
inputs["position_ids"] = torch.arange(
inputs["input_ids"].size(1), device=inputs["input_ids"].device
).expand(inputs["input_ids"].size(0), -1)
if "shift_labels" not in inputs:
logger.warning_once("Shift labels not found in the inputs, shifting manually")
if "labels" in inputs:
_ignore_index = -100
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
inputs["shift_labels"] = labels[:, 1:].contiguous()

buffers = []
buffer_seq_dims = []
Expand Down Expand Up @@ -3820,6 +3845,33 @@ def compute_loss(
Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation.
"""
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.cp_backend == "deepspeed":
unwrapped_model = self.accelerator.unwrap_model(model)

outputs = model(**inputs)
shift_labels = inputs["shift_labels"]
loss = unwrapped_model.loss_function(
logits=outputs.logits,
labels=None,
shift_labels=shift_labels,
vocab_size=unwrapped_model.config.vocab_size,
)

if pc.cp_size > 1:
sp_group = self.accelerator.torch_device_mesh["cp"].get_group()
sp_world_size = pc.cp_size
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / max(total_good_tokens, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need to do this if num_items_in_batch is computed and passed in in unwrapped_model.loss_function. num_items_in_batch was introduced to fix the gradient accumulation https://unsloth.ai/blog/gradient. num_items_in_batch is basically total_good_tokens if grad_acc = 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what is not needed?

This code block is because we need to compute the correct loss across SP ranks. If you just average those it'll be incorrect in the case of -100 masked tokens (SFT), since each rank is likely to process a different number of unmasked tokens (this is not DP averaging).

Unless what you mean is that we don't need to calculate total_good_tokens since num_items_in_batch is already that, but the rest of the code remains - did I understand you correctly?

Copy link
Member

@SunMarc SunMarc Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you pass num_items_in_batch in loss_function, it will sum the loss then divide it by num_items_in_batch directly. This way I think we don't need to actually to recalculate the total_loss from the averaged losses and the good_tokens_per_rank. Maybe I'm wrong so please correct me ! But I think this might solve the grad acc issue. In any case, we will keep the current code as not all models accepts num_items_in_batch when calculating the loss.

total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))

def ForCausalLMLoss(
    logits,
    labels,
    vocab_size: int,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    shift_labels: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()

    if shift_labels is None:
        # Shift so that tokens < n predict n
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
        shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    logits = logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(logits.device)
    loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    return loss


def fixed_cross_entropy(
    source: torch.Tensor,
    target: torch.Tensor,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    **kwargs,
) -> torch.Tensor:
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss

Copy link
Contributor Author

@stas00 stas00 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass num_items_in_batch you indeed don't need to do local loss calculation since it'll do that already. But we need to calculate a distributed across ranks loss.

Here is an example: Let's take a 2k tokens sample SP-split across 2 ranks using SFT:

  1. SP rank0 - 900 masked and 100 non-masked tokens (a long initial prompt that is -100 masked out)
  2. SP rank1 - 100 masked and 900 non-masked tokens

So each rank produces the correct loss if we use num_items_in_batch - but how do you combine the losses of 2 ranks. straight average will give a very skewed result, because the rank0's loss contributes 9x less non-masked tokens.

Let's take it to a more telling example:

  1. SP rank0 - 1000 masked and 0 non-masked tokens (a long initial prompt that is masked out)
  2. SP rank1 - 0 masked and 1000 non-masked tokens

here rank0 can't even contribute anything to the total loss - a normal averaging of 2 losses would be completely broken, since you'd average with an undefined behavior, since the loss function will return a NaN or None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So each rank produces the correct loss if we use num_items_in_batch - but how do you combine the losses of 2 ranks. straight average will give a very skewed result, because the rank0's loss contributes 9x less non-masked tokens.

The denominator of the losses is both num_items_in_batch, the value of each loss already takes into account the number of non-masked tokens as we do reduction = "sum". So we just sum them to get the final loss. In your first examples, num_items_in_batch will be equal to 1000. For rank0, the loss will be equal to (L1+...L100)/1000 and for rank1, it will be (l1+..+l900)/1000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we are missing each other. I'm talking about differentiable loss combination across ranks and I think you're talking about the local rank's loss.

Could you please point me to the code in HF Trainer that performs a differentiable loss combination across multiple ranks? I couldn't find any.


return (loss, outputs) if return_outputs else loss

if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
else:
Expand Down Expand Up @@ -3913,7 +3965,11 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
Path(os.path.join(output_dir, "user_content.pt")).touch()
# We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
elif getattr(self.accelerator, "parallelism_config", None) is not None:
if self.accelerator.should_save_model:
# deepspeed already takes care of saving the checkpoint below, so we need this only for the torch cp backend
if (
self.accelerator.should_save_model
and getattr(self.accelerator, "parallelism_config").cp_backend == "torch"
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this for sp deepspeed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything to update here, what am I missing? This code is to be run for torch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless you propose to revert the code part? I thought more explicit will make it easier down the road.

Copy link
Member

@SunMarc SunMarc Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing is that cp_backend default value is "torch" even if you decide to use sp with sp_backend="deepspeed". So it will always run. maybe change to something like this ?

pc = getattr(self.accelerator, "parallelism_config")
if self.accelerator.should_save_model and not (pc.sp_enabled and  pc.sp_backend == "deepspeed"):

self._save(output_dir)
# If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained`
elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
Expand Down Expand Up @@ -4981,9 +5037,11 @@ def create_accelerator_and_postprocess(self):

# We defer compatibility checks to accelerator
if self.args.parallelism_config is not None:
if not is_accelerate_available("1.10.1"):
# XXX: this will need to change once https://github.com/huggingface/accelerate/pull/3817 is merged and 1.11.1 is out
min_accelerate_version = "1.10.1"
if not is_accelerate_available(min_accelerate_version):
raise ImportError(
"ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature."
)
args["parallelism_config"] = self.args.parallelism_config

Expand Down Expand Up @@ -5136,7 +5194,9 @@ def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) ->
# In the DataParallel case, convert the scalar tensor into a 2-dim tensor with the same value repeated
num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1)
# Divide by number of devices with the same batch
if pc := getattr(self.accelerator, "parallelism_config", None):
# XXX: double check if this is only for fsdp
pc = getattr(self.accelerator, "parallelism_config", None)
if pc is not None and pc.cp_backend == "torch":
num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size

return num_items_in_batch
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ class TrainingArguments:
)
parallelism_config: Optional[ParallelismConfig] = field(
default=None,
metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")},
metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.11.1`")},
)
deepspeed: Optional[Union[dict, str]] = field(
default=None,
Expand Down
Loading