Skip to content

Add support for context parallelism #35983

@lewtun

Description

@lewtun

Feature request

Long context models like Qwen/Qwen2.5-7B-Instruct-1M have support for up to 1M tokens. However, fine-tuning such models in transformers leads to OOM errors and special methods like Ring Attention are needed. A similar issue arises during inference, where generating on a 1M prefill gives OOM.

It would be very exciting to have support for context parallelism where in each layer we split the KQV computation across GPUs.

As far as an API goes, having something like attn_implemention="ring" in from_pretrained() would likely be the simplest way to support this feature.

Links to papers and code:

Motivation

The main motivation is two-fold: to support fine-tuning large context models and to enable online RL methods like GRPO to scale better in TRL (where we generate potentially large CoTs during training)

Your contribution

Happy to review / test the feature on TRL side

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions