55# 
66# ----------------------------------------------------------------------------- 
77
8- import  json 
98import  os 
109import  time 
1110from  datetime  import  datetime 
1211from  functools  import  partial 
13- from  typing  import  Dict , List , Tuple 
1412
1513import  torch 
1614import  torch .distributed  as  dist 
1917from  tqdm  import  tqdm 
2018
2119from  QEfficient .finetune .configs .training  import  TrainConfig 
22- from  QEfficient .finetune .utils .helper  import  get_autocast_ctx , get_op_verifier_ctx , is_rank_zero 
20+ from  QEfficient .finetune .utils .helper  import  get_autocast_ctx , get_op_verifier_ctx , is_rank_zero ,  save_to_json 
2321from  QEfficient .finetune .utils .logging_utils  import  logger 
2422
2523try :
@@ -42,24 +40,24 @@ def train(
4240    optimizer ,
4341    lr_scheduler ,
4442    train_config : TrainConfig ,
45-     local_rank = None ,
4643):
4744    """ 
4845    Trains the model on the given dataloader 
4946
5047    Args: 
5148        model: The model to be trained 
52-         tokenizer: tokenizer used in the eval for decoding the predicitons  
49+         tokenizer: tokenizer used in the eval for decoding the predictions  
5350        train_dataloader: The dataloader containing the training data 
5451        eval_dataloader: The dataloader containing the eval data 
5552        optimizer: The optimizer used for training 
5653        lr_scheduler: The learning rate scheduler 
5754        train_config: The training configuration 
58-         local_rank: The rank of the current node in a distributed setting 
5955
6056    Returns: results dictionary containing average training and validation perplexity and loss 
6157    """ 
6258    device  =  train_config .device 
59+     device_type  =  torch .device (device ).type 
60+     local_rank  =  int (os .getenv ("LOCAL_RANK" , 0 ))
6361
6462    train_metric  =  []
6563    train_loss  =  []
@@ -89,8 +87,6 @@ def train(
8987        tensorboard_log_dir  =  train_config .output_dir  +  "/runs/"  +  f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )}  " 
9088        tensorboard_updates  =  SummaryWriter (log_dir = tensorboard_log_dir )
9189
92-     device_type  =  torch .device (device ).type 
93- 
9490    if  train_config .grad_scaler :
9591        if  device .startswith ("qaic" ):
9692            scaler  =  QAicGradScaler ()
@@ -130,10 +126,11 @@ def train(
130126                continue 
131127
132128        logger .log_rank_zero (f"Starting epoch { epoch  +  1 }  /{ train_config .num_epochs }  " )
133-         logger . log_rank_zero ( f"train_config.max_train_step:  { train_config .max_train_step } " ) 
134-         # stop when the maximum number of training  steps is reached 
129+         if   train_config .max_train_step   >   0 : 
130+              logger . log_rank_zero ( f"Max train  steps :  { train_config . max_train_step } " ) 
135131        if  max_steps_reached :
136132            break 
133+ 
137134        epoch_start_time  =  time .perf_counter ()
138135        model .train ()
139136
@@ -165,7 +162,6 @@ def train(
165162                    continue 
166163            total_train_steps  +=  1 
167164
168-             #  stop when the maximum number of training steps is reached 
169165            if  train_config .max_train_step  >  0  and  total_train_steps  >  train_config .max_train_step :
170166                max_steps_reached  =  True 
171167                break 
@@ -223,7 +219,7 @@ def train(
223219                    step_metric_val  =  float (torch .exp (loss .detach ().float ()))
224220                train_step_metric .append (step_metric_val )
225221
226-             # Accumalate  gradients 
222+             # Accumulate  gradients 
227223            complete_accum_steps  =  (
228224                len (train_dataloader ) -  len (train_dataloader ) %  train_config .gradient_accumulation_steps 
229225            )
@@ -291,25 +287,6 @@ def train(
291287                    if  total_loss  ==  0.0 
292288                    else  total_loss  /  (step  -  intermediate_step  -  num_dummy_samples  /  train_config .train_batch_size )
293289                )
294-             else :
295-                 train_epoch_loss  =  (
296-                     0.0 
297-                     if  total_loss  ==  0.0 
298-                     else  total_loss  /  (step  +  1  -  num_dummy_samples  /  train_config .train_batch_size )
299-                 )
300-         else :
301-             if  train_config .use_peft  and  train_config .from_peft_checkpoint  and  epoch  ==  intermediate_epoch :
302-                 train_epoch_loss  =  (
303-                     0.0 
304-                     if  total_loss  ==  0.0 
305-                     else  total_loss  /  (step  -  intermediate_step  -  (num_dummy_samples  /  train_config .train_batch_size ))
306-                 )
307-             else :
308-                 train_epoch_loss  =  (
309-                     0.0 
310-                     if  total_loss  ==  0.0 
311-                     else  total_loss  /  (step  +  1  -  (num_dummy_samples  /  train_config .train_batch_size ))
312-                 )
313290        if  train_config .task_type  ==  "seq_classification" :
314291            metric_val  =  acc_helper .compute ()
315292            acc_helper .reset ()
@@ -322,17 +299,6 @@ def train(
322299        # Update the learning rate as needed 
323300        lr_scheduler .step ()
324301
325-         if  train_config .run_validation :
326-             eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric  =  evaluation_helper (
327-                 model , train_config , eval_dataloader , device 
328-             )
329-             if  is_rank_zero ():
330-                 tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
331- 
332-             if  train_config .save_metrics :
333-                 val_step_loss .extend (temp_val_loss )
334-                 val_step_metric .extend (temp_step_metric )
335- 
336302        # saving the adapters after completion of each epoch 
337303        if  train_config .save_model :
338304            if  train_config .enable_ddp :
@@ -342,19 +308,24 @@ def train(
342308                model .save_pretrained (train_config .output_dir  +  f"/complete_epoch_{ epoch  +  1 }  " )
343309
344310        if  train_config .run_validation :
311+             eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric  =  evaluation (
312+                 model , train_config , eval_dataloader , device 
313+             )
345314            if  eval_epoch_loss  <  best_val_loss :
346315                best_val_loss  =  eval_epoch_loss 
347316                logger .log_rank_zero (f"best eval loss on epoch { epoch  +  1 }   is { best_val_loss }  " )
317+ 
318+             if  is_rank_zero ():
319+                 tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
320+ 
321+             if  train_config .save_metrics :
322+                 val_step_loss .extend (temp_val_loss )
323+                 val_step_metric .extend (temp_step_metric )
348324            val_loss .append (float (eval_epoch_loss ))
349325            val_metric .append (float (eval_metric ))
350-         if  train_config .task_type  ==  "seq_classification" :
351-             logger .log_rank_zero (
352-                 f"Epoch { epoch  +  1 }  : train_acc={ metric_val :.4f}  , train_epoch_loss={ train_epoch_loss :.4f}  , epoch time { epoch_end_time }  s" 
353-             )
354-         else :
355-             logger .log_rank_zero (
356-                 f"Epoch { epoch  +  1 }  : train_metric={ metric_val :.4f}  , train_epoch_loss={ train_epoch_loss :.4f}  , epoch time { epoch_end_time }  s" 
357-             )
326+         logger .log_rank_zero (
327+             f"Epoch { epoch  +  1 }  : train_metric={ metric_val :.4f}  , train_epoch_loss={ train_epoch_loss :.4f}  , epoch time { epoch_end_time }  s" 
328+         )
358329
359330        # Saving the results every epoch to plot later 
360331        if  train_config .save_metrics :
@@ -389,7 +360,7 @@ def train(
389360    return  results 
390361
391362
392- def  evaluation_helper (model , train_config , eval_dataloader , device ):
363+ def  evaluation (model , train_config , eval_dataloader , device ):
393364    """ 
394365    Evaluates the model on the given dataloader 
395366
@@ -474,60 +445,3 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
474445    logger .log_rank_zero (f"{ eval_metric .detach ().cpu ()= }   { eval_epoch_loss .detach ().cpu ()= }  " )
475446
476447    return  eval_epoch_loss , eval_metric , val_step_loss , val_step_metric 
477- 
478- 
479- def  get_longest_seq_length (data : List [Dict ]) ->  Tuple [int , int ]:
480-     # find out the minimum max_seq_length required during fine-tuning (saves memory!) 
481-     lengths  =  [len (d ["input_ids" ]) for  d  in  data ]
482-     longest_seq_length  =  max (lengths )
483-     longest_seq_ix  =  lengths .index (longest_seq_length )
484-     return  longest_seq_length , longest_seq_ix 
485- 
486- 
487- def  print_model_size (model ) ->  None :
488-     """ 
489-     Print model name, the number of trainable parameters and initialization time. 
490- 
491-     Args: 
492-         model: PyTorch model. 
493-     """ 
494-     total_params  =  sum (p .numel () for  p  in  model .parameters () if  p .requires_grad )
495-     logger .log_rank_zero (f"Model has { total_params  /  1e6 }   Million params." )
496- 
497- 
498- def  print_trainable_parameters (model ) ->  None :
499-     """ 
500-     Print the number of trainable parameters, all params and percentage of trainablke params. 
501- 
502-     Args: 
503-         model: The PyTorch model. 
504-     """ 
505-     trainable_params , all_param  =  model .get_nb_trainable_parameters ()
506-     logger .log_rank_zero (
507-         f"Trainable params: { trainable_params :,d}   || all params: { all_param :,d}   || trainable%: { 100  *  trainable_params  /  all_param :.4f}  " 
508-     )
509- 
510- 
511- def  save_to_json (
512-     output_filename ,
513-     train_step_loss ,
514-     train_epoch_loss ,
515-     train_step_metric ,
516-     train_epoch_metric ,
517-     val_step_loss ,
518-     val_epoch_loss ,
519-     val_step_metric ,
520-     val_epoch_metric ,
521- ):
522-     metrics_data  =  {
523-         "train_step_loss" : train_step_loss ,
524-         "train_epoch_loss" : train_epoch_loss ,
525-         "train_step_metric" : train_step_metric ,
526-         "train_epoch_metric" : train_epoch_metric ,
527-         "val_step_loss" : val_step_loss ,
528-         "val_epoch_loss" : val_epoch_loss ,
529-         "val_step_metric" : val_step_metric ,
530-         "val_epoch_metric" : val_epoch_metric ,
531-     }
532-     with  open (output_filename , "w" ) as  f :
533-         json .dump (metrics_data , f )
0 commit comments