55#
66# -----------------------------------------------------------------------------
77
8+ import logging
89import random
910import warnings
1011from typing import Any , Dict , Optional , Union
1718import torch .utils .data
1819from peft import PeftModel , get_peft_model
1920from torch .optim .lr_scheduler import StepLR
20- from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
21+ from transformers import AutoModel , AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
2122
2223from QEfficient .finetune .configs .training import TrainConfig
2324from QEfficient .finetune .utils .config_utils import (
2627 update_config ,
2728)
2829from QEfficient .finetune .utils .dataset_utils import get_dataloader
30+ from QEfficient .finetune .utils .logging_utils import logger
2931from QEfficient .finetune .utils .parser import get_finetune_parser
30- from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
31- from QEfficient .utils ._utils import login_and_download_hf_lm
32+ from QEfficient .finetune .utils .train_utils import (
33+ get_longest_seq_length ,
34+ print_model_size ,
35+ print_trainable_parameters ,
36+ train ,
37+ )
38+ from QEfficient .utils ._utils import hf_download
3239
3340# Try importing QAIC-specific module, proceed without it if unavailable
3441try :
3542 import torch_qaic # noqa: F401
3643except ImportError as e :
37- print (f"Warning: { e } . Proceeding without QAIC modules." )
38-
44+ logger .log_rank_zero (f"{ e } . Moving ahead without these qaic modules." , logging .WARNING )
3945
40- from transformers import AutoModelForSequenceClassification
4146
4247# Suppress all warnings
4348warnings .filterwarnings ("ignore" )
@@ -106,7 +111,8 @@ def load_model_and_tokenizer(
106111 - Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
107112 - Sets pad_token_id to eos_token_id if not defined in the tokenizer.
108113 """
109- pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
114+ logger .log_rank_zero (f"Loading HuggingFace model for { train_config .model_name } " )
115+ pretrained_model_path = hf_download (train_config .model_name )
110116 if train_config .task_type == "seq_classification" :
111117 model = AutoModelForSequenceClassification .from_pretrained (
112118 pretrained_model_path ,
@@ -116,7 +122,7 @@ def load_model_and_tokenizer(
116122 )
117123
118124 if not hasattr (model , "base_model_prefix" ):
119- raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
125+ logger . raise_error ("Given huggingface model does not have 'base_model_prefix' attribute." , RuntimeError )
120126
121127 for param in getattr (model , model .base_model_prefix ).parameters ():
122128 param .requires_grad = False
@@ -141,11 +147,10 @@ def load_model_and_tokenizer(
141147 # If there is a mismatch between tokenizer vocab size and embedding matrix,
142148 # throw a warning and then expand the embedding matrix
143149 if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
144- print ( "WARNING: Resizing embedding matrix to match tokenizer vocab size." )
150+ logger . log_rank_zero ( " Resizing the embedding matrix to match the tokenizer vocab size.", logging . WARNING )
145151 model .resize_token_embeddings (len (tokenizer ))
146152
147- # FIXME (Meet): Cover below line inside the logger once it is implemented.
148- print_model_size (model , train_config )
153+ print_model_size (model )
149154
150155 # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
151156 # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -157,7 +162,9 @@ def load_model_and_tokenizer(
157162 if hasattr (model , "supports_gradient_checkpointing" ) and model .supports_gradient_checkpointing :
158163 model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = {"preserve_rng_state" : False })
159164 else :
160- raise RuntimeError ("Given model doesn't support gradient checkpointing. Please disable it and run it." )
165+ logger .raise_error (
166+ "Given model doesn't support gradient checkpointing. Please disable it and run it." , RuntimeError
167+ )
161168
162169 model = apply_peft (model , train_config , peft_config_file , ** kwargs )
163170
@@ -192,7 +199,7 @@ def apply_peft(
192199 else :
193200 peft_config = generate_peft_config (train_config , peft_config_file , ** kwargs )
194201 model = get_peft_model (model , peft_config )
195- model . print_trainable_parameters ()
202+ print_trainable_parameters (model )
196203
197204 return model
198205
@@ -217,25 +224,26 @@ def setup_dataloaders(
217224 - Length of longest sequence in the dataset.
218225
219226 Raises:
220- ValueError : If validation is enabled but the validation set is too small.
227+ RuntimeError : If validation is enabled but the validation set is too small.
221228
222229 Notes:
223230 - Applies a custom data collator if provided by get_custom_data_collator.
224231 - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
225232 """
226233
227234 train_dataloader = get_dataloader (tokenizer , dataset_config , train_config , split = "train" )
228- print (f"--> Num of Training Set Batches loaded = { len (train_dataloader )} " )
235+ logger . log_rank_zero (f"Number of Training Set Batches loaded = { len (train_dataloader )} " )
229236
230237 eval_dataloader = None
231238 if train_config .run_validation :
232239 eval_dataloader = get_dataloader (tokenizer , dataset_config , train_config , split = "val" )
233240 if len (eval_dataloader ) == 0 :
234- raise ValueError (
235- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
241+ logger .raise_error (
242+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )" ,
243+ ValueError ,
236244 )
237245 else :
238- print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
246+ logger . log_rank_zero (f"Number of Validation Set Batches loaded = { len (eval_dataloader )} " )
239247
240248 longest_seq_length , _ = get_longest_seq_length (
241249 torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
@@ -274,13 +282,15 @@ def main(peft_config_file: str = None, **kwargs) -> None:
274282 dataset_config = generate_dataset_config (train_config .dataset )
275283 update_config (dataset_config , ** kwargs )
276284
285+ logger .prepare_for_logs (train_config .output_dir , train_config .dump_logs , train_config .log_level )
286+
277287 setup_distributed_training (train_config )
278288 setup_seeds (train_config .seed )
279289 model , tokenizer = load_model_and_tokenizer (train_config , dataset_config , peft_config_file , ** kwargs )
280290
281291 # Create DataLoaders for the training and validation dataset
282292 train_dataloader , eval_dataloader , longest_seq_length = setup_dataloaders (train_config , dataset_config , tokenizer )
283- print (
293+ logger . log_rank_zero (
284294 f"The longest sequence length in the train data is { longest_seq_length } , "
285295 f"passed context length is { train_config .context_length } and overall model's context length is "
286296 f"{ model .config .max_position_embeddings } "
0 commit comments