-
Notifications
You must be signed in to change notification settings - Fork 31.4k
HF Trainer: ALST/Ulysses sequence parallelism integration via HF Accelerate #41832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
cfee9c9
6e28ca8
86a09b9
bb902f9
c0e8e0d
101eaff
d8770d5
4f416a4
ce5e392
3ceaa94
607e166
211b6df
34b208c
816cc96
b3cbfb1
674db46
b12249a
4be7619
21ec5e5
bc32a16
a850a3a
0127933
a50c89c
5e29dd9
6ce745d
59972a3
f554277
0eef76f
c277586
8201150
854fd51
76ee3ad
083ca01
6929fb2
6ae2bb0
407f34a
363909b
8f62f14
6c5b00c
49c5ed7
4cafb9b
7c1abd5
58e4e13
d8d53c2
a05eb52
ad61079
3fd097d
e3d8eda
ef59f3e
59487a8
2444728
4f33c2f
7d09b28
2e52913
7a5c45e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| import gc | ||
| import importlib | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import multiprocessing | ||
| import os | ||
|
|
@@ -2005,14 +2006,12 @@ def get_env(self): | |
| paths = [self.repo_root_dir_str, self.src_dir_str] | ||
| if "/examples" in self.test_file_dir_str: | ||
| paths.append(self.examples_dir_str) | ||
| else: | ||
| paths.append(self.tests_dir_str) | ||
| paths.append(env.get("PYTHONPATH", "")) | ||
|
|
||
| env["PYTHONPATH"] = ":".join(paths) | ||
| return env | ||
|
|
||
| def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): | ||
| def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None, return_pathlib_obj=False): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a really old version. In the latest incarnation it always return a The latest version is here: https://github.com/stas00/ml-engineering/blob/master/testing/testing_utils.py If wanted you could switch to the latest version instead and adapt tests to simplify.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @ydshieh
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's much better for it to always return a |
||
| """ | ||
| Args: | ||
| tmp_dir (`string`, *optional*): | ||
|
|
@@ -2032,6 +2031,8 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): | |
| after (`bool`, *optional*): | ||
| If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents | ||
| intact at the end of the test. | ||
| return_pathlib_obj (`bool`, *optional*): | ||
| If `True` will return a pathlib.Path object | ||
|
|
||
| Returns: | ||
| tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir | ||
|
|
@@ -2078,7 +2079,7 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): | |
| # register for deletion | ||
| self.teardown_tmp_dirs.append(tmp_dir) | ||
|
|
||
| return tmp_dir | ||
| return Path(tmp_dir).resolve() if return_pathlib_obj else tmp_dir | ||
|
|
||
| def python_one_liner_max_rss(self, one_liner_str): | ||
| """ | ||
|
|
@@ -4076,3 +4077,13 @@ def use_one_line_repr(obj): | |
| cache[(id(obj), indent, mode, prefix)] = output | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| def write_file(file, content): | ||
| with open(file, "w") as f: | ||
| f.write(content) | ||
|
|
||
|
|
||
| def read_json_file(file): | ||
| with open(file, "r") as fh: | ||
| return json.load(fh) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2159,6 +2159,17 @@ def train( | |
| ignore_keys_for_eval=ignore_keys_for_eval, | ||
| ) | ||
|
|
||
| def get_cp_size(self) -> int: | ||
| """Get the context parallel size""" | ||
| if getattr(self.accelerator, "parallelism_config", None) is None: | ||
| return 1 | ||
| pc = self.accelerator.parallelism_config | ||
| if pc.cp_backend == "deepspeed": | ||
| return pc.cp_size | ||
| # XXX: not sure if backend=="torch" needs to return cp_size here as well it wasn't originally (a bug?) | ||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return 1 | ||
|
|
||
| def get_tp_size(self) -> int: | ||
| """Get the tensor parallel size from either the model or DeepSpeed config.""" | ||
|
|
||
|
|
@@ -2176,8 +2187,8 @@ def get_tp_size(self) -> int: | |
| def get_total_train_batch_size(self, args) -> int: | ||
| """Calculates total batch size (micro_batch * grad_accum * dp_world_size). | ||
|
|
||
| Note: Only considers DP and TP (dp_world_size = world_size // tp_size).""" | ||
| dp_world_size = args.world_size // self.get_tp_size() | ||
| Note: Only considers DP and TP and SP/CP (dp_world_size = world_size // tp_size // cp_size).""" | ||
| dp_world_size = args.world_size // self.get_tp_size() // self.get_cp_size() | ||
stas00 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size | ||
|
|
||
| def _inner_training_loop( | ||
|
|
@@ -2301,6 +2312,11 @@ def _inner_training_loop( | |
| else: | ||
| self.optimizer = self.accelerator.prepare(self.optimizer) | ||
|
|
||
| # since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared | ||
| pc = getattr(self.accelerator, "parallelism_config", None) | ||
| if pc is not None and pc.cp_backend == "deepspeed": | ||
| train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) | ||
|
|
||
stas00 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.is_fsdp_enabled: | ||
| self.model = self.model_wrapped = model | ||
|
|
||
|
|
@@ -3635,23 +3651,32 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch. | |
| getattr(self.accelerator, "parallelism_config", None) is not None | ||
| and self.accelerator.parallelism_config.cp_enabled | ||
| ): | ||
| if hasattr(model, "config"): | ||
| if model.config._attn_implementation != "sdpa": | ||
| raise ValueError( | ||
| f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}." | ||
| ) | ||
| if self.accelerator.parallelism_config.cp_backend == "torch": | ||
| if hasattr(model, "config"): | ||
| if model.config._attn_implementation != "sdpa": | ||
| raise ValueError( | ||
| f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}." | ||
| ) | ||
|
|
||
| if "shift_labels" not in inputs: | ||
| logger.warning_once("Shift labels not found in the inputs, shifting manually") | ||
| if "labels" in inputs: | ||
| _ignore_index = -100 | ||
| labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index) | ||
| inputs["shift_labels"] = labels[:, 1:].contiguous() | ||
|
|
||
| # carve out space to make it clear there are other backends with different requirements, even though no code needs to be run at the moment | ||
| elif self.accelerator.parallelism_config.cp_backend == "deepspeed": | ||
| # - accelerator.parallelism_config performs the `model.config._attn_implementation` checks already and it supports more than `dspa` | ||
| # - UlyssesSPDataLoaderAdapter called from Accelerate performs the `shift_label` creation - must not interfere | ||
| # - position_ids generation should be done by HF Trainer if it wasn't done by the user | ||
| pass | ||
|
|
||
| if "position_ids" not in inputs: | ||
| logger.warning_once("Position IDs not found in the inputs, generating manually") | ||
| inputs["position_ids"] = torch.arange( | ||
| inputs["input_ids"].size(1), device=inputs["input_ids"].device | ||
| ).expand(inputs["input_ids"].size(0), -1) | ||
| if "shift_labels" not in inputs: | ||
| logger.warning_once("Shift labels not found in the inputs, shifting manually") | ||
| if "labels" in inputs: | ||
| _ignore_index = -100 | ||
| labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index) | ||
| inputs["shift_labels"] = labels[:, 1:].contiguous() | ||
|
|
||
| buffers = [] | ||
| buffer_seq_dims = [] | ||
|
|
@@ -3820,6 +3845,33 @@ def compute_loss( | |
| Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss, | ||
| make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation. | ||
| """ | ||
| pc = getattr(self.accelerator, "parallelism_config", None) | ||
| if pc is not None and pc.cp_backend == "deepspeed": | ||
| unwrapped_model = self.accelerator.unwrap_model(model) | ||
|
|
||
| outputs = model(**inputs) | ||
| shift_labels = inputs["shift_labels"] | ||
| loss = unwrapped_model.loss_function( | ||
| logits=outputs.logits, | ||
| labels=None, | ||
| shift_labels=shift_labels, | ||
| vocab_size=unwrapped_model.config.vocab_size, | ||
| ) | ||
|
|
||
SunMarc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if pc.cp_size > 1: | ||
| sp_group = self.accelerator.torch_device_mesh["cp"].get_group() | ||
| sp_world_size = pc.cp_size | ||
| # differentiable weighted per-shard-loss aggregation across ranks | ||
| losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) | ||
| # special dealing with SFT that has prompt tokens that aren't used in loss computation | ||
| good_tokens = (shift_labels != -100).view(-1).sum() | ||
| good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) | ||
| total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) | ||
| total_good_tokens = sum(good_tokens_per_rank) | ||
| loss = total_loss / max(total_good_tokens, 1) | ||
|
||
|
|
||
| return (loss, outputs) if return_outputs else loss | ||
|
|
||
| if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: | ||
| labels = inputs.pop("labels") | ||
| else: | ||
|
|
@@ -3913,7 +3965,11 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa | |
| Path(os.path.join(output_dir, "user_content.pt")).touch() | ||
| # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank | ||
| elif getattr(self.accelerator, "parallelism_config", None) is not None: | ||
| if self.accelerator.should_save_model: | ||
| # deepspeed already takes care of saving the checkpoint below, so we need this only for the torch cp backend | ||
| if ( | ||
| self.accelerator.should_save_model | ||
| and getattr(self.accelerator, "parallelism_config").cp_backend == "torch" | ||
| ): | ||
|
||
| self._save(output_dir) | ||
| # If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained` | ||
| elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1: | ||
|
|
@@ -4981,9 +5037,11 @@ def create_accelerator_and_postprocess(self): | |
|
|
||
| # We defer compatibility checks to accelerator | ||
| if self.args.parallelism_config is not None: | ||
| if not is_accelerate_available("1.10.1"): | ||
| # XXX: this will need to change once https://github.com/huggingface/accelerate/pull/3817 is merged and 1.11.1 is out | ||
kashif marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| min_accelerate_version = "1.10.1" | ||
kashif marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not is_accelerate_available(min_accelerate_version): | ||
| raise ImportError( | ||
| "ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature." | ||
| f"ParallelismConfig requires accelerate>={min_accelerate_version}). Please upgrade accelerate to use this feature." | ||
| ) | ||
| args["parallelism_config"] = self.args.parallelism_config | ||
|
|
||
|
|
@@ -5136,7 +5194,9 @@ def _get_num_items_in_batch(self, batch_samples: list, device: torch.device) -> | |
| # In the DataParallel case, convert the scalar tensor into a 2-dim tensor with the same value repeated | ||
| num_items_in_batch = num_items_in_batch.unsqueeze(0).expand(self.args.n_gpu, -1) | ||
| # Divide by number of devices with the same batch | ||
| if pc := getattr(self.accelerator, "parallelism_config", None): | ||
| # XXX: double check if this is only for fsdp | ||
| pc = getattr(self.accelerator, "parallelism_config", None) | ||
| if pc is not None and pc.cp_backend == "torch": | ||
stas00 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
stas00 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size | ||
|
|
||
| return num_items_in_batch | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.