Skip to content

Conversation

@stas00
Copy link
Contributor

@stas00 stas00 commented Oct 22, 2025

This is the completion of the work started by @S1ro1 at #3782 to integrate the ALST/Ulysses long sequence training into HF Accelerate. Paper https://arxiv.org/abs/2506.13996. This is Matej's original code with lots of additional work on top and docs+tests from me.

Here is the corresponding HF Trainer integration PR: huggingface/transformers#41832

If you want to try it out please first install deepspeed from deepspeed@master as deepspeed needed some tweaks to make this integration work.

To use this feature a user needs

  1. to create ParallelismConfig
parallelism_config = ParallelismConfig(
    sp_backend="deepspeed",
    sp_size=2,
    sp_handler=DeepSpeedSequenceParallelConfig(attn_implementation="flash_attention_2"),
)

accelerator = Accelerator(parallelism_config=parallelism_config)
  1. add to their code the use of shift_labels and an aggregation of loss across ranks
    shift_labels = batch["shift_labels"]
    loss = model.module.loss_function(
        logits=outputs.logits,
        labels=None,
        shift_labels=shift_labels,
        vocab_size=model.module.config.vocab_size,
    )

    # differentiable weighted per-shard-loss aggregation across ranks
    losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
    # special dealing with SFT that has prompt tokens that aren't used in loss computation
    good_tokens = (shift_labels != -100).view(-1).sum()
    good_tokens_per_rank = torch.distributed.nn.functional.all_gather(
        good_tokens, group=sp_group
    )
    total_loss = sum(
        losses_per_rank[rank] * good_tokens_per_rank[rank]
        for rank in range(sp_world_size)
    )
    total_good_tokens = sum(good_tokens_per_rank)
    loss = total_loss / max(total_good_tokens, 1)

Quality validation

I wrote 3 accelerate-based scripts (attached at the end of the OP):

  • 1 gpu
  • 4 gpus w/ fsdp
  • 4 gpus w/ deepspeed ulysses

The loss checks out with very small variations due the precision loss in aggregation.

Screenshot 2025-10-22 at 09 22 50

TODO

These are really needed for the HF Trainer PR huggingface/transformers#41832 but since it anchors on accelerate let's make the dependency here instead.

Scripts used to perform the quality validation

The scripts and config files are:

You run the .sh files

cc: @SunMarc

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, this is already looking quite nice ! Left some minor comments. Please ping me when you have finished the integration !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
stas00 added a commit to deepspeedai/DeepSpeed that referenced this pull request Oct 22, 2025
Ulysses/ALST integration with HF Accelerate:
- Allow `UlyssesSPAttentionHF.register_with_transformers` to get a
`model` obj as an argument, to match HF accelerate's workflow
- Fix existing Ulysses' tests to tests z2 instead of z1
- Improve documentation
- Add a defensive check

The HF Accelerate PR that depends on this PR is here
huggingface/accelerate#3817

---------

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
@stas00
Copy link
Contributor Author

stas00 commented Oct 22, 2025

@SunMarc,

  • I did some more tweaks to improve/simplify UX
  • added docs - did I miss any places where I should mention this backend?
  • added tests - I didn't actually see any torch CP e2e tests - do they even exist? in any case I wrote a simple e2e test for this PR - the quality checks are already extensively tested in the deepspeed repo

So we just have a few conversations above to complete and otherwise I'm just waiting for the deepspeed to make a new version so that we could anchor on it here. Otherwise it's ready for your complete review - but don't merge just yet until we get the new ds version here.

And then we can discuss the HF Trainer integration. Should we somehow mark this API as experimental to let users use it for a bit and possibly adjust things? If so please give me an example to follow.

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
@SunMarc
Copy link
Member

SunMarc commented Oct 23, 2025

added tests - I didn't actually see any torch CP e2e tests - do they even exist? in any case I wrote a simple e2e test for this PR - the quality checks are already extensively tested in the deepspeed repo

There are some e2e examples in the example/torch_native_parallelism folder but we are not running them in the CI.

And then we can discuss the HF Trainer integration. Should we somehow mark this API as experimental to let users use it for a bit and possibly adjust things? If so please give me an example to follow.

Let's try to integrate it into HF Trainer before merging this PR. Once it is tightly coupled to Trainer, even if the API is marked as experimental, we will most likely try to limit breaking changes. For experimental features, we just put it on the docs, like for big model inference (probably need to remove the warning for this feature)

<Tip warning={true}>

This API is quite new and still in its experimental stage. While we strive to provide a stable API, it's possible some small parts of the public API will change in the future.

</Tip>

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the docs and the tests, this looks really nice. Just some minor nits

@stas00
Copy link
Contributor Author

stas00 commented Oct 23, 2025

Taking this out of code comments so that it doesn't disappear with 'Resolve conversation'

To give a sense of what ALST made possible - it allowed us to train in bf16 with 500K tokens on a single H100 GPU, 3.7M on a single node, and 15M on Llama-8B using just four nodes. This feature of HF Accelerate enables only 1 of the 3 ALST components so the achievable sequence length will be smaller. You'd want TiledMLP, Activation checkpoint offload to CPU and a few other things enabled to get the full power of ALST, for details please refer to this tutorial.

@SunMarc: what would it take to enable the other features ?

Those features belong to

  1. HF Transformers:
  1. HF Transformers and pytorch

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
@SunMarc
Copy link
Member

SunMarc commented Nov 20, 2025

Can we merge this PR @stas00 ?

@stas00
Copy link
Contributor Author

stas00 commented Nov 20, 2025

Thank you for further improvements/fixes, Kashif

Yes, let's merge it. Thank you.

@kashif kashif merged commit c7e59dd into huggingface:main Nov 20, 2025
41 of 45 checks passed
@egangu
Copy link

egangu commented Nov 26, 2025

Merging this PR seems to have invalidated the original cp implement.
In _prepare_cp, the function automatically passes when parallelism_config.sp_backend is "deepspeed"

if self.parallelism_config.sp_backend == "deepspeed":

However, parallelism_config.sp_backend can only be set to "deepspeed"
valid_sp_backends = ["deepspeed"]

@kashif
Copy link
Contributor

kashif commented Nov 26, 2025

@egangu so for CP one would need to add the appropriate cp_size and cp_backend configs

@egangu
Copy link

egangu commented Nov 27, 2025

@egangu so for CP one would need to add the appropriate cp_size and cp_backend configs

You're right. But what I mean is that the original CP cannot be enabled on the current version, no matter how user set the cp_size or cp_backend configurations.
This is because in the _prepare_cp function (the core code that enables the original CP), as I pointed out above, the current implementation automatically skips it.
Rolling back the accelerate version to 1.11 will enable the original CP.

@kashif
Copy link
Contributor

kashif commented Nov 27, 2025

thanks for the report @egangu let me test and fix

@kashif kashif mentioned this pull request Nov 27, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants