-
Notifications
You must be signed in to change notification settings - Fork 31.4k
[tests] Add Context-parallel CI tests #41860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
63efa18
intial
kashif 1ff587b
simplify tests
kashif 0bcf34e
add test_cp_equivalence
kashif 6d42d9a
removed fsdp_transformer_layer_cls_to_wrap
kashif e18436c
use DataCollatorForLanguageModeling
kashif df8aaac
remove use_cache=False.
kashif 4ec11c2
changes from review
kashif a4a187e
make script self contained
kashif 977c586
Merge branch 'main' into cp-ci-tests
kashif 5a7d1aa
moved to fsdp folder
kashif 8b39cd5
fix class name
kashif 9bd5dae
Merge branch 'main' into cp-ci-tests
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| compute_environment: LOCAL_MACHINE | ||
| debug: false | ||
| distributed_type: FSDP | ||
| downcast_bf16: 'no' | ||
| enable_cpu_affinity: false | ||
| fsdp_config: | ||
| fsdp_activation_checkpointing: false | ||
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| fsdp_cpu_ram_efficient_loading: false | ||
| fsdp_offload_params: false | ||
| fsdp_reshard_after_forward: false | ||
| fsdp_state_dict_type: SHARDED_STATE_DICT | ||
| fsdp_version: 2 | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| mixed_precision: bf16 | ||
| num_machines: 1 | ||
| num_processes: 2 | ||
| parallelism_config: | ||
| parallelism_config_dp_replicate_size: 1 | ||
| parallelism_config_dp_shard_size: 1 | ||
| parallelism_config_tp_size: 1 | ||
| parallelism_config_cp_size: 2 | ||
| rdzv_backend: static | ||
| same_network: true | ||
| tpu_env: [] | ||
| tpu_use_cluster: false | ||
| tpu_use_sudo: false | ||
| use_cpu: false | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| compute_environment: LOCAL_MACHINE | ||
| debug: false | ||
| distributed_type: FSDP | ||
| downcast_bf16: 'no' | ||
| enable_cpu_affinity: false | ||
| fsdp_config: | ||
| fsdp_activation_checkpointing: false | ||
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| fsdp_cpu_ram_efficient_loading: false | ||
| fsdp_offload_params: false | ||
| fsdp_reshard_after_forward: false | ||
| fsdp_state_dict_type: SHARDED_STATE_DICT | ||
| fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer | ||
| fsdp_version: 2 | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| mixed_precision: bf16 | ||
| num_machines: 1 | ||
| num_processes: 1 | ||
| rdzv_backend: static | ||
| same_network: true | ||
| tpu_env: [] | ||
| tpu_use_cluster: false | ||
| tpu_use_sudo: false | ||
| use_cpu: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,237 @@ | ||
| # Copyright 2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import json | ||
| import sys | ||
|
|
||
| from transformers import is_torch_available | ||
| from transformers.testing_utils import ( | ||
| TestCasePlus, | ||
| execute_subprocess_async, | ||
| require_accelerate, | ||
| require_torch_multi_accelerator, | ||
| run_first, | ||
| slow, | ||
| ) | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| DataCollatorForLanguageModeling, | ||
| HfArgumentParser, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
|
|
||
|
|
||
| class TestTrainerContextParallel(TestCasePlus): | ||
| """Test Trainer with context parallelism enabled via accelerate's ParallelismConfig.""" | ||
|
|
||
| @require_torch_multi_accelerator | ||
| @require_accelerate | ||
| @slow | ||
| @run_first | ||
| def test_trainer(self): | ||
| """Test basic training with context parallelism enabled.""" | ||
| output_dir = self.get_auto_remove_tmp_dir() | ||
| config_path = f"{self.test_file_dir}/context_parallel_config.yaml" | ||
|
|
||
| cmd = [ | ||
| "accelerate", | ||
| "launch", | ||
| "--config_file", | ||
| config_path, | ||
| f"{self.test_file_dir}/test_trainer_context_parallel.py", | ||
| "--output_dir", | ||
| output_dir, | ||
| "--report_to", | ||
| "none", | ||
| "--max_steps", | ||
| "5", | ||
| "--per_device_train_batch_size", | ||
| "1", | ||
| "--logging_steps", | ||
| "1", | ||
| "--remove_unused_columns", | ||
| "False", | ||
| ] | ||
|
|
||
| execute_subprocess_async(cmd, env=self.get_env()) | ||
|
|
||
| @require_torch_multi_accelerator | ||
| @require_accelerate | ||
| @slow | ||
| def test_cp_equivalence(self): | ||
| """Test that CP produces the same losses as without CP.""" | ||
| import os | ||
|
|
||
| output_dir = self.get_auto_remove_tmp_dir() | ||
|
|
||
| # Run with CP enabled (cp_size=2) | ||
| config_path_cp = f"{self.test_file_dir}/context_parallel_config.yaml" | ||
| loss_file_cp = os.path.join(output_dir, "losses_cp.json") | ||
|
|
||
| cmd_cp = [ | ||
| "accelerate", | ||
| "launch", | ||
| "--config_file", | ||
| config_path_cp, | ||
| f"{self.test_file_dir}/test_trainer_context_parallel.py", | ||
| "--output_dir", | ||
| os.path.join(output_dir, "with_cp"), | ||
| "--report_to", | ||
| "none", | ||
| "--max_steps", | ||
| "10", | ||
| "--per_device_train_batch_size", | ||
| "1", | ||
| "--gradient_accumulation_steps", | ||
| "1", | ||
| "--logging_steps", | ||
| "1", | ||
| "--remove_unused_columns", | ||
| "False", | ||
| "--seed", | ||
| "42", | ||
| "--loss_output_file", | ||
| loss_file_cp, | ||
| ] | ||
| execute_subprocess_async(cmd_cp, env=self.get_env()) | ||
|
|
||
| # Run without CP (FSDP with num_processes=1, no parallelism_config) | ||
| config_path_no_cp = f"{self.test_file_dir}/context_parallel_no_cp_config.yaml" | ||
| loss_file_no_cp = os.path.join(output_dir, "losses_no_cp.json") | ||
|
|
||
| cmd_no_cp = [ | ||
| "accelerate", | ||
| "launch", | ||
| "--config_file", | ||
| config_path_no_cp, | ||
| f"{self.test_file_dir}/test_trainer_context_parallel.py", | ||
| "--output_dir", | ||
| os.path.join(output_dir, "without_cp"), | ||
| "--report_to", | ||
| "none", | ||
| "--max_steps", | ||
| "10", | ||
| "--per_device_train_batch_size", | ||
| "1", | ||
| "--gradient_accumulation_steps", | ||
| "1", | ||
| "--logging_steps", | ||
| "1", | ||
| "--remove_unused_columns", | ||
| "False", | ||
| "--seed", | ||
| "42", | ||
| "--loss_output_file", | ||
| loss_file_no_cp, | ||
| ] | ||
| execute_subprocess_async(cmd_no_cp, env=self.get_env()) | ||
|
|
||
| # Compare losses - should be very close since CP just splits sequence computation | ||
| with open(loss_file_cp) as f: | ||
| losses_cp = json.load(f) | ||
| with open(loss_file_no_cp) as f: | ||
| losses_no_cp = json.load(f) | ||
|
|
||
| assert len(losses_cp) == len(losses_no_cp), ( | ||
| f"Different number of losses: CP has {len(losses_cp)}, no-CP has {len(losses_no_cp)}" | ||
| ) | ||
|
|
||
| # CP should produce very similar results (small numerical differences expected) | ||
| # The differences come from: | ||
| # - Different gradient reduction patterns in distributed training | ||
| # - BF16 mixed precision accumulated differences | ||
| # - Sequence splitting and gathering in CP mode | ||
| losses_cp_tensor = torch.tensor(losses_cp) | ||
| losses_no_cp_tensor = torch.tensor(losses_no_cp) | ||
|
|
||
| # Use torch.testing.assert_close with rtol=2% and atol=0.02 | ||
| # Testing shows actual differences are typically <1.5% | ||
| torch.testing.assert_close( | ||
| losses_cp_tensor, | ||
| losses_no_cp_tensor, | ||
| rtol=2e-2, # 2% relative tolerance | ||
| atol=2e-2, # 0.02 absolute tolerance | ||
| msg=f"CP losses {losses_cp} do not match non-CP losses {losses_no_cp}", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # Parse custom arguments (not TrainingArguments parameters) | ||
| loss_output_file = None | ||
|
|
||
| if "--loss_output_file" in sys.argv: | ||
| idx = sys.argv.index("--loss_output_file") | ||
| loss_output_file = sys.argv[idx + 1] | ||
| sys.argv.pop(idx) | ||
| sys.argv.pop(idx) | ||
|
|
||
| parser = HfArgumentParser((TrainingArguments,)) | ||
| training_args = parser.parse_args_into_dataclasses()[0] | ||
|
|
||
| # Use SmolLM (small Llama-based model that works with CP) | ||
| model_name = "HuggingFaceTB/SmolLM-135M" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| if tokenizer.pad_token is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_name, | ||
| attn_implementation="sdpa", # CP requires SDPA | ||
| ) | ||
|
|
||
| # Create simple dataset: just tokenize some text | ||
| texts = [ | ||
| "The quick brown fox jumps over the lazy dog. " * 10, | ||
| "Hello world, this is a test sentence for training. " * 10, | ||
| ] * 4 # 8 samples total | ||
|
|
||
| def tokenize_function(examples): | ||
| return tokenizer(examples, max_length=128, truncation=True, padding="max_length") | ||
|
|
||
| train_dataset = [tokenize_function(text) for text in texts] | ||
|
|
||
| # Use standard DataCollatorForLanguageModeling for causal LM | ||
| # pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2) | ||
| # Trainer will automatically generate position_ids and shift_labels as needed | ||
| data_collator = DataCollatorForLanguageModeling( | ||
| tokenizer=tokenizer, | ||
| mlm=False, # Causal language modeling | ||
| pad_to_multiple_of=4, | ||
| ) | ||
|
|
||
| trainer = Trainer( | ||
| model=model, | ||
| args=training_args, | ||
| train_dataset=train_dataset, | ||
| data_collator=data_collator, | ||
| ) | ||
|
|
||
| # Train for a few steps | ||
| trainer.train() | ||
|
|
||
| # Verify training completed | ||
| assert trainer.state.global_step > 0, "Training should have completed at least one step" | ||
|
|
||
| # Save losses to file if requested (for equivalence testing) | ||
| if loss_output_file and training_args.process_index == 0: | ||
| losses = [log["loss"] for log in trainer.state.log_history if "loss" in log] | ||
| with open(loss_output_file, "w") as f: | ||
| json.dump(losses, f) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.