From 63efa183965f0fa70d858edf8002fe22fbb16205 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 25 Oct 2025 11:48:07 +0000 Subject: [PATCH 01/10] intial --- tests/trainer/context_parallel_config.yaml | 30 ++ .../trainer/test_trainer_context_parallel.py | 334 ++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 tests/trainer/context_parallel_config.yaml create mode 100644 tests/trainer/test_trainer_context_parallel.py diff --git a/tests/trainer/context_parallel_config.yaml b/tests/trainer/context_parallel_config.yaml new file mode 100644 index 000000000000..f0b4da9c4ace --- /dev/null +++ b/tests/trainer/context_parallel_config.yaml @@ -0,0 +1,30 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: false + fsdp_offload_params: false + fsdp_reshard_after_forward: false + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 1 + parallelism_config_tp_size: 1 + parallelism_config_cp_size: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/tests/trainer/test_trainer_context_parallel.py b/tests/trainer/test_trainer_context_parallel.py new file mode 100644 index 000000000000..37201aa2e941 --- /dev/null +++ b/tests/trainer/test_trainer_context_parallel.py @@ -0,0 +1,334 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import is_torch_available +from transformers.testing_utils import ( + TestCasePlus, + backend_device_count, + execute_subprocess_async, + get_torch_dist_unique_port, + require_accelerate, + require_torch_multi_accelerator, + run_first, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + from datasets import load_dataset + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + HfArgumentParser, + Trainer, + TrainingArguments, + ) + + class ContextParallelDataCollator: + """Data collator for context parallelism - does not create attention masks.""" + + def __init__(self, tokenizer, pad_to_multiple_of=None): + self.tokenizer = tokenizer + self.pad_to_multiple_of = pad_to_multiple_of + + def __call__(self, features): + batch = {} + batch["input_ids"] = torch.stack([torch.tensor(f["input_ids"]) for f in features]) + # For CP, we need shift_labels pre-computed + batch["shift_labels"] = torch.stack([torch.tensor(f["shift_labels"]) for f in features]) + batch["labels"] = batch["shift_labels"].clone() + + # Pad to multiple if specified (required for CP: cp_size * 2) + if self.pad_to_multiple_of is not None: + seq_len = batch["input_ids"].shape[1] + remainder = seq_len % self.pad_to_multiple_of + if remainder != 0: + padding_len = self.pad_to_multiple_of - remainder + batch["input_ids"] = torch.nn.functional.pad( + batch["input_ids"], (0, padding_len), value=self.tokenizer.pad_token_id + ) + batch["shift_labels"] = torch.nn.functional.pad( + batch["shift_labels"], (0, padding_len), value=-100 + ) + batch["labels"] = batch["shift_labels"].clone() + + # Add position_ids (accelerate example includes this) + seq_len = batch["input_ids"].shape[1] + batch["position_ids"] = torch.arange(seq_len).unsqueeze(0).expand(batch["input_ids"].shape[0], -1) + + # Don't create attention_mask - it causes issues with CP + return batch + + +class TestTrainerContextParallel(TestCasePlus): + """Test Trainer with context parallelism enabled via accelerate's ParallelismConfig""" + + @require_torch_multi_accelerator + @require_accelerate + @slow + @run_first + def test_trainer_context_parallel_basic(self): + """Test basic training with context parallelism enabled.""" + output_dir = self.get_auto_remove_tmp_dir() + config_path = f"{self.test_file_dir}/context_parallel_config.yaml" + + cmd = [ + "accelerate", + "launch", + "--config_file", + config_path, + f"{self.test_file_dir}/test_trainer_context_parallel.py", + "--output_dir", + output_dir, + "--report_to", + "none", + "--max_steps", + "5", + "--per_device_train_batch_size", + "1", + "--pad_to_multiple_of", + "4", + "--logging_steps", + "1", + "--remove_unused_columns", + "False", + ] + + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + @require_torch_multi_accelerator + @require_accelerate + @slow + @run_first + def test_trainer_context_parallel_requires_sdpa(self): + """Test that context parallelism requires SDPA attention implementation.""" + output_dir = self.get_auto_remove_tmp_dir() + config_path = f"{self.test_file_dir}/context_parallel_config.yaml" + + cmd = [ + "accelerate", + "launch", + "--config_file", + config_path, + f"{self.test_file_dir}/test_trainer_context_parallel.py", + "--output_dir", + output_dir, + "--report_to", + "none", + "--max_steps", + "5", + "--remove_unused_columns", + "False", + "--test_mode", + "test_non_sdpa", + ] + + # This should fail because we're using eager attention instead of SDPA + with self.assertRaises(Exception): + execute_subprocess_async(cmd, env=self.get_env()) + + @require_torch_multi_accelerator + @require_accelerate + @slow + @run_first + def test_trainer_context_parallel_causal_mask_validation(self): + """Test that context parallelism validates causal attention masks.""" + output_dir = self.get_auto_remove_tmp_dir() + config_path = f"{self.test_file_dir}/context_parallel_config.yaml" + + cmd = [ + "accelerate", + "launch", + "--config_file", + config_path, + f"{self.test_file_dir}/test_trainer_context_parallel.py", + "--output_dir", + output_dir, + "--report_to", + "none", + "--max_steps", + "5", + "--remove_unused_columns", + "False", + "--test_mode", + "test_non_causal_mask", + ] + + # This should fail because we're using a non-causal attention mask + with self.assertRaises(Exception): + execute_subprocess_async(cmd, env=self.get_env()) + + @require_torch_multi_accelerator + @require_accelerate + @slow + @run_first + def test_trainer_context_parallel_auto_generation(self): + """Test that context parallelism auto-generates position_ids and shift_labels.""" + output_dir = self.get_auto_remove_tmp_dir() + config_path = f"{self.test_file_dir}/context_parallel_config.yaml" + + cmd = [ + "accelerate", + "launch", + "--config_file", + config_path, + f"{self.test_file_dir}/test_trainer_context_parallel.py", + "--output_dir", + output_dir, + "--report_to", + "none", + "--max_steps", + "5", + "--remove_unused_columns", + "False", + "--test_mode", + "test_auto_generation", + "--pad_to_multiple_of", + "4", + ] + + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success + + +if __name__ == "__main__": + # This script is meant to be run under torch.distributed with accelerate launch + # with context parallelism enabled via ParallelismConfig + + import argparse + + # Parse custom arguments along with training arguments + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--test_mode", type=str, default="default", help="Test mode to run") + arg_parser.add_argument("--pad_to_multiple_of", type=int, default=None, help="Pad sequences to multiple of this value") + custom_args, remaining_args = arg_parser.parse_known_args() + + parser = HfArgumentParser((TrainingArguments,)) + training_args = parser.parse_args_into_dataclasses(remaining_args)[0] + + # Use SmolLM model (small Llama-based model, works with CP unlike GPT2) + model_name = "HuggingFaceTB/SmolLM-135M" + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Create a simple causal LM dataset + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100]") + + def tokenize_fn(examples): + tokenized = tokenizer( + examples["text"], + max_length=128, + truncation=True, + padding="max_length", + return_attention_mask=False, # Don't create attention mask for CP + ) + # For context parallelism, we need to pre-compute shift_labels + # shift_labels[i] = input_ids[i+1] for causal LM + shift_labels = [] + for input_ids in tokenized["input_ids"]: + # Create shift_labels by taking input_ids[1:] and appending -100 + labels = input_ids[1:] + [tokenizer.pad_token_id] + # Replace pad tokens with -100 + labels = [label if label != tokenizer.pad_token_id else -100 for label in labels] + shift_labels.append(labels) + tokenized["shift_labels"] = shift_labels + return tokenized + + # Don't remove columns yet, keep text for debugging + tokenized_dataset = dataset.map( + tokenize_fn, + batched=True, + remove_columns=dataset.column_names, + ) + + # Verify shift_labels exists + print(f"Dataset columns: {tokenized_dataset.column_names}") + print(f"First example keys: {list(tokenized_dataset[0].keys())}") + + # Select attention implementation based on test mode + if custom_args.test_mode == "test_non_sdpa": + # This should fail with ValueError because CP requires SDPA + attn_implementation = "eager" + else: + # Default: use SDPA (required for context parallelism) + attn_implementation = "sdpa" + + model = AutoModelForCausalLM.from_pretrained( + model_name, + attn_implementation=attn_implementation, + use_cache=False, # Disable KV cache for CP (accelerate example does this) + ) + + # Handle special test modes that need custom data collators + if custom_args.test_mode == "test_non_causal_mask": + # Create a custom data collator that produces non-causal masks + class NonCausalDataCollator: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def __call__(self, features): + batch = {} + batch["input_ids"] = torch.stack([torch.tensor(f["input_ids"]) for f in features]) + batch["shift_labels"] = torch.stack([torch.tensor(f["shift_labels"]) for f in features]) + batch["labels"] = batch["shift_labels"].clone() + + # Create a bidirectional (non-causal) attention mask + # This should cause context parallelism to fail + batch_size, seq_len = batch["input_ids"].shape + batch["attention_mask"] = torch.ones((batch_size, seq_len, seq_len), dtype=torch.long) + + return batch + + data_collator = NonCausalDataCollator(tokenizer) + elif custom_args.test_mode == "test_auto_generation": + # Use DataCollatorForLanguageModeling which will test auto-generation + # This won't have shift_labels pre-computed + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=custom_args.pad_to_multiple_of) + else: + # Default: use ContextParallelDataCollator (no attention masks) + # pad_to_multiple_of should be cp_size * 2 (e.g., 4 for cp_size=2) + data_collator = ContextParallelDataCollator(tokenizer, pad_to_multiple_of=custom_args.pad_to_multiple_of) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + data_collator=data_collator, + ) + + # Verify context parallelism is enabled (if parallelism_config is available) + if trainer.accelerator.parallelism_config is not None: + if not trainer.accelerator.parallelism_config.cp_enabled: + print(f"Warning: Context parallelism not enabled. cp_size={trainer.accelerator.parallelism_config.cp_size}") + print(f"ParallelismConfig: {trainer.accelerator.parallelism_config}") + else: + print("Warning: No parallelism_config found on accelerator") + + # Train for a few steps + # This will raise ValueError if using non-SDPA attention or non-causal masks with CP + trainer.train() + + # Verify training completed successfully + assert trainer.state.global_step > 0, "Training should have completed at least one step" + + # For auto_generation test, verify that position_ids and shift_labels were auto-generated + if custom_args.test_mode == "test_auto_generation": + # The training should have succeeded with auto-generated position_ids and shift_labels + # (warnings should have been logged but training should complete) + print("Auto-generation test passed: position_ids and shift_labels were auto-generated successfully") From 1ff587b504359adf3a253c1f11b145c402a4ed9f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 25 Oct 2025 12:54:58 +0000 Subject: [PATCH 02/10] simplify tests --- .../trainer/test_trainer_context_parallel.py | 334 +++++++----------- 1 file changed, 137 insertions(+), 197 deletions(-) diff --git a/tests/trainer/test_trainer_context_parallel.py b/tests/trainer/test_trainer_context_parallel.py index 37201aa2e941..1e4388a9c539 100644 --- a/tests/trainer/test_trainer_context_parallel.py +++ b/tests/trainer/test_trainer_context_parallel.py @@ -12,76 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + from transformers import is_torch_available from transformers.testing_utils import ( TestCasePlus, - backend_device_count, execute_subprocess_async, - get_torch_dist_unique_port, require_accelerate, require_torch_multi_accelerator, run_first, slow, - torch_device, ) if is_torch_available(): import torch - from datasets import load_dataset + from transformers import ( AutoModelForCausalLM, AutoTokenizer, - DataCollatorForLanguageModeling, HfArgumentParser, + PreTrainedTokenizerBase, Trainer, TrainingArguments, ) - class ContextParallelDataCollator: - """Data collator for context parallelism - does not create attention masks.""" + class CPDataset(torch.utils.data.Dataset): + """Simple dataset for context parallelism testing.""" - def __init__(self, tokenizer, pad_to_multiple_of=None): + def __init__(self, tokenizer: PreTrainedTokenizerBase, seq_length: int = 128, num_samples: int = 8): + self.tokenizer = tokenizer + self.seq_length = seq_length + # Create simple text samples + texts = [ + "The quick brown fox jumps over the lazy dog. " * 10, + "Hello world, this is a test sentence for training. " * 10, + ] * (num_samples // 2) + + self.data = [] + for text in texts: + encoded = tokenizer( + text, + max_length=seq_length, + truncation=True, + padding="max_length", + return_attention_mask=False, # CP doesn't use attention_mask + ) + input_ids = encoded["input_ids"] + # Pre-compute shift_labels for causal LM + shift_labels = input_ids[1:] + [tokenizer.pad_token_id] + shift_labels = [lbl if lbl != tokenizer.pad_token_id else -100 for lbl in shift_labels] + + self.data.append({"input_ids": input_ids, "shift_labels": shift_labels}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + class CPDataCollator: + """Data collator for context parallelism - handles special CP requirements.""" + + def __init__(self, tokenizer: PreTrainedTokenizerBase, pad_to_multiple_of: int | None = None): self.tokenizer = tokenizer self.pad_to_multiple_of = pad_to_multiple_of def __call__(self, features): - batch = {} - batch["input_ids"] = torch.stack([torch.tensor(f["input_ids"]) for f in features]) - # For CP, we need shift_labels pre-computed - batch["shift_labels"] = torch.stack([torch.tensor(f["shift_labels"]) for f in features]) - batch["labels"] = batch["shift_labels"].clone() - - # Pad to multiple if specified (required for CP: cp_size * 2) - if self.pad_to_multiple_of is not None: - seq_len = batch["input_ids"].shape[1] - remainder = seq_len % self.pad_to_multiple_of - if remainder != 0: - padding_len = self.pad_to_multiple_of - remainder - batch["input_ids"] = torch.nn.functional.pad( - batch["input_ids"], (0, padding_len), value=self.tokenizer.pad_token_id - ) - batch["shift_labels"] = torch.nn.functional.pad( - batch["shift_labels"], (0, padding_len), value=-100 - ) - batch["labels"] = batch["shift_labels"].clone() - - # Add position_ids (accelerate example includes this) + # Stack input_ids and shift_labels - use clone() to avoid memory sharing issues + input_ids = torch.stack([torch.tensor(f["input_ids"], dtype=torch.long) for f in features]) + shift_labels = torch.stack([torch.tensor(f["shift_labels"], dtype=torch.long) for f in features]) + + # Pad to multiple if needed (required for CP: sequences must be divisible by cp_size * 2) + if self.pad_to_multiple_of: + seq_len = input_ids.shape[1] + if seq_len % self.pad_to_multiple_of != 0: + padding_len = self.pad_to_multiple_of - (seq_len % self.pad_to_multiple_of) + input_ids = torch.nn.functional.pad(input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) + shift_labels = torch.nn.functional.pad(shift_labels, (0, padding_len), value=-100) + + # Create batch dictionary + batch = { + "input_ids": input_ids.clone(), # Clone to avoid memory sharing with pin_memory + "shift_labels": shift_labels.clone(), # CP trainer expects this key + "labels": shift_labels.clone(), # Clone to avoid memory sharing + } + + # Add position_ids (CP needs explicit position IDs) + # Use repeat instead of expand to avoid view/memory sharing issues seq_len = batch["input_ids"].shape[1] - batch["position_ids"] = torch.arange(seq_len).unsqueeze(0).expand(batch["input_ids"].shape[0], -1) + batch_size = batch["input_ids"].shape[0] + batch["position_ids"] = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) - # Don't create attention_mask - it causes issues with CP return batch class TestTrainerContextParallel(TestCasePlus): - """Test Trainer with context parallelism enabled via accelerate's ParallelismConfig""" + """Test Trainer with context parallelism enabled via accelerate's ParallelismConfig.""" @require_torch_multi_accelerator @require_accelerate @slow @run_first - def test_trainer_context_parallel_basic(self): + def test_trainer(self): """Test basic training with context parallelism enabled.""" output_dir = self.get_auto_remove_tmp_dir() config_path = f"{self.test_file_dir}/context_parallel_config.yaml" @@ -100,8 +133,6 @@ def test_trainer_context_parallel_basic(self): "5", "--per_device_train_batch_size", "1", - "--pad_to_multiple_of", - "4", "--logging_steps", "1", "--remove_unused_columns", @@ -109,226 +140,135 @@ def test_trainer_context_parallel_basic(self): ] execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success - any errors would have caused an error in the sub-call @require_torch_multi_accelerator @require_accelerate @slow - @run_first - def test_trainer_context_parallel_requires_sdpa(self): - """Test that context parallelism requires SDPA attention implementation.""" + def test_cp_reproducibility(self): + """Test that CP produces reproducible results with the same seed.""" + import os + output_dir = self.get_auto_remove_tmp_dir() - config_path = f"{self.test_file_dir}/context_parallel_config.yaml" + config_path_cp = f"{self.test_file_dir}/context_parallel_config.yaml" - cmd = [ + # Run 1: Train with CP and seed=42 + loss_file_1 = os.path.join(output_dir, "losses_run1.json") + cmd_1 = [ "accelerate", "launch", "--config_file", - config_path, + config_path_cp, f"{self.test_file_dir}/test_trainer_context_parallel.py", "--output_dir", - output_dir, + os.path.join(output_dir, "run1"), "--report_to", "none", "--max_steps", - "5", + "10", + "--per_device_train_batch_size", + "1", + "--logging_steps", + "1", "--remove_unused_columns", "False", - "--test_mode", - "test_non_sdpa", + "--seed", + "42", + "--loss_output_file", + loss_file_1, ] + execute_subprocess_async(cmd_1, env=self.get_env()) - # This should fail because we're using eager attention instead of SDPA - with self.assertRaises(Exception): - execute_subprocess_async(cmd, env=self.get_env()) - - @require_torch_multi_accelerator - @require_accelerate - @slow - @run_first - def test_trainer_context_parallel_causal_mask_validation(self): - """Test that context parallelism validates causal attention masks.""" - output_dir = self.get_auto_remove_tmp_dir() - config_path = f"{self.test_file_dir}/context_parallel_config.yaml" - - cmd = [ + # Run 2: Train with CP and same seed=42 + loss_file_2 = os.path.join(output_dir, "losses_run2.json") + cmd_2 = [ "accelerate", "launch", "--config_file", - config_path, + config_path_cp, f"{self.test_file_dir}/test_trainer_context_parallel.py", "--output_dir", - output_dir, + os.path.join(output_dir, "run2"), "--report_to", "none", "--max_steps", - "5", + "10", + "--per_device_train_batch_size", + "1", + "--logging_steps", + "1", "--remove_unused_columns", "False", - "--test_mode", - "test_non_causal_mask", + "--seed", + "42", + "--loss_output_file", + loss_file_2, ] + execute_subprocess_async(cmd_2, env=self.get_env()) - # This should fail because we're using a non-causal attention mask - with self.assertRaises(Exception): - execute_subprocess_async(cmd, env=self.get_env()) + # Compare losses - should be identical with same seed + with open(loss_file_1) as f: + losses_1 = json.load(f) + with open(loss_file_2) as f: + losses_2 = json.load(f) - @require_torch_multi_accelerator - @require_accelerate - @slow - @run_first - def test_trainer_context_parallel_auto_generation(self): - """Test that context parallelism auto-generates position_ids and shift_labels.""" - output_dir = self.get_auto_remove_tmp_dir() - config_path = f"{self.test_file_dir}/context_parallel_config.yaml" - - cmd = [ - "accelerate", - "launch", - "--config_file", - config_path, - f"{self.test_file_dir}/test_trainer_context_parallel.py", - "--output_dir", - output_dir, - "--report_to", - "none", - "--max_steps", - "5", - "--remove_unused_columns", - "False", - "--test_mode", - "test_auto_generation", - "--pad_to_multiple_of", - "4", - ] + assert len(losses_1) == len(losses_2), ( + f"Different number of losses: Run1 has {len(losses_1)}, Run2 has {len(losses_2)}" + ) - execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success + # Losses should be identical (or very close) with same seed + for i, (loss_1, loss_2) in enumerate(zip(losses_1, losses_2)): + assert abs(loss_1 - loss_2) < 1e-6, ( + f"Loss mismatch at step {i + 1}: Run1={loss_1:.6f}, Run2={loss_2:.6f}, diff={abs(loss_1 - loss_2):.6e}" + ) if __name__ == "__main__": - # This script is meant to be run under torch.distributed with accelerate launch - # with context parallelism enabled via ParallelismConfig + import sys - import argparse + # Parse custom arguments (not TrainingArguments parameters) + loss_output_file = None - # Parse custom arguments along with training arguments - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("--test_mode", type=str, default="default", help="Test mode to run") - arg_parser.add_argument("--pad_to_multiple_of", type=int, default=None, help="Pad sequences to multiple of this value") - custom_args, remaining_args = arg_parser.parse_known_args() + if "--loss_output_file" in sys.argv: + idx = sys.argv.index("--loss_output_file") + loss_output_file = sys.argv[idx + 1] + sys.argv.pop(idx) + sys.argv.pop(idx) parser = HfArgumentParser((TrainingArguments,)) - training_args = parser.parse_args_into_dataclasses(remaining_args)[0] + training_args = parser.parse_args_into_dataclasses()[0] - # Use SmolLM model (small Llama-based model, works with CP unlike GPT2) + # Use SmolLM (small Llama-based model that works with CP) model_name = "HuggingFaceTB/SmolLM-135M" tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - # Create a simple causal LM dataset - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100]") - - def tokenize_fn(examples): - tokenized = tokenizer( - examples["text"], - max_length=128, - truncation=True, - padding="max_length", - return_attention_mask=False, # Don't create attention mask for CP - ) - # For context parallelism, we need to pre-compute shift_labels - # shift_labels[i] = input_ids[i+1] for causal LM - shift_labels = [] - for input_ids in tokenized["input_ids"]: - # Create shift_labels by taking input_ids[1:] and appending -100 - labels = input_ids[1:] + [tokenizer.pad_token_id] - # Replace pad tokens with -100 - labels = [label if label != tokenizer.pad_token_id else -100 for label in labels] - shift_labels.append(labels) - tokenized["shift_labels"] = shift_labels - return tokenized - - # Don't remove columns yet, keep text for debugging - tokenized_dataset = dataset.map( - tokenize_fn, - batched=True, - remove_columns=dataset.column_names, - ) - - # Verify shift_labels exists - print(f"Dataset columns: {tokenized_dataset.column_names}") - print(f"First example keys: {list(tokenized_dataset[0].keys())}") - - # Select attention implementation based on test mode - if custom_args.test_mode == "test_non_sdpa": - # This should fail with ValueError because CP requires SDPA - attn_implementation = "eager" - else: - # Default: use SDPA (required for context parallelism) - attn_implementation = "sdpa" - model = AutoModelForCausalLM.from_pretrained( model_name, - attn_implementation=attn_implementation, - use_cache=False, # Disable KV cache for CP (accelerate example does this) + attn_implementation="sdpa", # CP requires SDPA + use_cache=False, # Disable KV cache for CP ) - # Handle special test modes that need custom data collators - if custom_args.test_mode == "test_non_causal_mask": - # Create a custom data collator that produces non-causal masks - class NonCausalDataCollator: - def __init__(self, tokenizer): - self.tokenizer = tokenizer - - def __call__(self, features): - batch = {} - batch["input_ids"] = torch.stack([torch.tensor(f["input_ids"]) for f in features]) - batch["shift_labels"] = torch.stack([torch.tensor(f["shift_labels"]) for f in features]) - batch["labels"] = batch["shift_labels"].clone() - - # Create a bidirectional (non-causal) attention mask - # This should cause context parallelism to fail - batch_size, seq_len = batch["input_ids"].shape - batch["attention_mask"] = torch.ones((batch_size, seq_len, seq_len), dtype=torch.long) - - return batch - - data_collator = NonCausalDataCollator(tokenizer) - elif custom_args.test_mode == "test_auto_generation": - # Use DataCollatorForLanguageModeling which will test auto-generation - # This won't have shift_labels pre-computed - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=custom_args.pad_to_multiple_of) - else: - # Default: use ContextParallelDataCollator (no attention masks) - # pad_to_multiple_of should be cp_size * 2 (e.g., 4 for cp_size=2) - data_collator = ContextParallelDataCollator(tokenizer, pad_to_multiple_of=custom_args.pad_to_multiple_of) + # Create dataset and data collator + train_dataset = CPDataset(tokenizer, seq_length=128, num_samples=8) + # pad_to_multiple_of=4 for cp_size=2 (must be divisible by cp_size * 2) + data_collator = CPDataCollator(tokenizer, pad_to_multiple_of=4) trainer = Trainer( model=model, args=training_args, - train_dataset=tokenized_dataset, + train_dataset=train_dataset, data_collator=data_collator, ) - # Verify context parallelism is enabled (if parallelism_config is available) - if trainer.accelerator.parallelism_config is not None: - if not trainer.accelerator.parallelism_config.cp_enabled: - print(f"Warning: Context parallelism not enabled. cp_size={trainer.accelerator.parallelism_config.cp_size}") - print(f"ParallelismConfig: {trainer.accelerator.parallelism_config}") - else: - print("Warning: No parallelism_config found on accelerator") - # Train for a few steps - # This will raise ValueError if using non-SDPA attention or non-causal masks with CP trainer.train() - # Verify training completed successfully + # Verify training completed assert trainer.state.global_step > 0, "Training should have completed at least one step" - # For auto_generation test, verify that position_ids and shift_labels were auto-generated - if custom_args.test_mode == "test_auto_generation": - # The training should have succeeded with auto-generated position_ids and shift_labels - # (warnings should have been logged but training should complete) - print("Auto-generation test passed: position_ids and shift_labels were auto-generated successfully") + # Save losses to file if requested (for reproducibility testing) + if loss_output_file and training_args.process_index == 0: + losses = [log["loss"] for log in trainer.state.log_history if "loss" in log] + with open(loss_output_file, "w") as f: + json.dump(losses, f) From 0bcf34edad6c23e25be2efdf84a05f16882f4b82 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 25 Oct 2025 14:02:48 +0000 Subject: [PATCH 03/10] add test_cp_equivalence --- .../context_parallel_no_cp_config.yaml | 25 ++++++ .../trainer/test_trainer_context_parallel.py | 85 ++++++++++++------- 2 files changed, 79 insertions(+), 31 deletions(-) create mode 100644 tests/trainer/context_parallel_no_cp_config.yaml diff --git a/tests/trainer/context_parallel_no_cp_config.yaml b/tests/trainer/context_parallel_no_cp_config.yaml new file mode 100644 index 000000000000..e520dd5e27a8 --- /dev/null +++ b/tests/trainer/context_parallel_no_cp_config.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: false + fsdp_offload_params: false + fsdp_reshard_after_forward: false + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/tests/trainer/test_trainer_context_parallel.py b/tests/trainer/test_trainer_context_parallel.py index 1e4388a9c539..811d3f82781a 100644 --- a/tests/trainer/test_trainer_context_parallel.py +++ b/tests/trainer/test_trainer_context_parallel.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import sys from transformers import is_torch_available from transformers.testing_utils import ( @@ -91,11 +92,16 @@ def __call__(self, features): input_ids = torch.nn.functional.pad(input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) shift_labels = torch.nn.functional.pad(shift_labels, (0, padding_len), value=-100) + # For causal LM, unshifted labels are the input_ids themselves + # Replace pad tokens with -100 in labels + unshifted_labels = input_ids.clone() + unshifted_labels[unshifted_labels == self.tokenizer.pad_token_id] = -100 + # Create batch dictionary batch = { "input_ids": input_ids.clone(), # Clone to avoid memory sharing with pin_memory - "shift_labels": shift_labels.clone(), # CP trainer expects this key - "labels": shift_labels.clone(), # Clone to avoid memory sharing + "shift_labels": shift_labels.clone(), # CP trainer expects pre-shifted labels + "labels": unshifted_labels.clone(), # Non-CP mode expects unshifted labels } # Add position_ids (CP needs explicit position IDs) @@ -144,29 +150,32 @@ def test_trainer(self): @require_torch_multi_accelerator @require_accelerate @slow - def test_cp_reproducibility(self): - """Test that CP produces reproducible results with the same seed.""" + def test_cp_equivalence(self): + """Test that CP produces the same losses as without CP.""" import os output_dir = self.get_auto_remove_tmp_dir() + + # Run with CP enabled (cp_size=2) config_path_cp = f"{self.test_file_dir}/context_parallel_config.yaml" + loss_file_cp = os.path.join(output_dir, "losses_cp.json") - # Run 1: Train with CP and seed=42 - loss_file_1 = os.path.join(output_dir, "losses_run1.json") - cmd_1 = [ + cmd_cp = [ "accelerate", "launch", "--config_file", config_path_cp, f"{self.test_file_dir}/test_trainer_context_parallel.py", "--output_dir", - os.path.join(output_dir, "run1"), + os.path.join(output_dir, "with_cp"), "--report_to", "none", "--max_steps", "10", "--per_device_train_batch_size", "1", + "--gradient_accumulation_steps", + "1", "--logging_steps", "1", "--remove_unused_columns", @@ -174,26 +183,30 @@ def test_cp_reproducibility(self): "--seed", "42", "--loss_output_file", - loss_file_1, + loss_file_cp, ] - execute_subprocess_async(cmd_1, env=self.get_env()) + execute_subprocess_async(cmd_cp, env=self.get_env()) + + # Run without CP (FSDP with num_processes=1, no parallelism_config) + config_path_no_cp = f"{self.test_file_dir}/context_parallel_no_cp_config.yaml" + loss_file_no_cp = os.path.join(output_dir, "losses_no_cp.json") - # Run 2: Train with CP and same seed=42 - loss_file_2 = os.path.join(output_dir, "losses_run2.json") - cmd_2 = [ + cmd_no_cp = [ "accelerate", "launch", "--config_file", - config_path_cp, + config_path_no_cp, f"{self.test_file_dir}/test_trainer_context_parallel.py", "--output_dir", - os.path.join(output_dir, "run2"), + os.path.join(output_dir, "without_cp"), "--report_to", "none", "--max_steps", "10", "--per_device_train_batch_size", "1", + "--gradient_accumulation_steps", + "1", "--logging_steps", "1", "--remove_unused_columns", @@ -201,30 +214,40 @@ def test_cp_reproducibility(self): "--seed", "42", "--loss_output_file", - loss_file_2, + loss_file_no_cp, ] - execute_subprocess_async(cmd_2, env=self.get_env()) + execute_subprocess_async(cmd_no_cp, env=self.get_env()) - # Compare losses - should be identical with same seed - with open(loss_file_1) as f: - losses_1 = json.load(f) - with open(loss_file_2) as f: - losses_2 = json.load(f) + # Compare losses - should be very close since CP just splits sequence computation + with open(loss_file_cp) as f: + losses_cp = json.load(f) + with open(loss_file_no_cp) as f: + losses_no_cp = json.load(f) - assert len(losses_1) == len(losses_2), ( - f"Different number of losses: Run1 has {len(losses_1)}, Run2 has {len(losses_2)}" + assert len(losses_cp) == len(losses_no_cp), ( + f"Different number of losses: CP has {len(losses_cp)}, no-CP has {len(losses_no_cp)}" ) - # Losses should be identical (or very close) with same seed - for i, (loss_1, loss_2) in enumerate(zip(losses_1, losses_2)): - assert abs(loss_1 - loss_2) < 1e-6, ( - f"Loss mismatch at step {i + 1}: Run1={loss_1:.6f}, Run2={loss_2:.6f}, diff={abs(loss_1 - loss_2):.6e}" - ) + # CP should produce very similar results (small numerical differences expected) + # The differences come from: + # - Different gradient reduction patterns in distributed training + # - BF16 mixed precision accumulated differences + # - Sequence splitting and gathering in CP mode + losses_cp_tensor = torch.tensor(losses_cp) + losses_no_cp_tensor = torch.tensor(losses_no_cp) + + # Use torch.testing.assert_close with rtol=2% and atol=0.02 + # Testing shows actual differences are typically <1.5% + torch.testing.assert_close( + losses_cp_tensor, + losses_no_cp_tensor, + rtol=2e-2, # 2% relative tolerance + atol=2e-2, # 0.02 absolute tolerance + msg=f"CP losses {losses_cp} do not match non-CP losses {losses_no_cp}", + ) if __name__ == "__main__": - import sys - # Parse custom arguments (not TrainingArguments parameters) loss_output_file = None From 6d42d9a599d486ef4a670607b87881fc4871c186 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 25 Oct 2025 14:09:26 +0000 Subject: [PATCH 04/10] removed fsdp_transformer_layer_cls_to_wrap --- tests/trainer/context_parallel_config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/context_parallel_config.yaml b/tests/trainer/context_parallel_config.yaml index f0b4da9c4ace..331e6c6b77e9 100644 --- a/tests/trainer/context_parallel_config.yaml +++ b/tests/trainer/context_parallel_config.yaml @@ -10,7 +10,6 @@ fsdp_config: fsdp_offload_params: false fsdp_reshard_after_forward: false fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_version: 2 machine_rank: 0 main_training_function: main From e18436c0d2798a60d9d3b5f6dbeb43acefe1610a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 25 Oct 2025 14:20:22 +0000 Subject: [PATCH 05/10] use DataCollatorForLanguageModeling --- .../trainer/test_trainer_context_parallel.py | 101 ++++-------------- 1 file changed, 21 insertions(+), 80 deletions(-) diff --git a/tests/trainer/test_trainer_context_parallel.py b/tests/trainer/test_trainer_context_parallel.py index 811d3f82781a..8bc1d05c8532 100644 --- a/tests/trainer/test_trainer_context_parallel.py +++ b/tests/trainer/test_trainer_context_parallel.py @@ -32,86 +32,12 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, + DataCollatorForLanguageModeling, HfArgumentParser, - PreTrainedTokenizerBase, Trainer, TrainingArguments, ) - class CPDataset(torch.utils.data.Dataset): - """Simple dataset for context parallelism testing.""" - - def __init__(self, tokenizer: PreTrainedTokenizerBase, seq_length: int = 128, num_samples: int = 8): - self.tokenizer = tokenizer - self.seq_length = seq_length - # Create simple text samples - texts = [ - "The quick brown fox jumps over the lazy dog. " * 10, - "Hello world, this is a test sentence for training. " * 10, - ] * (num_samples // 2) - - self.data = [] - for text in texts: - encoded = tokenizer( - text, - max_length=seq_length, - truncation=True, - padding="max_length", - return_attention_mask=False, # CP doesn't use attention_mask - ) - input_ids = encoded["input_ids"] - # Pre-compute shift_labels for causal LM - shift_labels = input_ids[1:] + [tokenizer.pad_token_id] - shift_labels = [lbl if lbl != tokenizer.pad_token_id else -100 for lbl in shift_labels] - - self.data.append({"input_ids": input_ids, "shift_labels": shift_labels}) - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - return self.data[i] - - class CPDataCollator: - """Data collator for context parallelism - handles special CP requirements.""" - - def __init__(self, tokenizer: PreTrainedTokenizerBase, pad_to_multiple_of: int | None = None): - self.tokenizer = tokenizer - self.pad_to_multiple_of = pad_to_multiple_of - - def __call__(self, features): - # Stack input_ids and shift_labels - use clone() to avoid memory sharing issues - input_ids = torch.stack([torch.tensor(f["input_ids"], dtype=torch.long) for f in features]) - shift_labels = torch.stack([torch.tensor(f["shift_labels"], dtype=torch.long) for f in features]) - - # Pad to multiple if needed (required for CP: sequences must be divisible by cp_size * 2) - if self.pad_to_multiple_of: - seq_len = input_ids.shape[1] - if seq_len % self.pad_to_multiple_of != 0: - padding_len = self.pad_to_multiple_of - (seq_len % self.pad_to_multiple_of) - input_ids = torch.nn.functional.pad(input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) - shift_labels = torch.nn.functional.pad(shift_labels, (0, padding_len), value=-100) - - # For causal LM, unshifted labels are the input_ids themselves - # Replace pad tokens with -100 in labels - unshifted_labels = input_ids.clone() - unshifted_labels[unshifted_labels == self.tokenizer.pad_token_id] = -100 - - # Create batch dictionary - batch = { - "input_ids": input_ids.clone(), # Clone to avoid memory sharing with pin_memory - "shift_labels": shift_labels.clone(), # CP trainer expects pre-shifted labels - "labels": unshifted_labels.clone(), # Non-CP mode expects unshifted labels - } - - # Add position_ids (CP needs explicit position IDs) - # Use repeat instead of expand to avoid view/memory sharing issues - seq_len = batch["input_ids"].shape[1] - batch_size = batch["input_ids"].shape[0] - batch["position_ids"] = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) - - return batch - class TestTrainerContextParallel(TestCasePlus): """Test Trainer with context parallelism enabled via accelerate's ParallelismConfig.""" @@ -272,10 +198,25 @@ def test_cp_equivalence(self): use_cache=False, # Disable KV cache for CP ) - # Create dataset and data collator - train_dataset = CPDataset(tokenizer, seq_length=128, num_samples=8) - # pad_to_multiple_of=4 for cp_size=2 (must be divisible by cp_size * 2) - data_collator = CPDataCollator(tokenizer, pad_to_multiple_of=4) + # Create simple dataset: just tokenize some text + texts = [ + "The quick brown fox jumps over the lazy dog. " * 10, + "Hello world, this is a test sentence for training. " * 10, + ] * 4 # 8 samples total + + def tokenize_function(examples): + return tokenizer(examples, max_length=128, truncation=True, padding="max_length") + + train_dataset = [tokenize_function(text) for text in texts] + + # Use standard DataCollatorForLanguageModeling for causal LM + # pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2) + # Trainer will automatically generate position_ids and shift_labels as needed + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # Causal language modeling + pad_to_multiple_of=4, + ) trainer = Trainer( model=model, @@ -290,7 +231,7 @@ def test_cp_equivalence(self): # Verify training completed assert trainer.state.global_step > 0, "Training should have completed at least one step" - # Save losses to file if requested (for reproducibility testing) + # Save losses to file if requested (for equivalence testing) if loss_output_file and training_args.process_index == 0: losses = [log["loss"] for log in trainer.state.log_history if "loss" in log] with open(loss_output_file, "w") as f: From df8aaacbbb096a4e3fdf4240db1c1ce6410a45f7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 25 Oct 2025 14:25:52 +0000 Subject: [PATCH 06/10] remove use_cache=False. --- tests/trainer/test_trainer_context_parallel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer_context_parallel.py b/tests/trainer/test_trainer_context_parallel.py index 8bc1d05c8532..7dea9675fb90 100644 --- a/tests/trainer/test_trainer_context_parallel.py +++ b/tests/trainer/test_trainer_context_parallel.py @@ -195,7 +195,6 @@ def test_cp_equivalence(self): model = AutoModelForCausalLM.from_pretrained( model_name, attn_implementation="sdpa", # CP requires SDPA - use_cache=False, # Disable KV cache for CP ) # Create simple dataset: just tokenize some text From 4ec11c20c8508b844bb8f0932f0f8ae25adb0ebf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Nov 2025 19:14:55 +0000 Subject: [PATCH 07/10] changes from review --- tests/trainer/context_parallel_config.yaml | 29 ------------- .../context_parallel_no_cp_config.yaml | 25 ----------- .../context_parallel_torch_config.yaml | 13 ++++++ .../context_parallel_torch_no_cp_config.yaml | 8 ++++ ...=> test_trainer_context_parallel_torch.py} | 42 +++---------------- 5 files changed, 27 insertions(+), 90 deletions(-) delete mode 100644 tests/trainer/context_parallel_config.yaml delete mode 100644 tests/trainer/context_parallel_no_cp_config.yaml create mode 100644 tests/trainer/context_parallel_torch_config.yaml create mode 100644 tests/trainer/context_parallel_torch_no_cp_config.yaml rename tests/trainer/{test_trainer_context_parallel.py => test_trainer_context_parallel_torch.py} (86%) diff --git a/tests/trainer/context_parallel_config.yaml b/tests/trainer/context_parallel_config.yaml deleted file mode 100644 index 331e6c6b77e9..000000000000 --- a/tests/trainer/context_parallel_config.yaml +++ /dev/null @@ -1,29 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -enable_cpu_affinity: false -fsdp_config: - fsdp_activation_checkpointing: false - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_cpu_ram_efficient_loading: false - fsdp_offload_params: false - fsdp_reshard_after_forward: false - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_version: 2 -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 2 -parallelism_config: - parallelism_config_dp_replicate_size: 1 - parallelism_config_dp_shard_size: 1 - parallelism_config_tp_size: 1 - parallelism_config_cp_size: 2 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/tests/trainer/context_parallel_no_cp_config.yaml b/tests/trainer/context_parallel_no_cp_config.yaml deleted file mode 100644 index e520dd5e27a8..000000000000 --- a/tests/trainer/context_parallel_no_cp_config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -enable_cpu_affinity: false -fsdp_config: - fsdp_activation_checkpointing: false - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_cpu_ram_efficient_loading: false - fsdp_offload_params: false - fsdp_reshard_after_forward: false - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_version: 2 -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/tests/trainer/context_parallel_torch_config.yaml b/tests/trainer/context_parallel_torch_config.yaml new file mode 100644 index 000000000000..ed78181c15ca --- /dev/null +++ b/tests/trainer/context_parallel_torch_config.yaml @@ -0,0 +1,13 @@ +distributed_type: FSDP +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_version: 2 +mixed_precision: bf16 +num_processes: 2 +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 1 + parallelism_config_tp_size: 1 + parallelism_config_cp_size: 2 + parallelism_config_cp_comm_strategy: alltoall diff --git a/tests/trainer/context_parallel_torch_no_cp_config.yaml b/tests/trainer/context_parallel_torch_no_cp_config.yaml new file mode 100644 index 000000000000..0b35b732cc35 --- /dev/null +++ b/tests/trainer/context_parallel_torch_no_cp_config.yaml @@ -0,0 +1,8 @@ +distributed_type: FSDP +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_version: 2 +mixed_precision: bf16 +num_processes: 1 diff --git a/tests/trainer/test_trainer_context_parallel.py b/tests/trainer/test_trainer_context_parallel_torch.py similarity index 86% rename from tests/trainer/test_trainer_context_parallel.py rename to tests/trainer/test_trainer_context_parallel_torch.py index 7dea9675fb90..fa6bd54738a7 100644 --- a/tests/trainer/test_trainer_context_parallel.py +++ b/tests/trainer/test_trainer_context_parallel_torch.py @@ -39,43 +39,13 @@ ) -class TestTrainerContextParallel(TestCasePlus): - """Test Trainer with context parallelism enabled via accelerate's ParallelismConfig.""" +class TestTrainerContextParallelTorch(TestCasePlus): + """Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig.""" @require_torch_multi_accelerator @require_accelerate @slow @run_first - def test_trainer(self): - """Test basic training with context parallelism enabled.""" - output_dir = self.get_auto_remove_tmp_dir() - config_path = f"{self.test_file_dir}/context_parallel_config.yaml" - - cmd = [ - "accelerate", - "launch", - "--config_file", - config_path, - f"{self.test_file_dir}/test_trainer_context_parallel.py", - "--output_dir", - output_dir, - "--report_to", - "none", - "--max_steps", - "5", - "--per_device_train_batch_size", - "1", - "--logging_steps", - "1", - "--remove_unused_columns", - "False", - ] - - execute_subprocess_async(cmd, env=self.get_env()) - - @require_torch_multi_accelerator - @require_accelerate - @slow def test_cp_equivalence(self): """Test that CP produces the same losses as without CP.""" import os @@ -83,7 +53,7 @@ def test_cp_equivalence(self): output_dir = self.get_auto_remove_tmp_dir() # Run with CP enabled (cp_size=2) - config_path_cp = f"{self.test_file_dir}/context_parallel_config.yaml" + config_path_cp = f"{self.test_file_dir}/context_parallel_torch_config.yaml" loss_file_cp = os.path.join(output_dir, "losses_cp.json") cmd_cp = [ @@ -91,7 +61,7 @@ def test_cp_equivalence(self): "launch", "--config_file", config_path_cp, - f"{self.test_file_dir}/test_trainer_context_parallel.py", + f"{self.test_file_dir}/test_trainer_context_parallel_torch.py", "--output_dir", os.path.join(output_dir, "with_cp"), "--report_to", @@ -114,7 +84,7 @@ def test_cp_equivalence(self): execute_subprocess_async(cmd_cp, env=self.get_env()) # Run without CP (FSDP with num_processes=1, no parallelism_config) - config_path_no_cp = f"{self.test_file_dir}/context_parallel_no_cp_config.yaml" + config_path_no_cp = f"{self.test_file_dir}/context_parallel_torch_no_cp_config.yaml" loss_file_no_cp = os.path.join(output_dir, "losses_no_cp.json") cmd_no_cp = [ @@ -122,7 +92,7 @@ def test_cp_equivalence(self): "launch", "--config_file", config_path_no_cp, - f"{self.test_file_dir}/test_trainer_context_parallel.py", + f"{self.test_file_dir}/test_trainer_context_parallel_torch.py", "--output_dir", os.path.join(output_dir, "without_cp"), "--report_to", From a4a187e0e53fa4e51a3985b83f3d6dcbc7bbfd54 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Nov 2025 19:28:54 +0000 Subject: [PATCH 08/10] make script self contained --- .../context_parallel_torch_config.yaml | 13 -- .../context_parallel_torch_no_cp_config.yaml | 8 - .../test_trainer_context_parallel_torch.py | 169 ++++++++++-------- 3 files changed, 93 insertions(+), 97 deletions(-) delete mode 100644 tests/trainer/context_parallel_torch_config.yaml delete mode 100644 tests/trainer/context_parallel_torch_no_cp_config.yaml diff --git a/tests/trainer/context_parallel_torch_config.yaml b/tests/trainer/context_parallel_torch_config.yaml deleted file mode 100644 index ed78181c15ca..000000000000 --- a/tests/trainer/context_parallel_torch_config.yaml +++ /dev/null @@ -1,13 +0,0 @@ -distributed_type: FSDP -fsdp_config: - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_version: 2 -mixed_precision: bf16 -num_processes: 2 -parallelism_config: - parallelism_config_dp_replicate_size: 1 - parallelism_config_dp_shard_size: 1 - parallelism_config_tp_size: 1 - parallelism_config_cp_size: 2 - parallelism_config_cp_comm_strategy: alltoall diff --git a/tests/trainer/context_parallel_torch_no_cp_config.yaml b/tests/trainer/context_parallel_torch_no_cp_config.yaml deleted file mode 100644 index 0b35b732cc35..000000000000 --- a/tests/trainer/context_parallel_torch_no_cp_config.yaml +++ /dev/null @@ -1,8 +0,0 @@ -distributed_type: FSDP -fsdp_config: - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_version: 2 -mixed_precision: bf16 -num_processes: 1 diff --git a/tests/trainer/test_trainer_context_parallel_torch.py b/tests/trainer/test_trainer_context_parallel_torch.py index fa6bd54738a7..8d8c07056ef1 100644 --- a/tests/trainer/test_trainer_context_parallel_torch.py +++ b/tests/trainer/test_trainer_context_parallel_torch.py @@ -14,6 +14,7 @@ import json import sys +from pathlib import Path from transformers import is_torch_available from transformers.testing_utils import ( @@ -48,80 +49,96 @@ class TestTrainerContextParallelTorch(TestCasePlus): @run_first def test_cp_equivalence(self): """Test that CP produces the same losses as without CP.""" - import os - - output_dir = self.get_auto_remove_tmp_dir() - - # Run with CP enabled (cp_size=2) - config_path_cp = f"{self.test_file_dir}/context_parallel_torch_config.yaml" - loss_file_cp = os.path.join(output_dir, "losses_cp.json") - - cmd_cp = [ - "accelerate", - "launch", - "--config_file", - config_path_cp, - f"{self.test_file_dir}/test_trainer_context_parallel_torch.py", - "--output_dir", - os.path.join(output_dir, "with_cp"), - "--report_to", - "none", - "--max_steps", - "10", - "--per_device_train_batch_size", - "1", - "--gradient_accumulation_steps", - "1", - "--logging_steps", - "1", - "--remove_unused_columns", - "False", - "--seed", - "42", - "--loss_output_file", - loss_file_cp, - ] - execute_subprocess_async(cmd_cp, env=self.get_env()) - - # Run without CP (FSDP with num_processes=1, no parallelism_config) - config_path_no_cp = f"{self.test_file_dir}/context_parallel_torch_no_cp_config.yaml" - loss_file_no_cp = os.path.join(output_dir, "losses_no_cp.json") - - cmd_no_cp = [ - "accelerate", - "launch", - "--config_file", - config_path_no_cp, - f"{self.test_file_dir}/test_trainer_context_parallel_torch.py", - "--output_dir", - os.path.join(output_dir, "without_cp"), - "--report_to", - "none", - "--max_steps", - "10", - "--per_device_train_batch_size", - "1", - "--gradient_accumulation_steps", - "1", - "--logging_steps", - "1", - "--remove_unused_columns", - "False", - "--seed", - "42", - "--loss_output_file", - loss_file_no_cp, - ] - execute_subprocess_async(cmd_no_cp, env=self.get_env()) + + # Shared setup + world_size = 2 + script_path = __file__ + + # Step 1: Run with CP enabled (cp_size=world_size) + cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve() + cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml" + cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json" + + # Write config file inline (self-contained test) + with open(cp_yes_config_path, "w") as f: + f.write( + f"""distributed_type: FSDP +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_version: 2 +mixed_precision: bf16 +num_processes: {world_size} +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 1 + parallelism_config_tp_size: 1 + parallelism_config_cp_size: {world_size} + parallelism_config_cp_comm_strategy: alltoall +""" + ) + + cmd_cp_yes = f""" + accelerate launch + --config_file {cp_yes_config_path} + {script_path} + --output_dir {cp_yes_output_dir} + --report_to none + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --logging_steps 1 + --remove_unused_columns False + --seed 42 + --loss_output_file {cp_yes_losses_path} + """.split() + + execute_subprocess_async(cmd_cp_yes, env=self.get_env()) + + # Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config) + cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve() + cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml" + cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json" + + # Write config file inline (self-contained test) + with open(cp_no_config_path, "w") as f: + f.write( + """distributed_type: FSDP +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_version: 2 +mixed_precision: bf16 +num_processes: 1 +""" + ) + + cmd_cp_no = f""" + accelerate launch + --config_file {cp_no_config_path} + {script_path} + --output_dir {cp_no_output_dir} + --report_to none + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --logging_steps 1 + --remove_unused_columns False + --seed 42 + --loss_output_file {cp_no_losses_path} + """.split() + + execute_subprocess_async(cmd_cp_no, env=self.get_env()) # Compare losses - should be very close since CP just splits sequence computation - with open(loss_file_cp) as f: - losses_cp = json.load(f) - with open(loss_file_no_cp) as f: - losses_no_cp = json.load(f) + with open(cp_yes_losses_path) as f: + cp_yes_losses = json.load(f) + with open(cp_no_losses_path) as f: + cp_no_losses = json.load(f) - assert len(losses_cp) == len(losses_no_cp), ( - f"Different number of losses: CP has {len(losses_cp)}, no-CP has {len(losses_no_cp)}" + assert len(cp_yes_losses) == len(cp_no_losses), ( + f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}" ) # CP should produce very similar results (small numerical differences expected) @@ -129,17 +146,17 @@ def test_cp_equivalence(self): # - Different gradient reduction patterns in distributed training # - BF16 mixed precision accumulated differences # - Sequence splitting and gathering in CP mode - losses_cp_tensor = torch.tensor(losses_cp) - losses_no_cp_tensor = torch.tensor(losses_no_cp) + cp_yes_losses_tensor = torch.tensor(cp_yes_losses) + cp_no_losses_tensor = torch.tensor(cp_no_losses) # Use torch.testing.assert_close with rtol=2% and atol=0.02 # Testing shows actual differences are typically <1.5% torch.testing.assert_close( - losses_cp_tensor, - losses_no_cp_tensor, + cp_yes_losses_tensor, + cp_no_losses_tensor, rtol=2e-2, # 2% relative tolerance atol=2e-2, # 0.02 absolute tolerance - msg=f"CP losses {losses_cp} do not match non-CP losses {losses_no_cp}", + msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}", ) From 5a7d1aa7eb155e63f98bd8bb6c6819e6c0edf919 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 5 Nov 2025 11:27:45 +0100 Subject: [PATCH 09/10] moved to fsdp folder --- .../test_context_parallel.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{trainer/test_trainer_context_parallel_torch.py => fsdp/test_context_parallel.py} (100%) diff --git a/tests/trainer/test_trainer_context_parallel_torch.py b/tests/fsdp/test_context_parallel.py similarity index 100% rename from tests/trainer/test_trainer_context_parallel_torch.py rename to tests/fsdp/test_context_parallel.py From 8b39cd53895fe6fddd0aca7df906ec5172dfcc38 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 5 Nov 2025 11:30:41 +0100 Subject: [PATCH 10/10] fix class name --- tests/fsdp/test_context_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fsdp/test_context_parallel.py b/tests/fsdp/test_context_parallel.py index 8d8c07056ef1..8e0b58a32187 100644 --- a/tests/fsdp/test_context_parallel.py +++ b/tests/fsdp/test_context_parallel.py @@ -40,7 +40,7 @@ ) -class TestTrainerContextParallelTorch(TestCasePlus): +class TestContextParallel(TestCasePlus): """Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig.""" @require_torch_multi_accelerator