5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
- import json
9
8
import os
10
9
import time
11
10
from datetime import datetime
12
11
from functools import partial
13
- from typing import Dict , List , Tuple
14
12
15
13
import torch
16
14
import torch .distributed as dist
19
17
from tqdm import tqdm
20
18
21
19
from QEfficient .finetune .configs .training import TrainConfig
22
- from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx , is_rank_zero
20
+ from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx , is_rank_zero , save_to_json
23
21
from QEfficient .finetune .utils .logging_utils import logger
24
22
25
23
try :
@@ -42,24 +40,24 @@ def train(
42
40
optimizer ,
43
41
lr_scheduler ,
44
42
train_config : TrainConfig ,
45
- local_rank = None ,
46
43
):
47
44
"""
48
45
Trains the model on the given dataloader
49
46
50
47
Args:
51
48
model: The model to be trained
52
- tokenizer: tokenizer used in the eval for decoding the predicitons
49
+ tokenizer: tokenizer used in the eval for decoding the predictions
53
50
train_dataloader: The dataloader containing the training data
54
51
eval_dataloader: The dataloader containing the eval data
55
52
optimizer: The optimizer used for training
56
53
lr_scheduler: The learning rate scheduler
57
54
train_config: The training configuration
58
- local_rank: The rank of the current node in a distributed setting
59
55
60
56
Returns: results dictionary containing average training and validation perplexity and loss
61
57
"""
62
58
device = train_config .device
59
+ device_type = torch .device (device ).type
60
+ local_rank = int (os .getenv ("LOCAL_RANK" , 0 ))
63
61
64
62
train_metric = []
65
63
train_loss = []
@@ -89,8 +87,6 @@ def train(
89
87
tensorboard_log_dir = train_config .output_dir + "/runs/" + f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} "
90
88
tensorboard_updates = SummaryWriter (log_dir = tensorboard_log_dir )
91
89
92
- device_type = torch .device (device ).type
93
-
94
90
if train_config .grad_scaler :
95
91
if device .startswith ("qaic" ):
96
92
scaler = QAicGradScaler ()
@@ -130,10 +126,11 @@ def train(
130
126
continue
131
127
132
128
logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
133
- logger . log_rank_zero ( f"train_config.max_train_step: { train_config .max_train_step } " )
134
- # stop when the maximum number of training steps is reached
129
+ if train_config .max_train_step > 0 :
130
+ logger . log_rank_zero ( f"Max train steps : { train_config . max_train_step } " )
135
131
if max_steps_reached :
136
132
break
133
+
137
134
epoch_start_time = time .perf_counter ()
138
135
model .train ()
139
136
@@ -165,7 +162,6 @@ def train(
165
162
continue
166
163
total_train_steps += 1
167
164
168
- # stop when the maximum number of training steps is reached
169
165
if train_config .max_train_step > 0 and total_train_steps > train_config .max_train_step :
170
166
max_steps_reached = True
171
167
break
@@ -223,7 +219,7 @@ def train(
223
219
step_metric_val = float (torch .exp (loss .detach ().float ()))
224
220
train_step_metric .append (step_metric_val )
225
221
226
- # Accumalate gradients
222
+ # Accumulate gradients
227
223
complete_accum_steps = (
228
224
len (train_dataloader ) - len (train_dataloader ) % train_config .gradient_accumulation_steps
229
225
)
@@ -297,19 +293,6 @@ def train(
297
293
if total_loss == 0.0
298
294
else total_loss / (step + 1 - num_dummy_samples / train_config .train_batch_size )
299
295
)
300
- else :
301
- if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
302
- train_epoch_loss = (
303
- 0.0
304
- if total_loss == 0.0
305
- else total_loss / (step - intermediate_step - (num_dummy_samples / train_config .train_batch_size ))
306
- )
307
- else :
308
- train_epoch_loss = (
309
- 0.0
310
- if total_loss == 0.0
311
- else total_loss / (step + 1 - (num_dummy_samples / train_config .train_batch_size ))
312
- )
313
296
if train_config .task_type == "seq_classification" :
314
297
metric_val = acc_helper .compute ()
315
298
acc_helper .reset ()
@@ -322,17 +305,6 @@ def train(
322
305
# Update the learning rate as needed
323
306
lr_scheduler .step ()
324
307
325
- if train_config .run_validation :
326
- eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation_helper (
327
- model , train_config , eval_dataloader , device
328
- )
329
- if is_rank_zero ():
330
- tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
331
-
332
- if train_config .save_metrics :
333
- val_step_loss .extend (temp_val_loss )
334
- val_step_metric .extend (temp_step_metric )
335
-
336
308
# saving the adapters after completion of each epoch
337
309
if train_config .save_model :
338
310
if train_config .enable_ddp :
@@ -342,19 +314,24 @@ def train(
342
314
model .save_pretrained (train_config .output_dir + f"/complete_epoch_{ epoch + 1 } " )
343
315
344
316
if train_config .run_validation :
317
+ eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation (
318
+ model , train_config , eval_dataloader , device
319
+ )
345
320
if eval_epoch_loss < best_val_loss :
346
321
best_val_loss = eval_epoch_loss
347
322
logger .log_rank_zero (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
323
+
324
+ if is_rank_zero ():
325
+ tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
326
+
327
+ if train_config .save_metrics :
328
+ val_step_loss .extend (temp_val_loss )
329
+ val_step_metric .extend (temp_step_metric )
348
330
val_loss .append (float (eval_epoch_loss ))
349
331
val_metric .append (float (eval_metric ))
350
- if train_config .task_type == "seq_classification" :
351
- logger .log_rank_zero (
352
- f"Epoch { epoch + 1 } : train_acc={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
353
- )
354
- else :
355
- logger .log_rank_zero (
356
- f"Epoch { epoch + 1 } : train_metric={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
357
- )
332
+ logger .log_rank_zero (
333
+ f"Epoch { epoch + 1 } : train_metric={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
334
+ )
358
335
359
336
# Saving the results every epoch to plot later
360
337
if train_config .save_metrics :
@@ -389,7 +366,7 @@ def train(
389
366
return results
390
367
391
368
392
- def evaluation_helper (model , train_config , eval_dataloader , device ):
369
+ def evaluation (model , train_config , eval_dataloader , device ):
393
370
"""
394
371
Evaluates the model on the given dataloader
395
372
@@ -474,60 +451,3 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
474
451
logger .log_rank_zero (f"{ eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
475
452
476
453
return eval_epoch_loss , eval_metric , val_step_loss , val_step_metric
477
-
478
-
479
- def get_longest_seq_length (data : List [Dict ]) -> Tuple [int , int ]:
480
- # find out the minimum max_seq_length required during fine-tuning (saves memory!)
481
- lengths = [len (d ["input_ids" ]) for d in data ]
482
- longest_seq_length = max (lengths )
483
- longest_seq_ix = lengths .index (longest_seq_length )
484
- return longest_seq_length , longest_seq_ix
485
-
486
-
487
- def print_model_size (model ) -> None :
488
- """
489
- Print model name, the number of trainable parameters and initialization time.
490
-
491
- Args:
492
- model: PyTorch model.
493
- """
494
- total_params = sum (p .numel () for p in model .parameters () if p .requires_grad )
495
- logger .log_rank_zero (f"Model has { total_params / 1e6 } Million params." )
496
-
497
-
498
- def print_trainable_parameters (model ) -> None :
499
- """
500
- Print the number of trainable parameters, all params and percentage of trainablke params.
501
-
502
- Args:
503
- model: The PyTorch model.
504
- """
505
- trainable_params , all_param = model .get_nb_trainable_parameters ()
506
- logger .log_rank_zero (
507
- f"Trainable params: { trainable_params :,d} || all params: { all_param :,d} || trainable%: { 100 * trainable_params / all_param :.4f} "
508
- )
509
-
510
-
511
- def save_to_json (
512
- output_filename ,
513
- train_step_loss ,
514
- train_epoch_loss ,
515
- train_step_metric ,
516
- train_epoch_metric ,
517
- val_step_loss ,
518
- val_epoch_loss ,
519
- val_step_metric ,
520
- val_epoch_metric ,
521
- ):
522
- metrics_data = {
523
- "train_step_loss" : train_step_loss ,
524
- "train_epoch_loss" : train_epoch_loss ,
525
- "train_step_metric" : train_step_metric ,
526
- "train_epoch_metric" : train_epoch_metric ,
527
- "val_step_loss" : val_step_loss ,
528
- "val_epoch_loss" : val_epoch_loss ,
529
- "val_step_metric" : val_step_metric ,
530
- "val_epoch_metric" : val_epoch_metric ,
531
- }
532
- with open (output_filename , "w" ) as f :
533
- json .dump (metrics_data , f )
0 commit comments