Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions docs/source/distributing_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,156 @@ These results show that **Context Parallelism (CP) scales effectively with more
- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl)
- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST) — Scalable and Efficient Training for Multi-Million Token Sequences (Note that they use a different strategy)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/)

## ALST/Ulysses Sequence Parallelism (DeepSpeed)

ALST (Arctic Long Sequence Training) / Ulysses is an alternative Context Parallelism implementation available through DeepSpeed. Unlike the FSDP2-based approach described above, ALST/Ulysses uses DeepSpeed's sequence parallelism infrastructure to split long sequences across GPUs.

### Key Differences from FSDP2 Context Parallelism

| Feature | FSDP2 CP | ALST/Ulysses (DeepSpeed) |
|---------|----------|-------------------------|
| Backend | PyTorch FSDP2 | DeepSpeed ZeRO |
| Attention | SDPA only | Flash Attention 2 or SDPA |
| Minimum Accelerate | 1.10+ | 1.10+ |
| Minimum DeepSpeed | N/A | 0.18.1+ |
| Sequence Divisibility | `cp_size * 2` | `cp_size` |
| Zero Stage | N/A | ZeRO Stage 1/2/3 |

### Requirements and Limitations

ALST/Ulysses has specific requirements:

1. **DeepSpeed 0.18.1 or higher** is required
2. **Accelerate 1.10 or higher** for parallelism configuration support
3. **Attention implementation** - Flash Attention 2 recommended (clean output), SDPA works as fallback
4. **Sequence length divisibility** - sequences must be divisible by `cp_size`. Use `pad_to_multiple_of` in your training config.

### Configuration

To enable ALST/Ulysses, you need to configure both Accelerate and your training arguments:

#### Accelerate Configuration

Use the provided accelerate config file ([`alst_ulysses_4gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/alst_ulysses_4gpu.yaml)):

```yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
zero_stage: 3
seq_parallel_communication_data_type: bf16
distributed_type: DEEPSPEED
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Number of GPUs
```

#### Training Configuration

```python
from accelerate.utils import DeepSpeedContextParallelConfig, ParallelismConfig
from trl import SFTConfig

# Setup ALST/Ulysses parallelism
cp_handler = DeepSpeedContextParallelConfig(
cp_seq_length=4096,
cp_seq_length_is_variable=True,
cp_attn_implementation="flash_attention_2", # or "sdpa" as fallback
)

parallelism_config = ParallelismConfig(
cp_backend="deepspeed",
cp_size=2, # Number of GPUs for sequence parallelism
cp_handler=cp_handler,
)

# Training configuration
training_args = SFTConfig(
max_length=4096,
packing=True,
pad_to_multiple_of=2, # Must equal cp_size
parallelism_config=parallelism_config,
gradient_checkpointing=True,
attn_implementation="flash_attention_2",
per_device_train_batch_size=1,
...
)
```

Then, launch your training script:

```bash
accelerate launch --config_file examples/accelerate_configs/alst_ulysses_4gpu.yaml train.py
```

### 2D Parallelism

The 4 GPU configuration above automatically enables 2D parallelism by combining Data Parallelism (DP) with Context Parallelism (CP). The `dp_shard_size` is automatically calculated to distribute across available GPUs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add experiments here as the fsdp example


```python
import os

num_gpus = int(os.environ.get("WORLD_SIZE", "1"))
cp_size = 2
dp_shard_size = num_gpus // cp_size # Automatically calculated

parallelism_config = ParallelismConfig(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do it via .yaml as in the other example to follow the same style?

cp_backend="deepspeed",
cp_size=cp_size,
dp_shard_size=dp_shard_size, # Enable 2D parallelism
cp_handler=cp_handler,
)
```

Scaling configurations:

| GPUs | cp_size | dp_shard_size | Use Case |
|------|---------|---------------|----------|
| 4 | 2 | 2 | Balanced - longer sequences + more data |
| 4 | 4 | 1 | Pure CP for maximum sequence length |
| 8 | 2 | 4 | Large-scale training |

### Best Practices

1. **Use `pad_to_multiple_of`** to ensure sequences are divisible by `cp_size`
2. **Use Flash Attention 2** for clean output (SDPA works but shows packing warnings)
3. **Start with `cp_size=2`** before scaling to larger values
4. **Use DeepSpeed ZeRO Stage 3** for large models
5. **Combine with memory optimizations** like Liger kernels and gradient checkpointing

### Complete Example

Here's how to run ALST/Ulysses training using the built-in [`sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) script with 4 GPUs:

```bash
accelerate launch --config_file examples/accelerate_configs/alst_ulysses_4gpu.yaml \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

trl/scripts/sft.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2e-4 \
--max_steps 100 \
--max_seq_length 4096 \
--packing \
--packing_strategy wrapped \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--attn_implementation flash_attention_2 \
--output_dir output-alst-4gpu \
--logging_steps 10 \
--report_to trackio
```

This command automatically:
- Configures 2D parallelism (CP=2, DP=2) across 4 GPUs
- Uses Flash Attention 2 for clean training
- Enables packing with automatic padding to ensure sequence divisibility
- Leverages DeepSpeed ZeRO Stage 3 for memory efficiency

### Further Reading

- [DeepSpeed Sequence Parallelism Documentation](https://www.deepspeed.ai/tutorials/ds-sequence/)


## Multi-Node Training

We're working on a guide for multi-node training. Stay tuned! 🚀
37 changes: 37 additions & 0 deletions examples/accelerate_configs/alst_ulysses_4gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# ALST/Ulysses Sequence Parallelism with 2D Parallelism (DP + CP) for 4 GPUs
#
# This configuration enables 2D parallelism:
# - Context Parallelism (cp_size=2): Sequences split across 2 GPUs
# - Data Parallelism (dp_shard_size=2): Model/optimizer sharded across 2 GPUs
# - Total: 4 GPUs (2 × 2)
#
# Set parallelism_config in your training script:
# parallelism_config = ParallelismConfig(
# cp_backend="deepspeed",
# cp_size=2,
# dp_shard_size=2, # Calculated as: num_gpus // cp_size
# cp_handler=DeepSpeedContextParallelConfig(...)
# )

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
zero_stage: 3
seq_parallel_communication_data_type: bf16
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero3_save_16bit_model: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4 # Total number of GPUs
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Loading