Skip to content

Commit c0b4e18

Browse files
mamtsingquic-mamta
authored andcommitted
address comments
Signed-off-by: Mamta Singh <[email protected]>
1 parent d3e3029 commit c0b4e18

File tree

6 files changed

+79
-73
lines changed

6 files changed

+79
-73
lines changed

QEfficient/cloud/finetune.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
generate_peft_config,
2727
update_config,
2828
)
29-
from QEfficient.finetune.utils.dataset_utils import get_dataloader
29+
from QEfficient.finetune.utils.dataset_utils import get_dataloader, get_longest_seq_length
3030
from QEfficient.finetune.utils.device_map import get_device_map
31-
from QEfficient.finetune.utils.helper import Task_Mode, get_longest_seq_length
31+
from QEfficient.finetune.utils.helper import Task_Mode
3232
from QEfficient.finetune.utils.logging_utils import logger
3333
from QEfficient.finetune.utils.parser import get_finetune_parser
3434
from QEfficient.finetune.utils.train_utils import print_model_size, print_trainable_parameters, train
@@ -67,7 +67,7 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6767
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
6868
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
6969
dist.init_process_group(backend=dist_backend_map[torch_device.type])
70-
if train_config.enable_pp:
70+
if train_config.num_pp_stages > 1:
7171
assert dist.get_world_size() * train_config.num_pp_stages == getattr(torch, torch_device.type).device_count(), (
7272
"Total available devices should be multiple of number of pipeline stages."
7373
)
@@ -97,7 +97,7 @@ def load_model_and_tokenizer(
9797
"""Load the pre-trained model and tokenizer from Hugging Face.
9898
9999
Args:
100-
config (TrainConfig): Training configuration object containing model and tokenizer names.
100+
train_config (TrainConfig): Training configuration object containing model and tokenizer names.
101101
dataset_config (Any): A dataclass object representing dataset configuration.
102102
kwargs: Additional arguments to override PEFT config.
103103
@@ -135,26 +135,14 @@ def load_model_and_tokenizer(
135135
if param.requires_grad:
136136
param.data = param.data.to(torch.float32)
137137
else:
138-
if train_config.enable_pp:
139-
if train_config.enable_ddp:
140-
device_map = get_device_map(train_config.model_name, train_config.num_pp_stages, rank=dist.get_rank())
141-
else:
142-
device_map = "auto"
143-
model = AutoModelForCausalLM.from_pretrained(
144-
pretrained_model_path,
145-
use_cache=False,
146-
attn_implementation="sdpa",
147-
torch_dtype=torch.float16,
148-
device_map=device_map,
149-
)
150-
else:
151-
model = AutoModelForCausalLM.from_pretrained(
152-
pretrained_model_path,
153-
use_cache=False,
154-
attn_implementation="sdpa",
155-
torch_dtype=torch.float16,
156-
)
157-
138+
device_map = get_device_map(train_config)
139+
model = AutoModelForCausalLM.from_pretrained(
140+
pretrained_model_path,
141+
use_cache=False,
142+
attn_implementation="sdpa",
143+
torch_dtype=torch.float16,
144+
device_map=device_map,
145+
)
158146
tokenizer = AutoTokenizer.from_pretrained(
159147
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
160148
)
@@ -307,7 +295,7 @@ def main(**kwargs) -> None:
307295
f"passed context length is {train_config.context_length} and overall model's context length is "
308296
f"{model.config.max_position_embeddings}"
309297
)
310-
if not train_config.enable_pp:
298+
if train_config.num_pp_stages == 1:
311299
model.to(train_config.device)
312300
optimizer = optim.AdamW(
313301
model.parameters(),
@@ -320,6 +308,8 @@ def main(**kwargs) -> None:
320308
for name, param in model.named_parameters():
321309
if not param.requires_grad:
322310
ignore_names.add(name)
311+
# Adding params in ignore list will enforce DDP to ignore them during synchronization,
312+
# which will further reduce the tensor exchange across devices.
323313
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(model, ignore_names)
324314
model = nn.parallel.DistributedDataParallel(model)
325315

QEfficient/finetune/configs/training.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ class TrainConfig:
5050
convergence_counter (int): Steps to check convergence (default: 5).
5151
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
5252
use_profiler (bool): Enable profiling (default: False).
53-
enable_pp (bool): Enable training with pipeline parallelism (default: False).
5453
num_pp_stages (int): Number of stages in which model is split layerwise when training using pipeline (default: 1).
5554
enable_ddp (bool): Enable distributed data parallel (default: False).
5655
enable_sorting_for_ddp (bool): Sort data for DDP (default: True).
@@ -101,7 +100,6 @@ class TrainConfig:
101100
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
102101

103102
# dist-related
104-
enable_pp: bool = False
105103
num_pp_stages: int = 1
106104
enable_ddp: bool = False
107105
enable_sorting_for_ddp: bool = True

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
from typing import Dict, List, Tuple
9+
710
import datasets
811
import torch
912
import torch.distributed as dist
@@ -116,3 +119,11 @@ def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"
116119
**dl_kwargs,
117120
)
118121
return dataloader
122+
123+
124+
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
125+
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
126+
lengths = [len(d["input_ids"]) for d in data]
127+
longest_seq_length = max(lengths)
128+
longest_seq_ix = lengths.index(longest_seq_length)
129+
return longest_seq_length, longest_seq_ix

QEfficient/finetune/utils/device_map.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,40 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
import math
8+
import os
99

10+
import numpy as np
1011
from transformers import AutoConfig
1112

1213
from QEfficient.utils._utils import get_num_layers_from_config
1314

1415

15-
def get_device_map(model_name, num_pp_stages, rank):
16-
"""Returns device map for model layers based number of pipeline stages and given process rank.
16+
def get_device_map(train_config):
17+
"""Returns device map for the given model.
1718
1819
Args:
19-
model_name (str): model name to get the device map for.
20-
num_pp_stages (int): number of stages in pipeline
21-
rank (int): process rank
20+
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
21+
22+
Returns:
23+
Dict: A dictionary of layers and corresponding device id.
24+
"""
25+
26+
if train_config.num_pp_stages > 1:
27+
if train_config.enable_ddp:
28+
device_map = custom_device_map(train_config)
29+
else:
30+
device_map = "auto"
31+
else:
32+
device_map = None
33+
34+
return device_map
35+
36+
37+
def custom_device_map(train_config):
38+
"""Returns custom device map for model layers based number of pipeline stages and given process rank.
39+
40+
Args:
41+
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
2242
2343
Returns:
2444
Dict: A dictionary of layers and corresponding device id.
@@ -33,30 +53,31 @@ def get_device_map(model_name, num_pp_stages, rank):
3353
PP (Pipeline Parallelism): Each copy of the model is split into 2 stages
3454
DDP (Distributed Data Parallel): 2 model copies run in parallel
3555
36-
|-------------------------------------------------------------------------------
37-
| Process Rank | Assigned Device IDs | Model Component |
38-
|-------------------------------------------------------------------------------
39-
| Rank 0 | 0 | model.embed_tokens |
40-
| | | model.lm_head |
41-
| | | model.layers.0 - model.layers.7 |
42-
|-------------------------------------------------------------------------------
43-
| Rank 0 | 1 | model.norm |
44-
| | | model.rotary_emb |
45-
| | | model.layers.8 - model.layers.15 |
46-
|-------------------------------------------------------------------------------
47-
| Rank 1 | 2 | model.embed_tokens |
48-
| | | model.lm_head |
49-
| | | model.layers.0 - model.layers.7 |
50-
|-------------------------------------------------------------------------------
51-
| Rank 1 | 3 | model.norm |
52-
| | | model.rotary_emb |
53-
| | | model.layers.8 - model.layers.15 |
54-
|-------------------------------------------------------------------------------
56+
|--------------------------------------------------------------------------|
57+
| Process Rank | Assigned Device IDs | Model Component |
58+
|--------------------------------------------------------------------------|
59+
| Rank 0 | 0 | model.embed_tokens |
60+
| | | model.lm_head |
61+
| | | model.layers.0 - model.layers.7 |
62+
|--------------------------------------------------------------------------|
63+
| Rank 0 | 1 | model.norm |
64+
| | | model.rotary_emb |
65+
| | | model.layers.8 - model.layers.15 |
66+
|--------------------------------------------------------------------------|
67+
| Rank 1 | 2 | model.embed_tokens |
68+
| | | model.lm_head |
69+
| | | model.layers.0 - model.layers.7 |
70+
|--------------------------------------------------------------------------|
71+
| Rank 1 | 3 | model.norm |
72+
| | | model.rotary_emb |
73+
| | | model.layers.8 - model.layers.15 |
74+
|--------------------------------------------------------------------------|
5575
"""
5676

57-
config = AutoConfig.from_pretrained(model_name)
77+
config = AutoConfig.from_pretrained(train_config.model_name)
5878
num_layers = get_num_layers_from_config(config)
59-
79+
num_pp_stages = train_config.num_pp_stages
80+
rank = int(os.getenv("LOCAL_RANK", 0))
6081
first_device = rank * num_pp_stages
6182
last_device = rank * num_pp_stages + (num_pp_stages - 1)
6283

@@ -71,11 +92,12 @@ def get_device_map(model_name, num_pp_stages, rank):
7192
"model.norm": last_device,
7293
"model.rotary_emb": last_device,
7394
}
95+
n_layer_per_stage = np.ceil(num_layers / num_pp_stages)
7496

75-
n_layer_per_stage = math.ceil(num_layers / num_pp_stages)
97+
pp_stage_ids = np.arange(num_pp_stages)
98+
pp_device_map = np.repeat(pp_stage_ids, n_layer_per_stage)
7699

77-
for j in range(num_pp_stages):
78-
for i in range(n_layer_per_stage * j, min(n_layer_per_stage * (j + 1), num_layers)):
79-
device_map[f"model.layers.{i}"] = first_device + j
100+
for i in range(num_layers):
101+
device_map[f"model.layers.{i}"] = pp_device_map[i] + rank * num_pp_stages
80102

81103
return device_map

QEfficient/finetune/utils/helper.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import os
1010
from contextlib import nullcontext
1111
from enum import Enum
12-
from typing import Dict, List, Tuple
1312

1413
import torch
1514

@@ -82,14 +81,6 @@ def get_op_verifier_ctx(
8281
)
8382

8483

85-
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
86-
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
87-
lengths = [len(d["input_ids"]) for d in data]
88-
longest_seq_length = max(lengths)
89-
longest_seq_ix = lengths.index(longest_seq_length)
90-
return longest_seq_length, longest_seq_ix
91-
92-
9384
def save_to_json(
9485
output_filename,
9586
train_step_loss,

QEfficient/finetune/utils/parser.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,6 @@ def get_finetune_parser():
262262
action="store_true",
263263
help="Enable distributed data parallel training. This will load the replicas of model on given number of devices and train the model. This should be used using torchrun interface. Please check docs for exact usage.",
264264
)
265-
parser.add_argument(
266-
"--enable_pp",
267-
"--enable-pp",
268-
action="store_true",
269-
help="Enable pipeline parallel training. This will split the of model layerwise in given number of stages and train the model.",
270-
)
271265
parser.add_argument(
272266
"--num_pp_stages",
273267
"--num-pp-stages",

0 commit comments

Comments
 (0)