26
26
generate_peft_config ,
27
27
update_config ,
28
28
)
29
- from QEfficient .finetune .utils .dataset_utils import get_dataloader
30
- from QEfficient .finetune .utils .helper import Task_Mode
29
+ from QEfficient .finetune .utils .dataset_utils import get_dataloader , get_longest_seq_length
30
+ from QEfficient .finetune .utils .device_map import get_device_map
31
+ from QEfficient .finetune .utils .helper import Task_Mode , get_world_size
31
32
from QEfficient .finetune .utils .logging_utils import logger
32
33
from QEfficient .finetune .utils .parser import get_finetune_parser
33
- from QEfficient .finetune .utils .train_utils import (
34
- get_longest_seq_length ,
35
- print_model_size ,
36
- print_trainable_parameters ,
37
- train ,
38
- )
34
+ from QEfficient .finetune .utils .train_utils import print_model_size , print_trainable_parameters , train
39
35
from QEfficient .utils ._utils import hf_download
40
36
41
37
# Try importing QAIC-specific module, proceed without it if unavailable
@@ -63,17 +59,27 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
63
59
Raises:
64
60
AssertionError: If device is CPU or includes an index with DDP enabled.
65
61
"""
62
+
63
+ torch_device = torch .device (train_config .device )
64
+ num_available_devices = getattr (torch , torch_device .type ).device_count ()
65
+ assert get_world_size () * train_config .num_pp_stages <= num_available_devices , (
66
+ "Number of devices required should be less than or equal to total available devices."
67
+ )
68
+ if train_config .enable_pp :
69
+ assert train_config .num_pp_stages > 1 , (
70
+ f"For pipeline parallelism, num_pp_stages should be greater than 1. Got { train_config .num_pp_stages } "
71
+ )
72
+
66
73
if not train_config .enable_ddp :
67
74
return
68
75
69
- torch_device = torch .device (train_config .device )
70
76
assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
71
77
assert torch_device .index is None , f"DDP requires only device type, got: { torch_device } "
72
-
73
78
dist_backend_map = {"cpu" : "gloo" , "qaic" : "qccl" , "cuda" : "gloo" }
74
79
dist .init_process_group (backend = dist_backend_map [torch_device .type ])
75
- # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
76
- getattr (torch , torch_device .type ).set_device (dist .get_rank ())
80
+ if not train_config .enable_pp :
81
+ # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
82
+ getattr (torch , torch_device .type ).set_device (dist .get_rank ())
77
83
78
84
79
85
def setup_seeds (seed : int ) -> None :
@@ -85,6 +91,10 @@ def setup_seeds(seed: int) -> None:
85
91
Notes:
86
92
- Sets seeds for PyTorch, Python's random module, and NumPy.
87
93
"""
94
+ torch .use_deterministic_algorithms (True )
95
+ # With this flag, PP+DDP works only for meta-llama/Llama-3.2-1B and mistralai/Mistral-7B-Instruct-v0.3
96
+ # and throws error during loading model for meta-llama/Llama-3.1-8B and bigger size models.
97
+
88
98
torch .manual_seed (seed )
89
99
random .seed (seed )
90
100
np .random .seed (seed )
@@ -96,7 +106,7 @@ def load_model_and_tokenizer(
96
106
"""Load the pre-trained model and tokenizer from Hugging Face.
97
107
98
108
Args:
99
- config (TrainConfig): Training configuration object containing model and tokenizer names.
109
+ train_config (TrainConfig): Training configuration object containing model and tokenizer names.
100
110
dataset_config (Any): A dataclass object representing dataset configuration.
101
111
kwargs: Additional arguments to override PEFT config.
102
112
@@ -112,7 +122,10 @@ def load_model_and_tokenizer(
112
122
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
113
123
"""
114
124
logger .log_rank_zero (f"Loading HuggingFace model for { train_config .model_name } " )
115
- pretrained_model_path = hf_download (train_config .model_name )
125
+ pretrained_model_path = hf_download (
126
+ train_config .model_name ,
127
+ ignore_patterns = ["*.txt" , "*.onnx" , "*.ot" , "*.md" , "*.tflite" , "*.pdf" , "*.msgpack" , "*.h5" , "*.pth" ],
128
+ )
116
129
if train_config .task_mode == Task_Mode .SEQ_CLASSIFICATION :
117
130
model = AutoModelForSequenceClassification .from_pretrained (
118
131
pretrained_model_path ,
@@ -131,13 +144,14 @@ def load_model_and_tokenizer(
131
144
if param .requires_grad :
132
145
param .data = param .data .to (torch .float32 )
133
146
else :
147
+ device_map = get_device_map (train_config )
134
148
model = AutoModelForCausalLM .from_pretrained (
135
149
pretrained_model_path ,
136
150
use_cache = False ,
137
151
attn_implementation = "sdpa" ,
138
152
torch_dtype = torch .float16 ,
153
+ device_map = device_map ,
139
154
)
140
-
141
155
tokenizer = AutoTokenizer .from_pretrained (
142
156
train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name
143
157
)
@@ -290,11 +304,24 @@ def main(**kwargs) -> None:
290
304
f"passed context length is { train_config .context_length } and overall model's context length is "
291
305
f"{ model .config .max_position_embeddings } "
292
306
)
293
- model .to (train_config .device )
294
- optimizer = optim .AdamW (model .parameters (), lr = train_config .lr , weight_decay = train_config .weight_decay )
307
+ if not train_config .enable_pp :
308
+ model .to (train_config .device )
309
+ optimizer = optim .AdamW (
310
+ model .parameters (),
311
+ lr = train_config .lr ,
312
+ weight_decay = train_config .weight_decay ,
313
+ )
295
314
scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
296
315
if train_config .enable_ddp :
297
- model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
316
+ ignore_names = set ()
317
+ for name , param in model .named_parameters ():
318
+ if not param .requires_grad :
319
+ ignore_names .add (name )
320
+ # Adding params in ignore list will enforce DDP to ignore them during synchronization,
321
+ # which will further reduce the tensor exchange across devices.
322
+ torch .nn .parallel .DistributedDataParallel ._set_params_and_buffers_to_ignore_for_model (model , ignore_names )
323
+ model = nn .parallel .DistributedDataParallel (model )
324
+
298
325
results = train (
299
326
model ,
300
327
tokenizer ,
@@ -303,7 +330,6 @@ def main(**kwargs) -> None:
303
330
optimizer ,
304
331
scheduler ,
305
332
train_config ,
306
- dist .get_rank () if train_config .enable_ddp else None ,
307
333
)
308
334
if train_config .enable_ddp :
309
335
dist .destroy_process_group ()
0 commit comments