Skip to content

Commit d87a249

Browse files
committed
address comments
Signed-off-by: Mamta Singh <[email protected]>
1 parent d3e3029 commit d87a249

File tree

4 files changed

+81
-66
lines changed

4 files changed

+81
-66
lines changed

QEfficient/cloud/finetune.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
generate_peft_config,
2727
update_config,
2828
)
29-
from QEfficient.finetune.utils.dataset_utils import get_dataloader
29+
from QEfficient.finetune.utils.dataset_utils import get_dataloader, get_longest_seq_length
3030
from QEfficient.finetune.utils.device_map import get_device_map
31-
from QEfficient.finetune.utils.helper import Task_Mode, get_longest_seq_length
31+
from QEfficient.finetune.utils.helper import Task_Mode
3232
from QEfficient.finetune.utils.logging_utils import logger
3333
from QEfficient.finetune.utils.parser import get_finetune_parser
3434
from QEfficient.finetune.utils.train_utils import print_model_size, print_trainable_parameters, train
@@ -67,7 +67,7 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6767
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
6868
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
6969
dist.init_process_group(backend=dist_backend_map[torch_device.type])
70-
if train_config.enable_pp:
70+
if train_config.num_pp_stages > 1:
7171
assert dist.get_world_size() * train_config.num_pp_stages == getattr(torch, torch_device.type).device_count(), (
7272
"Total available devices should be multiple of number of pipeline stages."
7373
)
@@ -97,7 +97,7 @@ def load_model_and_tokenizer(
9797
"""Load the pre-trained model and tokenizer from Hugging Face.
9898
9999
Args:
100-
config (TrainConfig): Training configuration object containing model and tokenizer names.
100+
train_config (TrainConfig): Training configuration object containing model and tokenizer names.
101101
dataset_config (Any): A dataclass object representing dataset configuration.
102102
kwargs: Additional arguments to override PEFT config.
103103
@@ -135,26 +135,14 @@ def load_model_and_tokenizer(
135135
if param.requires_grad:
136136
param.data = param.data.to(torch.float32)
137137
else:
138-
if train_config.enable_pp:
139-
if train_config.enable_ddp:
140-
device_map = get_device_map(train_config.model_name, train_config.num_pp_stages, rank=dist.get_rank())
141-
else:
142-
device_map = "auto"
143-
model = AutoModelForCausalLM.from_pretrained(
144-
pretrained_model_path,
145-
use_cache=False,
146-
attn_implementation="sdpa",
147-
torch_dtype=torch.float16,
148-
device_map=device_map,
149-
)
150-
else:
151-
model = AutoModelForCausalLM.from_pretrained(
152-
pretrained_model_path,
153-
use_cache=False,
154-
attn_implementation="sdpa",
155-
torch_dtype=torch.float16,
156-
)
157-
138+
device_map = get_device_map(train_config)
139+
model = AutoModelForCausalLM.from_pretrained(
140+
pretrained_model_path,
141+
use_cache=False,
142+
attn_implementation="sdpa",
143+
torch_dtype=torch.float16,
144+
device_map=device_map,
145+
)
158146
tokenizer = AutoTokenizer.from_pretrained(
159147
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
160148
)
@@ -307,7 +295,7 @@ def main(**kwargs) -> None:
307295
f"passed context length is {train_config.context_length} and overall model's context length is "
308296
f"{model.config.max_position_embeddings}"
309297
)
310-
if not train_config.enable_pp:
298+
if train_config.num_pp_stages == 1:
311299
model.to(train_config.device)
312300
optimizer = optim.AdamW(
313301
model.parameters(),
@@ -320,6 +308,8 @@ def main(**kwargs) -> None:
320308
for name, param in model.named_parameters():
321309
if not param.requires_grad:
322310
ignore_names.add(name)
311+
# Adding params in ignore list will enforce DDP to ignore them during synchronization,
312+
# which will further reduce the tensor exchange across devices.
323313
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(model, ignore_names)
324314
model = nn.parallel.DistributedDataParallel(model)
325315

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
from typing import Dict, List, Tuple
9+
710
import datasets
811
import torch
912
import torch.distributed as dist
@@ -116,3 +119,11 @@ def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"
116119
**dl_kwargs,
117120
)
118121
return dataloader
122+
123+
124+
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
125+
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
126+
lengths = [len(d["input_ids"]) for d in data]
127+
longest_seq_length = max(lengths)
128+
longest_seq_ix = lengths.index(longest_seq_length)
129+
return longest_seq_length, longest_seq_ix

QEfficient/finetune/utils/device_map.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,40 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
import math
8+
import os
99

10+
import numpy as np
1011
from transformers import AutoConfig
1112

1213
from QEfficient.utils._utils import get_num_layers_from_config
1314

1415

15-
def get_device_map(model_name, num_pp_stages, rank):
16-
"""Returns device map for model layers based number of pipeline stages and given process rank.
16+
def get_device_map(train_config):
17+
"""Returns device map for the given model.
1718
1819
Args:
19-
model_name (str): model name to get the device map for.
20-
num_pp_stages (int): number of stages in pipeline
21-
rank (int): process rank
20+
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
21+
22+
Returns:
23+
Dict: A dictionary of layers and corresponding device id.
24+
"""
25+
26+
if train_config.num_pp_stages > 1:
27+
if train_config.enable_ddp:
28+
device_map = custom_device_map(train_config)
29+
else:
30+
device_map = "auto"
31+
else:
32+
device_map = None
33+
34+
return device_map
35+
36+
37+
def custom_device_map(train_config):
38+
"""Returns custom device map for model layers based number of pipeline stages and given process rank.
39+
40+
Args:
41+
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
2242
2343
Returns:
2444
Dict: A dictionary of layers and corresponding device id.
@@ -33,32 +53,34 @@ def get_device_map(model_name, num_pp_stages, rank):
3353
PP (Pipeline Parallelism): Each copy of the model is split into 2 stages
3454
DDP (Distributed Data Parallel): 2 model copies run in parallel
3555
36-
|-------------------------------------------------------------------------------
37-
| Process Rank | Assigned Device IDs | Model Component |
38-
|-------------------------------------------------------------------------------
39-
| Rank 0 | 0 | model.embed_tokens |
40-
| | | model.lm_head |
41-
| | | model.layers.0 - model.layers.7 |
42-
|-------------------------------------------------------------------------------
43-
| Rank 0 | 1 | model.norm |
44-
| | | model.rotary_emb |
45-
| | | model.layers.8 - model.layers.15 |
46-
|-------------------------------------------------------------------------------
47-
| Rank 1 | 2 | model.embed_tokens |
48-
| | | model.lm_head |
49-
| | | model.layers.0 - model.layers.7 |
50-
|-------------------------------------------------------------------------------
51-
| Rank 1 | 3 | model.norm |
52-
| | | model.rotary_emb |
53-
| | | model.layers.8 - model.layers.15 |
54-
|-------------------------------------------------------------------------------
56+
|--------------------------------------------------------------------------|
57+
| Process Rank | Assigned Device IDs | Model Component |
58+
|--------------------------------------------------------------------------|
59+
| Rank 0 | 0 | model.embed_tokens |
60+
| | | model.lm_head |
61+
| | | model.layers.0 - model.layers.7 |
62+
|--------------------------------------------------------------------------|
63+
| Rank 0 | 1 | model.norm |
64+
| | | model.rotary_emb |
65+
| | | model.layers.8 - model.layers.15 |
66+
|--------------------------------------------------------------------------|
67+
| Rank 1 | 2 | model.embed_tokens |
68+
| | | model.lm_head |
69+
| | | model.layers.0 - model.layers.7 |
70+
|--------------------------------------------------------------------------|
71+
| Rank 1 | 3 | model.norm |
72+
| | | model.rotary_emb |
73+
| | | model.layers.8 - model.layers.15 |
74+
|--------------------------------------------------------------------------|
5575
"""
5676

57-
config = AutoConfig.from_pretrained(model_name)
77+
config = AutoConfig.from_pretrained(train_config.model_name)
5878
num_layers = get_num_layers_from_config(config)
5979

60-
first_device = rank * num_pp_stages
61-
last_device = rank * num_pp_stages + (num_pp_stages - 1)
80+
n = train_config.num_pp_stages
81+
rank = int(os.getenv("LOCAL_RANK", 0))
82+
first_device = rank * n
83+
last_device = rank * n + (n - 1)
6284

6385
if config.tie_word_embeddings:
6486
lm_head_device = first_device
@@ -71,11 +93,12 @@ def get_device_map(model_name, num_pp_stages, rank):
7193
"model.norm": last_device,
7294
"model.rotary_emb": last_device,
7395
}
96+
n_layer_per_stage = np.ceil(num_layers / n)
7497

75-
n_layer_per_stage = math.ceil(num_layers / num_pp_stages)
98+
pp_stage_ids = np.arange(n)
99+
pp_device_map = np.repeat(pp_stage_ids, n_layer_per_stage)
76100

77-
for j in range(num_pp_stages):
78-
for i in range(n_layer_per_stage * j, min(n_layer_per_stage * (j + 1), num_layers)):
79-
device_map[f"model.layers.{i}"] = first_device + j
101+
for i in range(num_layers):
102+
device_map[f"model.layers.{i}"] = pp_device_map[i] + rank * n
80103

81104
return device_map

QEfficient/finetune/utils/helper.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import os
1010
from contextlib import nullcontext
1111
from enum import Enum
12-
from typing import Dict, List, Tuple
1312

1413
import torch
1514

@@ -82,14 +81,6 @@ def get_op_verifier_ctx(
8281
)
8382

8483

85-
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
86-
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
87-
lengths = [len(d["input_ids"]) for d in data]
88-
longest_seq_length = max(lengths)
89-
longest_seq_ix = lengths.index(longest_seq_length)
90-
return longest_seq_length, longest_seq_ix
91-
92-
9384
def save_to_json(
9485
output_filename,
9586
train_step_loss,

0 commit comments

Comments
 (0)