Skip to content

Conversation

@S1ro1
Copy link
Contributor

@S1ro1 S1ro1 commented Sep 16, 2025

cc @stas00

from accelerate import Accelerator
from accelerate.utils import ParallelismConfig
from deepspeed.runtime.utils import move_to_device
from accelerate.utils.dataclasses import DeepSpeedContextParallelConfig
from torch import tensor
from transformers import AutoModelForCausalLM
import torch

MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM"
MICRO_BATCH_SIZE = 1

parallelism_config = ParallelismConfig(
    flavour="deepspeed",
    cp_size=2,
    dp_shard_size=4,
    cp_handler=DeepSpeedContextParallelConfig(
        max_length=64, attn_implementation="sdpa"
    ),
)

accelerator = Accelerator(parallelism_config=parallelism_config)

input_ids = tensor(
    [[1, 10, 10, 10, 2, 2], [1, 20, 20, 20, 2, 2]],
)
position_ids = tensor([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]])
ds = torch.utils.data.TensorDataset(input_ids, position_ids)


def collate_fn(batch):
    input_ids, position_ids = batch[0]
    return dict(
        input_ids=input_ids.unsqueeze(0),
        position_ids=position_ids.unsqueeze(0),
        labels=input_ids.unsqueeze(0),
        shift_labels=input_ids.unsqueeze(0),
    )


model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dl = torch.utils.data.DataLoader(ds, batch_size=MICRO_BATCH_SIZE, collate_fn=collate_fn)

model, optimizer, dl = accelerator.prepare(model, optimizer, dl)

sp_group = accelerator.torch_device_mesh["cp"].get_group()
sp_world_size = parallelism_config.cp_size

# Normal training loop
for iter, batch in enumerate(dl):
    batch = move_to_device(batch, model.device)
    outputs = model(**batch)

    shift_labels = batch["shift_labels"]
    loss = model.module.loss_function(
        logits=outputs.logits,
        labels=None,
        shift_labels=shift_labels,
        vocab_size=model.module.config.vocab_size,
    )

    # differentiable weighted per-shard-loss aggregation across ranks
    losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
    # special dealing with SFT that has prompt tokens that aren't used in loss computation
    good_tokens = (shift_labels != -100).view(-1).sum()
    good_tokens_per_rank = torch.distributed.nn.functional.all_gather(
        good_tokens, group=sp_group
    )
    total_loss = sum(
        losses_per_rank[rank] * good_tokens_per_rank[rank]
        for rank in range(sp_world_size)
    )
    total_good_tokens = sum(good_tokens_per_rank)
    loss = total_loss / max(total_good_tokens, 1)

    accelerator.print(f"{iter}: {loss=}")

    model.backward(loss)

@S1ro1 S1ro1 marked this pull request as draft September 16, 2025 14:27
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@stas00
Copy link
Contributor

stas00 commented Sep 16, 2025

probably, let's use:

sp_world_size = dist.get_world_size(group=sp_group)

so it's self-contained - should this code be called where parallelism_config var is not in scope?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sfc-gh-sbekman
Copy link
Contributor

unstale

@stas00 stas00 mentioned this pull request Oct 22, 2025
6 tasks
@stas00
Copy link
Contributor

stas00 commented Oct 22, 2025

Matej did all the heavy lifting, and I'm bringing it to the finish line. Since I couldn't take over this branch - I made a new PR #3817

so this can be closed.

@SunMarc
Copy link
Member

SunMarc commented Oct 22, 2025

Thanks @stas00 !

@SunMarc SunMarc closed this Oct 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants