Skip to content

Commit 5b7a315

Browse files
quic-mamtamamtsing
andauthored
[QEff Finetune]: Enable PP+DDP (#394)
Added support for PP+DDP Command for PP only : QAIC_VISIBLE_DEVICES=0,1,2,3 python -m QEfficient.cloud.finetune --device qaic --enable_pp --num_pp_stages 4 (number of pipeline stages must be less than or equal to total available devices) Command for DDP only : QAIC_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 -m QEfficient.cloud.finetune --device qaic --enable_ddp Command for PP+DDP : For 4 qaic devices(1 Ultra) with 2 pipeline stages QAIC_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node 2 -m QEfficient.cloud.finetune --device qaic --enable_ddp --enable_pp --num_pp_stages 2 --------- Signed-off-by: Mamta Singh <[email protected]> Co-authored-by: Mamta Singh <[email protected]>
1 parent 5fb7532 commit 5b7a315

File tree

7 files changed

+242
-79
lines changed

7 files changed

+242
-79
lines changed

QEfficient/cloud/finetune.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,12 @@
2626
generate_peft_config,
2727
update_config,
2828
)
29-
from QEfficient.finetune.utils.dataset_utils import get_dataloader
30-
from QEfficient.finetune.utils.helper import Task_Mode
29+
from QEfficient.finetune.utils.dataset_utils import get_dataloader, get_longest_seq_length
30+
from QEfficient.finetune.utils.device_map import get_device_map
31+
from QEfficient.finetune.utils.helper import Task_Mode, get_world_size
3132
from QEfficient.finetune.utils.logging_utils import logger
3233
from QEfficient.finetune.utils.parser import get_finetune_parser
33-
from QEfficient.finetune.utils.train_utils import (
34-
get_longest_seq_length,
35-
print_model_size,
36-
print_trainable_parameters,
37-
train,
38-
)
34+
from QEfficient.finetune.utils.train_utils import print_model_size, print_trainable_parameters, train
3935
from QEfficient.utils._utils import hf_download
4036

4137
# Try importing QAIC-specific module, proceed without it if unavailable
@@ -63,17 +59,27 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6359
Raises:
6460
AssertionError: If device is CPU or includes an index with DDP enabled.
6561
"""
62+
63+
torch_device = torch.device(train_config.device)
64+
num_available_devices = getattr(torch, torch_device.type).device_count()
65+
assert get_world_size() * train_config.num_pp_stages <= num_available_devices, (
66+
"Number of devices required should be less than or equal to total available devices."
67+
)
68+
if train_config.enable_pp:
69+
assert train_config.num_pp_stages > 1, (
70+
f"For pipeline parallelism, num_pp_stages should be greater than 1. Got {train_config.num_pp_stages}"
71+
)
72+
6673
if not train_config.enable_ddp:
6774
return
6875

69-
torch_device = torch.device(train_config.device)
7076
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
7177
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
72-
7378
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
7479
dist.init_process_group(backend=dist_backend_map[torch_device.type])
75-
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
76-
getattr(torch, torch_device.type).set_device(dist.get_rank())
80+
if not train_config.enable_pp:
81+
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
82+
getattr(torch, torch_device.type).set_device(dist.get_rank())
7783

7884

7985
def setup_seeds(seed: int) -> None:
@@ -85,6 +91,10 @@ def setup_seeds(seed: int) -> None:
8591
Notes:
8692
- Sets seeds for PyTorch, Python's random module, and NumPy.
8793
"""
94+
torch.use_deterministic_algorithms(True)
95+
# With this flag, PP+DDP works only for meta-llama/Llama-3.2-1B and mistralai/Mistral-7B-Instruct-v0.3
96+
# and throws error during loading model for meta-llama/Llama-3.1-8B and bigger size models.
97+
8898
torch.manual_seed(seed)
8999
random.seed(seed)
90100
np.random.seed(seed)
@@ -96,7 +106,7 @@ def load_model_and_tokenizer(
96106
"""Load the pre-trained model and tokenizer from Hugging Face.
97107
98108
Args:
99-
config (TrainConfig): Training configuration object containing model and tokenizer names.
109+
train_config (TrainConfig): Training configuration object containing model and tokenizer names.
100110
dataset_config (Any): A dataclass object representing dataset configuration.
101111
kwargs: Additional arguments to override PEFT config.
102112
@@ -112,7 +122,10 @@ def load_model_and_tokenizer(
112122
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
113123
"""
114124
logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
115-
pretrained_model_path = hf_download(train_config.model_name)
125+
pretrained_model_path = hf_download(
126+
train_config.model_name,
127+
ignore_patterns=["*.txt", "*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.msgpack", "*.h5", "*.pth"],
128+
)
116129
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
117130
model = AutoModelForSequenceClassification.from_pretrained(
118131
pretrained_model_path,
@@ -131,13 +144,14 @@ def load_model_and_tokenizer(
131144
if param.requires_grad:
132145
param.data = param.data.to(torch.float32)
133146
else:
147+
device_map = get_device_map(train_config)
134148
model = AutoModelForCausalLM.from_pretrained(
135149
pretrained_model_path,
136150
use_cache=False,
137151
attn_implementation="sdpa",
138152
torch_dtype=torch.float16,
153+
device_map=device_map,
139154
)
140-
141155
tokenizer = AutoTokenizer.from_pretrained(
142156
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
143157
)
@@ -290,11 +304,24 @@ def main(**kwargs) -> None:
290304
f"passed context length is {train_config.context_length} and overall model's context length is "
291305
f"{model.config.max_position_embeddings}"
292306
)
293-
model.to(train_config.device)
294-
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
307+
if not train_config.enable_pp:
308+
model.to(train_config.device)
309+
optimizer = optim.AdamW(
310+
model.parameters(),
311+
lr=train_config.lr,
312+
weight_decay=train_config.weight_decay,
313+
)
295314
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
296315
if train_config.enable_ddp:
297-
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
316+
ignore_names = set()
317+
for name, param in model.named_parameters():
318+
if not param.requires_grad:
319+
ignore_names.add(name)
320+
# Adding params in ignore list will enforce DDP to ignore them during synchronization,
321+
# which will further reduce the tensor exchange across devices.
322+
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(model, ignore_names)
323+
model = nn.parallel.DistributedDataParallel(model)
324+
298325
results = train(
299326
model,
300327
tokenizer,
@@ -303,7 +330,6 @@ def main(**kwargs) -> None:
303330
optimizer,
304331
scheduler,
305332
train_config,
306-
dist.get_rank() if train_config.enable_ddp else None,
307333
)
308334
if train_config.enable_ddp:
309335
dist.destroy_process_group()

QEfficient/finetune/configs/training.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ class TrainConfig:
4747
save_metrics (bool): Save training metrics (default: True).
4848
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
4949
batching_strategy (str): Batching strategy (default: "packing").
50-
enable_sorting_for_ddp (bool): Sort data for DDP (default: True).
5150
convergence_counter (int): Steps to check convergence (default: 5).
5251
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
5352
use_profiler (bool): Enable profiling (default: False).
53+
enable_pp (bool): Enable training with pipeline parallelism (default: False).
54+
num_pp_stages (int): Number of stages in which model is split layerwise when training using pipeline (default: 1).
5455
enable_ddp (bool): Enable distributed data parallel (default: False).
56+
enable_sorting_for_ddp (bool): Sort data for DDP (default: True).
5557
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
5658
dump_logs (bool): Whether to dump logs (default: True).
5759
log_level (str): logging level (default: logging.INFO)
@@ -87,8 +89,6 @@ class TrainConfig:
8789
save_metrics: bool = True # saves training metrics to a json file for later plotting
8890
intermediate_step_save: int = 1000
8991
batching_strategy: str = Batching_Strategy.PADDING.value
90-
enable_ddp: bool = False
91-
enable_sorting_for_ddp: bool = True
9292
convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps
9393
convergence_loss: float = (
9494
1e-4 # if loss value is <= convergence_loss for #convergence_counter consecutive steps, fine tuning stops
@@ -100,6 +100,11 @@ class TrainConfig:
100100
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
101101
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
102102

103+
# dist-related
104+
enable_pp: bool = False
105+
num_pp_stages: int = 1
106+
enable_ddp: bool = False
107+
enable_sorting_for_ddp: bool = True
103108
opByOpVerifier: bool = False
104109

105110
dump_logs: bool = True

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
78
import logging
9+
from typing import Dict, List, Tuple
810

911
import datasets
1012
import torch
@@ -13,7 +15,7 @@
1315

1416
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
1517
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
16-
from QEfficient.finetune.utils.helper import get_num_ddp_devices
18+
from QEfficient.finetune.utils.helper import get_world_size
1719
from QEfficient.finetune.utils.logging_utils import logger
1820

1921

@@ -68,7 +70,7 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
6870

6971

7072
def padding_dataset(train_config, dataset, batch_size):
71-
num_replicas = get_num_ddp_devices()
73+
num_replicas = get_world_size()
7274
remainder = len(dataset) % (num_replicas * batch_size)
7375
if remainder == 0:
7476
return dataset
@@ -125,3 +127,11 @@ def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"
125127
**dl_kwargs,
126128
)
127129
return dataloader
130+
131+
132+
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
133+
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
134+
lengths = [len(d["input_ids"]) for d in data]
135+
longest_seq_length = max(lengths)
136+
longest_seq_ix = lengths.index(longest_seq_length)
137+
return longest_seq_length, longest_seq_ix
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
import numpy as np
10+
import torch
11+
from transformers import AutoConfig
12+
13+
from QEfficient.finetune.utils.helper import get_rank
14+
from QEfficient.utils._utils import get_num_layers_from_config
15+
16+
17+
def get_device_map(train_config):
18+
"""Returns device map for the given model.
19+
20+
Args:
21+
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
22+
23+
Returns:
24+
Dict: A dictionary of layers and corresponding device id.
25+
"""
26+
torch_device = torch.device(train_config.device)
27+
num_available_devices = getattr(torch, torch_device.type).device_count()
28+
if train_config.enable_pp:
29+
if train_config.enable_ddp:
30+
device_map = custom_device_map(train_config)
31+
elif train_config.num_pp_stages < num_available_devices:
32+
device_map = custom_device_map(train_config)
33+
elif train_config.num_pp_stages == num_available_devices:
34+
device_map = "auto"
35+
else:
36+
device_map = None
37+
38+
return device_map
39+
40+
41+
def custom_device_map(train_config):
42+
"""Returns custom device map for model layers based number of pipeline stages and given process rank.
43+
44+
Args:
45+
train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
46+
47+
Returns:
48+
Dict: A dictionary of layers and corresponding device id.
49+
50+
Notes:
51+
- This device map structure is verified for llama models only.
52+
53+
Example:
54+
Configuration for meta-llama/Llama-3.2-1B
55+
Total devices: 4 (2x PP x 2x DDP)
56+
57+
PP (Pipeline Parallelism): Each copy of the model is split into 2 stages
58+
DDP (Distributed Data Parallel): 2 model copies run in parallel
59+
60+
|--------------------------------------------------------------------------|
61+
| Process Rank | Assigned Device IDs | Model Component |
62+
|--------------------------------------------------------------------------|
63+
| Rank 0 | 0 | model.embed_tokens |
64+
| | | model.lm_head |
65+
| | | model.layers.0 - model.layers.7 |
66+
|--------------------------------------------------------------------------|
67+
| Rank 0 | 1 | model.norm |
68+
| | | model.rotary_emb |
69+
| | | model.layers.8 - model.layers.15 |
70+
|--------------------------------------------------------------------------|
71+
| Rank 1 | 2 | model.embed_tokens |
72+
| | | model.lm_head |
73+
| | | model.layers.0 - model.layers.7 |
74+
|--------------------------------------------------------------------------|
75+
| Rank 1 | 3 | model.norm |
76+
| | | model.rotary_emb |
77+
| | | model.layers.8 - model.layers.15 |
78+
|--------------------------------------------------------------------------|
79+
"""
80+
81+
model_config = AutoConfig.from_pretrained(train_config.model_name)
82+
num_layers = get_num_layers_from_config(model_config)
83+
num_pp_stages = train_config.num_pp_stages
84+
rank = get_rank()
85+
first_device = rank * num_pp_stages
86+
last_device = rank * num_pp_stages + (num_pp_stages - 1)
87+
88+
if model_config.tie_word_embeddings:
89+
lm_head_device = first_device
90+
else:
91+
lm_head_device = last_device
92+
93+
device_map = {
94+
"model.embed_tokens": first_device,
95+
"lm_head": lm_head_device,
96+
"model.norm": last_device,
97+
"model.rotary_emb": last_device,
98+
}
99+
n_layer_per_stage = np.ceil(num_layers / num_pp_stages)
100+
101+
pp_stage_ids = np.arange(num_pp_stages)
102+
pp_device_map = np.repeat(pp_stage_ids, n_layer_per_stage)
103+
104+
for i in range(num_layers):
105+
device_map[f"model.layers.{i}"] = pp_device_map[i] + rank * num_pp_stages
106+
107+
return device_map

QEfficient/finetune/utils/helper.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import json
79
import os
810
from contextlib import nullcontext
911
from enum import Enum
@@ -37,14 +39,18 @@ class Task_Mode(str, Enum):
3739

3840

3941
def enum_names(enum_cls):
40-
return [member.value for member in enum_cls]
42+
return [x.value for x in enum_cls]
43+
44+
45+
def get_rank():
46+
return int(os.getenv("LOCAL_RANK", 0))
4147

4248

4349
def is_rank_zero():
44-
return int(os.getenv("LOCAL_RANK", 0)) == 0
50+
return get_rank() == 0
4551

4652

47-
def get_num_ddp_devices():
53+
def get_world_size():
4854
return int(os.getenv("WORLD_SIZE", 1))
4955

5056

@@ -77,3 +83,28 @@ def get_op_verifier_ctx(
7783
filter_config=filter_config,
7884
dump_root_dir=dump_dir,
7985
)
86+
87+
88+
def save_to_json(
89+
output_filename,
90+
train_step_loss,
91+
train_epoch_loss,
92+
train_step_metric,
93+
train_epoch_metric,
94+
val_step_loss,
95+
val_epoch_loss,
96+
val_step_metric,
97+
val_epoch_metric,
98+
):
99+
metrics_data = {
100+
"train_step_loss": train_step_loss,
101+
"train_epoch_loss": train_epoch_loss,
102+
"train_step_metric": train_step_metric,
103+
"train_epoch_metric": train_epoch_metric,
104+
"val_step_loss": val_step_loss,
105+
"val_epoch_loss": val_epoch_loss,
106+
"val_step_metric": val_step_metric,
107+
"val_epoch_metric": val_epoch_metric,
108+
}
109+
with open(output_filename, "w") as f:
110+
json.dump(metrics_data, f)

0 commit comments

Comments
 (0)