-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Description
Feature request
Long context models like Qwen/Qwen2.5-7B-Instruct-1M have support for up to 1M tokens. However, fine-tuning such models in transformers leads to OOM errors and special methods like Ring Attention are needed. A similar issue arises during inference, where generating on a 1M prefill gives OOM.
It would be very exciting to have support for context parallelism where in each layer we split the KQV computation across GPUs.
As far as an API goes, having something like attn_implemention="ring" in from_pretrained() would likely be the simplest way to support this feature.
Links to papers and code:
- Ring Attention: https://arxiv.org/abs/2310.01889
- Reference code: https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/adapters/hf_adapter.py
- Picotron code: https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py
Motivation
The main motivation is two-fold: to support fine-tuning large context models and to enable online RL methods like GRPO to scale better in TRL (where we generate potentially large CoTs during training)
Your contribution
Happy to review / test the feature on TRL side