1919from tqdm import tqdm
2020
2121from QEfficient .finetune .configs .training import train_config as TRAIN_CONFIG
22+ from QEfficient .utils .logging_utils import logger
2223
2324try :
2425 import torch_qaic # noqa: F401
2728 import torch_qaic .utils as qaic_utils # noqa: F401
2829 from torch .qaic .amp import GradScaler as QAicGradScaler
2930except ImportError as e :
30- print (f"Warning: { e } . Moving ahead without these qaic modules." )
31+ logger . warning (f"{ e } . Moving ahead without these qaic modules." )
3132
3233from torch .amp import GradScaler
3334
@@ -116,26 +117,26 @@ def train(
116117 for epoch in range (train_config .num_epochs ):
117118 if loss_0_counter .item () == train_config .convergence_counter :
118119 if train_config .enable_ddp :
119- print (
120+ logger . info (
120121 f"Not proceeding with epoch { epoch + 1 } on device { local_rank } since loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps."
121122 )
122123 break
123124 else :
124- print (
125+ logger . info (
125126 f"Not proceeding with epoch { epoch + 1 } since loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps."
126127 )
127128 break
128129
129130 if train_config .use_peft and train_config .from_peft_checkpoint :
130131 intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
131132 if epoch < intermediate_epoch :
132- print (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
133+ logger . info (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
133134 # to bring the count of train_step in sync with where it left off
134135 total_train_steps += len (train_dataloader )
135136 continue
136137
137- print (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
138- print (f"train_config.max_train_step: { train_config .max_train_step } " )
138+ logger . info (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
139+ logger . info (f"train_config.max_train_step: { train_config .max_train_step } " )
139140 # stop when the maximum number of training steps is reached
140141 if max_steps_reached :
141142 break
@@ -162,7 +163,7 @@ def train(
162163 # to bring the count of train_step in sync with where it left off
163164 if epoch == intermediate_epoch and step == 0 :
164165 total_train_steps += intermediate_step
165- print (
166+ logger . info (
166167 f"skipping first { intermediate_step } steps for epoch { epoch + 1 } , since fine tuning has already completed for them."
167168 )
168169 if epoch == intermediate_epoch and step < intermediate_step :
@@ -197,7 +198,7 @@ def train(
197198 labels = batch ["labels" ][:, 0 ]
198199 preds = torch .nn .functional .softmax (logits , dim = - 1 )
199200 acc_helper .forward (preds , labels )
200- print ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
201+ logger . info ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
201202 else :
202203 model_outputs = model (** batch )
203204 loss = model_outputs .loss # Forward call
@@ -279,13 +280,13 @@ def train(
279280 )
280281 if train_config .enable_ddp :
281282 if loss_0_counter .item () == train_config .convergence_counter :
282- print (
283+ logger . info (
283284 f"Loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps. Hence, stopping the fine tuning on device { local_rank } ."
284285 )
285286 break
286287 else :
287288 if loss_0_counter .item () == train_config .convergence_counter :
288- print (
289+ logger . info (
289290 f"Loss value has been <= { train_config .convergence_loss } for last { loss_0_counter .item ()} steps. Hence, stopping the fine tuning."
290291 )
291292 break
@@ -347,15 +348,15 @@ def train(
347348 if train_config .run_validation :
348349 if eval_epoch_loss < best_val_loss :
349350 best_val_loss = eval_epoch_loss
350- print (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
351+ logger . info (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
351352 val_loss .append (float (eval_epoch_loss ))
352353 val_metric .append (float (eval_metric ))
353354 if train_config .task_type == "seq_classification" :
354- print (
355+ logger . info (
355356 f"Epoch { epoch + 1 } : train_acc={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
356357 )
357358 else :
358- print (
359+ logger . info (
359360 f"Epoch { epoch + 1 } : train_metric={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
360361 )
361362
@@ -459,7 +460,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
459460 eval_metric = torch .exp (eval_epoch_loss )
460461
461462 # Print evaluation metrics
462- print (f" { eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
463+ logger . info (f" { eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
463464
464465 return eval_metric , eval_epoch_loss , val_step_loss , val_step_metric
465466
@@ -489,9 +490,9 @@ def print_model_size(model, config) -> None:
489490 model_name (str): Name of the model.
490491 """
491492
492- print (f"--> Model { config .model_name } " )
493+ logger . info (f"--> Model { config .model_name } " )
493494 total_params = sum (p .numel () for p in model .parameters () if p .requires_grad )
494- print (f"\n --> { config .model_name } has { total_params / 1e6 } Million params\n " )
495+ logger . info (f"\n --> { config .model_name } has { total_params / 1e6 } Million params\n " )
495496
496497
497498def save_to_json (
0 commit comments