55#
66# -----------------------------------------------------------------------------
77
8+ import logging
89import random
910import warnings
10- from typing import Any , Dict , Optional , Union
11+ from typing import Any , Optional , Union
1112
1213import numpy as np
1314import torch
1718import torch .utils .data
1819from peft import PeftModel , get_peft_model
1920from torch .optim .lr_scheduler import StepLR
20- from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
21+ from transformers import AutoModel , AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
2122
2223from QEfficient .finetune .configs .training import TrainConfig
2324from QEfficient .finetune .utils .config_utils import (
2425 generate_dataset_config ,
2526 generate_peft_config ,
2627 update_config ,
2728)
28- from QEfficient .finetune .utils .dataset_utils import get_dataloader
29+ from QEfficient .finetune .utils .dataset_utils import get_dataloader , get_longest_seq_length
30+ from QEfficient .finetune .utils .device_map import get_device_map
31+ from QEfficient .finetune .utils .helper import Task_Mode , get_world_size
32+ from QEfficient .finetune .utils .logging_utils import logger
2933from QEfficient .finetune .utils .parser import get_finetune_parser
30- from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
31- from QEfficient .utils ._utils import login_and_download_hf_lm
34+ from QEfficient .finetune .utils .train_utils import print_model_size , print_trainable_parameters , train
35+ from QEfficient .utils ._utils import hf_download
3236
3337# Try importing QAIC-specific module, proceed without it if unavailable
3438try :
3539 import torch_qaic # noqa: F401
3640except ImportError as e :
37- print (f"Warning: { e } . Proceeding without QAIC modules." )
38-
41+ logger .log_rank_zero (
42+ f"Unable to import 'torch_qaic' package due to exception: { e } . Moving ahead without the torch_qaic extension." ,
43+ logging .WARNING ,
44+ )
3945
40- from transformers import AutoModelForSequenceClassification
4146
4247# Suppress all warnings
4348warnings .filterwarnings ("ignore" )
@@ -57,17 +62,27 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
5762 Raises:
5863 AssertionError: If device is CPU or includes an index with DDP enabled.
5964 """
65+
66+ torch_device = torch .device (train_config .device )
67+ num_available_devices = getattr (torch , torch_device .type ).device_count ()
68+ assert get_world_size () * train_config .num_pp_stages <= num_available_devices , (
69+ "Number of devices required should be less than or equal to total available devices."
70+ )
71+ if train_config .enable_pp :
72+ assert train_config .num_pp_stages > 1 , (
73+ f"For pipeline parallelism, num_pp_stages should be greater than 1. Got { train_config .num_pp_stages } "
74+ )
75+
6076 if not train_config .enable_ddp :
6177 return
6278
63- torch_device = torch .device (train_config .device )
6479 assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
6580 assert torch_device .index is None , f"DDP requires only device type, got: { torch_device } "
66-
6781 dist_backend_map = {"cpu" : "gloo" , "qaic" : "qccl" , "cuda" : "gloo" }
6882 dist .init_process_group (backend = dist_backend_map [torch_device .type ])
69- # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
70- getattr (torch , torch_device .type ).set_device (dist .get_rank ())
83+ if not train_config .enable_pp :
84+ # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
85+ getattr (torch , torch_device .type ).set_device (dist .get_rank ())
7186
7287
7388def setup_seeds (seed : int ) -> None :
@@ -79,20 +94,23 @@ def setup_seeds(seed: int) -> None:
7994 Notes:
8095 - Sets seeds for PyTorch, Python's random module, and NumPy.
8196 """
97+ torch .use_deterministic_algorithms (True )
98+ # With this flag, PP+DDP works only for meta-llama/Llama-3.2-1B and mistralai/Mistral-7B-Instruct-v0.3
99+ # and throws error during loading model for meta-llama/Llama-3.1-8B and bigger size models.
100+
82101 torch .manual_seed (seed )
83102 random .seed (seed )
84103 np .random .seed (seed )
85104
86105
87106def load_model_and_tokenizer (
88- train_config : TrainConfig , dataset_config : Any , peft_config_file : str , ** kwargs
107+ train_config : TrainConfig , dataset_config : Any , ** kwargs
89108) -> tuple [AutoModelForCausalLM , AutoTokenizer ]:
90109 """Load the pre-trained model and tokenizer from Hugging Face.
91110
92111 Args:
93- config (TrainConfig): Training configuration object containing model and tokenizer names.
112+ train_config (TrainConfig): Training configuration object containing model and tokenizer names.
94113 dataset_config (Any): A dataclass object representing dataset configuration.
95- peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
96114 kwargs: Additional arguments to override PEFT config.
97115
98116 Returns:
@@ -106,8 +124,12 @@ def load_model_and_tokenizer(
106124 - Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
107125 - Sets pad_token_id to eos_token_id if not defined in the tokenizer.
108126 """
109- pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
110- if train_config .task_type == "seq_classification" :
127+ logger .log_rank_zero (f"Loading HuggingFace model for { train_config .model_name } " )
128+ pretrained_model_path = hf_download (
129+ train_config .model_name ,
130+ ignore_patterns = ["*.txt" , "*.onnx" , "*.ot" , "*.md" , "*.tflite" , "*.pdf" , "*.msgpack" , "*.h5" , "*.pth" ],
131+ )
132+ if train_config .task_mode == Task_Mode .SEQ_CLASSIFICATION :
111133 model = AutoModelForSequenceClassification .from_pretrained (
112134 pretrained_model_path ,
113135 num_labels = dataset_config .num_labels ,
@@ -116,7 +138,7 @@ def load_model_and_tokenizer(
116138 )
117139
118140 if not hasattr (model , "base_model_prefix" ):
119- raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
141+ logger . raise_error ("Given huggingface model does not have 'base_model_prefix' attribute." , RuntimeError )
120142
121143 for param in getattr (model , model .base_model_prefix ).parameters ():
122144 param .requires_grad = False
@@ -125,13 +147,14 @@ def load_model_and_tokenizer(
125147 if param .requires_grad :
126148 param .data = param .data .to (torch .float32 )
127149 else :
150+ device_map = get_device_map (train_config )
128151 model = AutoModelForCausalLM .from_pretrained (
129152 pretrained_model_path ,
130153 use_cache = False ,
131154 attn_implementation = "sdpa" ,
132155 torch_dtype = torch .float16 ,
156+ device_map = device_map ,
133157 )
134-
135158 tokenizer = AutoTokenizer .from_pretrained (
136159 train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name
137160 )
@@ -141,11 +164,10 @@ def load_model_and_tokenizer(
141164 # If there is a mismatch between tokenizer vocab size and embedding matrix,
142165 # throw a warning and then expand the embedding matrix
143166 if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
144- print ( "WARNING: Resizing embedding matrix to match tokenizer vocab size." )
167+ logger . log_rank_zero ( " Resizing the embedding matrix to match the tokenizer vocab size.", logging . WARNING )
145168 model .resize_token_embeddings (len (tokenizer ))
146169
147- # FIXME (Meet): Cover below line inside the logger once it is implemented.
148- print_model_size (model , train_config )
170+ print_model_size (model )
149171
150172 # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
151173 # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -155,25 +177,23 @@ def load_model_and_tokenizer(
155177 if train_config .gradient_checkpointing :
156178 # Note: below attribute and method is only available in HuggingFace Transformer models.
157179 if hasattr (model , "supports_gradient_checkpointing" ) and model .supports_gradient_checkpointing :
158- model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = {"preserve_rng_state" : False })
180+ model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = {"preserve_rng_state" : True })
159181 else :
160- raise RuntimeError ("Given model doesn't support gradient checkpointing. Please disable it and run it." )
182+ logger .raise_error (
183+ "Given model doesn't support gradient checkpointing. Please disable it and run it." , RuntimeError
184+ )
161185
162- model = apply_peft (model , train_config , peft_config_file , ** kwargs )
186+ model = apply_peft (model , train_config , ** kwargs )
163187
164188 return model , tokenizer
165189
166190
167- def apply_peft (
168- model : AutoModel , train_config : TrainConfig , peft_config_file : Dict , ** kwargs
169- ) -> Union [AutoModel , PeftModel ]:
191+ def apply_peft (model : AutoModel , train_config : TrainConfig , ** kwargs ) -> Union [AutoModel , PeftModel ]:
170192 """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
171193
172194 Args:
173195 model (AutoModel): Huggingface model.
174196 train_config (TrainConfig): Training configuration object.
175- peft_config_file (str, optional): Path to YAML/JSON file containing
176- PEFT (LoRA) config. Defaults to None.
177197 kwargs: Additional arguments to override PEFT config params.
178198
179199 Returns:
@@ -190,9 +210,9 @@ def apply_peft(
190210 peft_config = model .peft_config
191211 # Generate the peft config and start fine-tuning from original model
192212 else :
193- peft_config = generate_peft_config (train_config , peft_config_file , ** kwargs )
213+ peft_config = generate_peft_config (train_config , ** kwargs )
194214 model = get_peft_model (model , peft_config )
195- model . print_trainable_parameters ()
215+ print_trainable_parameters (model )
196216
197217 return model
198218
@@ -217,25 +237,26 @@ def setup_dataloaders(
217237 - Length of longest sequence in the dataset.
218238
219239 Raises:
220- ValueError : If validation is enabled but the validation set is too small.
240+ RuntimeError : If validation is enabled but the validation set is too small.
221241
222242 Notes:
223243 - Applies a custom data collator if provided by get_custom_data_collator.
224244 - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
225245 """
226246
227247 train_dataloader = get_dataloader (tokenizer , dataset_config , train_config , split = "train" )
228- print (f"--> Num of Training Set Batches loaded = { len (train_dataloader )} " )
248+ logger . log_rank_zero (f"Number of Training Set Batches loaded = { len (train_dataloader )} " )
229249
230250 eval_dataloader = None
231251 if train_config .run_validation :
232252 eval_dataloader = get_dataloader (tokenizer , dataset_config , train_config , split = "val" )
233253 if len (eval_dataloader ) == 0 :
234- raise ValueError (
235- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
254+ logger .raise_error (
255+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )" ,
256+ ValueError ,
236257 )
237258 else :
238- print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
259+ logger . log_rank_zero (f"Number of Validation Set Batches loaded = { len (eval_dataloader )} " )
239260
240261 longest_seq_length , _ = get_longest_seq_length (
241262 torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
@@ -246,12 +267,11 @@ def setup_dataloaders(
246267 return train_dataloader , eval_dataloader , longest_seq_length
247268
248269
249- def main (peft_config_file : str = None , ** kwargs ) -> None :
270+ def main (** kwargs ) -> None :
250271 """
251272 Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
252273
253274 Args:
254- peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
255275 kwargs: Additional arguments to override TrainConfig.
256276
257277 Example:
@@ -274,23 +294,37 @@ def main(peft_config_file: str = None, **kwargs) -> None:
274294 dataset_config = generate_dataset_config (train_config .dataset )
275295 update_config (dataset_config , ** kwargs )
276296
297+ logger .prepare_for_logs (train_config .output_dir , train_config .dump_logs , train_config .log_level )
298+
277299 setup_distributed_training (train_config )
278300 setup_seeds (train_config .seed )
279- model , tokenizer = load_model_and_tokenizer (train_config , dataset_config , peft_config_file , ** kwargs )
301+ model , tokenizer = load_model_and_tokenizer (train_config , dataset_config , ** kwargs )
280302
281303 # Create DataLoaders for the training and validation dataset
282304 train_dataloader , eval_dataloader , longest_seq_length = setup_dataloaders (train_config , dataset_config , tokenizer )
283- print (
305+ logger . log_rank_zero (
284306 f"The longest sequence length in the train data is { longest_seq_length } , "
285307 f"passed context length is { train_config .context_length } and overall model's context length is "
286308 f"{ model .config .max_position_embeddings } "
287309 )
288-
289- model .to (train_config .device )
290- optimizer = optim .AdamW (model .parameters (), lr = train_config .lr , weight_decay = train_config .weight_decay )
310+ if not train_config .enable_pp :
311+ model .to (train_config .device )
312+ optimizer = optim .AdamW (
313+ model .parameters (),
314+ lr = train_config .lr ,
315+ weight_decay = train_config .weight_decay ,
316+ )
291317 scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
292318 if train_config .enable_ddp :
293- model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
319+ ignore_names = set ()
320+ for name , param in model .named_parameters ():
321+ if not param .requires_grad :
322+ ignore_names .add (name )
323+ # Adding params in ignore list will enforce DDP to ignore them during synchronization,
324+ # which will further reduce the tensor exchange across devices.
325+ torch .nn .parallel .DistributedDataParallel ._set_params_and_buffers_to_ignore_for_model (model , ignore_names )
326+ model = nn .parallel .DistributedDataParallel (model )
327+
294328 results = train (
295329 model ,
296330 tokenizer ,
@@ -299,7 +333,6 @@ def main(peft_config_file: str = None, **kwargs) -> None:
299333 optimizer ,
300334 scheduler ,
301335 train_config ,
302- dist .get_rank () if train_config .enable_ddp else None ,
303336 )
304337 if train_config .enable_ddp :
305338 dist .destroy_process_group ()
0 commit comments