Skip to content
224 changes: 224 additions & 0 deletions tests/trainer/test_trainer_context_parallel_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import sys
from pathlib import Path

from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_accelerate,
require_torch_multi_accelerator,
run_first,
slow,
)


if is_torch_available():
import torch

from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainingArguments,
)


class TestTrainerContextParallelTorch(TestCasePlus):
"""Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig."""

@require_torch_multi_accelerator
@require_accelerate
@slow
@run_first
def test_cp_equivalence(self):
"""Test that CP produces the same losses as without CP."""

# Shared setup
world_size = 2
script_path = __file__

# Step 1: Run with CP enabled (cp_size=world_size)
cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml"
cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json"

# Write config file inline (self-contained test)
with open(cp_yes_config_path, "w") as f:
f.write(
f"""distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_processes: {world_size}
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
parallelism_config_cp_size: {world_size}
parallelism_config_cp_comm_strategy: alltoall
"""
)

cmd_cp_yes = f"""
accelerate launch
--config_file {cp_yes_config_path}
{script_path}
--output_dir {cp_yes_output_dir}
--report_to none
--max_steps 10
--per_device_train_batch_size 1
--gradient_accumulation_steps 1
--logging_steps 1
--remove_unused_columns False
--seed 42
--loss_output_file {cp_yes_losses_path}
""".split()

execute_subprocess_async(cmd_cp_yes, env=self.get_env())

# Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config)
cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml"
cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json"

# Write config file inline (self-contained test)
with open(cp_no_config_path, "w") as f:
f.write(
"""distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_version: 2
mixed_precision: bf16
num_processes: 1
"""
)

cmd_cp_no = f"""
accelerate launch
--config_file {cp_no_config_path}
{script_path}
--output_dir {cp_no_output_dir}
--report_to none
--max_steps 10
--per_device_train_batch_size 1
--gradient_accumulation_steps 1
--logging_steps 1
--remove_unused_columns False
--seed 42
--loss_output_file {cp_no_losses_path}
""".split()

execute_subprocess_async(cmd_cp_no, env=self.get_env())

# Compare losses - should be very close since CP just splits sequence computation
with open(cp_yes_losses_path) as f:
cp_yes_losses = json.load(f)
with open(cp_no_losses_path) as f:
cp_no_losses = json.load(f)

assert len(cp_yes_losses) == len(cp_no_losses), (
f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}"
)

# CP should produce very similar results (small numerical differences expected)
# The differences come from:
# - Different gradient reduction patterns in distributed training
# - BF16 mixed precision accumulated differences
# - Sequence splitting and gathering in CP mode
cp_yes_losses_tensor = torch.tensor(cp_yes_losses)
cp_no_losses_tensor = torch.tensor(cp_no_losses)

# Use torch.testing.assert_close with rtol=2% and atol=0.02
# Testing shows actual differences are typically <1.5%
torch.testing.assert_close(
cp_yes_losses_tensor,
cp_no_losses_tensor,
rtol=2e-2, # 2% relative tolerance
atol=2e-2, # 0.02 absolute tolerance
msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}",
)


if __name__ == "__main__":
# Parse custom arguments (not TrainingArguments parameters)
loss_output_file = None

if "--loss_output_file" in sys.argv:
idx = sys.argv.index("--loss_output_file")
loss_output_file = sys.argv[idx + 1]
sys.argv.pop(idx)
sys.argv.pop(idx)

parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]

# Use SmolLM (small Llama-based model that works with CP)
model_name = "HuggingFaceTB/SmolLM-135M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="sdpa", # CP requires SDPA
)

# Create simple dataset: just tokenize some text
texts = [
"The quick brown fox jumps over the lazy dog. " * 10,
"Hello world, this is a test sentence for training. " * 10,
] * 4 # 8 samples total

def tokenize_function(examples):
return tokenizer(examples, max_length=128, truncation=True, padding="max_length")

train_dataset = [tokenize_function(text) for text in texts]

# Use standard DataCollatorForLanguageModeling for causal LM
# pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2)
# Trainer will automatically generate position_ids and shift_labels as needed
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Causal language modeling
pad_to_multiple_of=4,
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)

# Train for a few steps
trainer.train()

# Verify training completed
assert trainer.state.global_step > 0, "Training should have completed at least one step"

# Save losses to file if requested (for equivalence testing)
if loss_output_file and training_args.process_index == 0:
losses = [log["loss"] for log in trainer.state.log_history if "loss" in log]
with open(loss_output_file, "w") as f:
json.dump(losses, f)