5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
- import json
9
8
import os
10
9
import time
11
10
from datetime import datetime
12
11
from functools import partial
13
- from typing import Dict , List , Tuple
14
12
15
13
import torch
16
14
import torch .distributed as dist
19
17
from tqdm import tqdm
20
18
21
19
from QEfficient .finetune .configs .training import TrainConfig
22
- from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx , is_rank_zero
20
+ from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx , is_rank_zero , save_to_json
23
21
from QEfficient .finetune .utils .logging_utils import logger
24
22
25
23
try :
@@ -42,24 +40,24 @@ def train(
42
40
optimizer ,
43
41
lr_scheduler ,
44
42
train_config : TrainConfig ,
45
- local_rank = None ,
46
43
):
47
44
"""
48
45
Trains the model on the given dataloader
49
46
50
47
Args:
51
48
model: The model to be trained
52
- tokenizer: tokenizer used in the eval for decoding the predicitons
49
+ tokenizer: tokenizer used in the eval for decoding the predictions
53
50
train_dataloader: The dataloader containing the training data
54
51
eval_dataloader: The dataloader containing the eval data
55
52
optimizer: The optimizer used for training
56
53
lr_scheduler: The learning rate scheduler
57
54
train_config: The training configuration
58
- local_rank: The rank of the current node in a distributed setting
59
55
60
56
Returns: results dictionary containing average training and validation perplexity and loss
61
57
"""
62
58
device = train_config .device
59
+ device_type = torch .device (device ).type
60
+ local_rank = int (os .getenv ("LOCAL_RANK" , 0 ))
63
61
64
62
train_metric = []
65
63
train_loss = []
@@ -89,8 +87,6 @@ def train(
89
87
tensorboard_log_dir = train_config .output_dir + "/runs/" + f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} "
90
88
tensorboard_updates = SummaryWriter (log_dir = tensorboard_log_dir )
91
89
92
- device_type = torch .device (device ).type
93
-
94
90
if train_config .grad_scaler :
95
91
if device .startswith ("qaic" ):
96
92
scaler = QAicGradScaler ()
@@ -130,10 +126,11 @@ def train(
130
126
continue
131
127
132
128
logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
133
- logger . log_rank_zero ( f"train_config.max_train_step: { train_config .max_train_step } " )
134
- # stop when the maximum number of training steps is reached
129
+ if train_config .max_train_step > 0 :
130
+ logger . log_rank_zero ( f"Max train steps : { train_config . max_train_step } " )
135
131
if max_steps_reached :
136
132
break
133
+
137
134
epoch_start_time = time .perf_counter ()
138
135
model .train ()
139
136
@@ -165,7 +162,6 @@ def train(
165
162
continue
166
163
total_train_steps += 1
167
164
168
- # stop when the maximum number of training steps is reached
169
165
if train_config .max_train_step > 0 and total_train_steps > train_config .max_train_step :
170
166
max_steps_reached = True
171
167
break
@@ -223,7 +219,7 @@ def train(
223
219
step_metric_val = float (torch .exp (loss .detach ().float ()))
224
220
train_step_metric .append (step_metric_val )
225
221
226
- # Accumalate gradients
222
+ # Accumulate gradients
227
223
complete_accum_steps = (
228
224
len (train_dataloader ) - len (train_dataloader ) % train_config .gradient_accumulation_steps
229
225
)
@@ -291,25 +287,6 @@ def train(
291
287
if total_loss == 0.0
292
288
else total_loss / (step - intermediate_step - num_dummy_samples / train_config .train_batch_size )
293
289
)
294
- else :
295
- train_epoch_loss = (
296
- 0.0
297
- if total_loss == 0.0
298
- else total_loss / (step + 1 - num_dummy_samples / train_config .train_batch_size )
299
- )
300
- else :
301
- if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
302
- train_epoch_loss = (
303
- 0.0
304
- if total_loss == 0.0
305
- else total_loss / (step - intermediate_step - (num_dummy_samples / train_config .train_batch_size ))
306
- )
307
- else :
308
- train_epoch_loss = (
309
- 0.0
310
- if total_loss == 0.0
311
- else total_loss / (step + 1 - (num_dummy_samples / train_config .train_batch_size ))
312
- )
313
290
if train_config .task_type == "seq_classification" :
314
291
metric_val = acc_helper .compute ()
315
292
acc_helper .reset ()
@@ -322,17 +299,6 @@ def train(
322
299
# Update the learning rate as needed
323
300
lr_scheduler .step ()
324
301
325
- if train_config .run_validation :
326
- eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation_helper (
327
- model , train_config , eval_dataloader , device
328
- )
329
- if is_rank_zero ():
330
- tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
331
-
332
- if train_config .save_metrics :
333
- val_step_loss .extend (temp_val_loss )
334
- val_step_metric .extend (temp_step_metric )
335
-
336
302
# saving the adapters after completion of each epoch
337
303
if train_config .save_model :
338
304
if train_config .enable_ddp :
@@ -342,19 +308,24 @@ def train(
342
308
model .save_pretrained (train_config .output_dir + f"/complete_epoch_{ epoch + 1 } " )
343
309
344
310
if train_config .run_validation :
311
+ eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation (
312
+ model , train_config , eval_dataloader , device
313
+ )
345
314
if eval_epoch_loss < best_val_loss :
346
315
best_val_loss = eval_epoch_loss
347
316
logger .log_rank_zero (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
317
+
318
+ if is_rank_zero ():
319
+ tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
320
+
321
+ if train_config .save_metrics :
322
+ val_step_loss .extend (temp_val_loss )
323
+ val_step_metric .extend (temp_step_metric )
348
324
val_loss .append (float (eval_epoch_loss ))
349
325
val_metric .append (float (eval_metric ))
350
- if train_config .task_type == "seq_classification" :
351
- logger .log_rank_zero (
352
- f"Epoch { epoch + 1 } : train_acc={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
353
- )
354
- else :
355
- logger .log_rank_zero (
356
- f"Epoch { epoch + 1 } : train_metric={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
357
- )
326
+ logger .log_rank_zero (
327
+ f"Epoch { epoch + 1 } : train_metric={ metric_val :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
328
+ )
358
329
359
330
# Saving the results every epoch to plot later
360
331
if train_config .save_metrics :
@@ -389,7 +360,7 @@ def train(
389
360
return results
390
361
391
362
392
- def evaluation_helper (model , train_config , eval_dataloader , device ):
363
+ def evaluation (model , train_config , eval_dataloader , device ):
393
364
"""
394
365
Evaluates the model on the given dataloader
395
366
@@ -474,60 +445,3 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
474
445
logger .log_rank_zero (f"{ eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
475
446
476
447
return eval_epoch_loss , eval_metric , val_step_loss , val_step_metric
477
-
478
-
479
- def get_longest_seq_length (data : List [Dict ]) -> Tuple [int , int ]:
480
- # find out the minimum max_seq_length required during fine-tuning (saves memory!)
481
- lengths = [len (d ["input_ids" ]) for d in data ]
482
- longest_seq_length = max (lengths )
483
- longest_seq_ix = lengths .index (longest_seq_length )
484
- return longest_seq_length , longest_seq_ix
485
-
486
-
487
- def print_model_size (model ) -> None :
488
- """
489
- Print model name, the number of trainable parameters and initialization time.
490
-
491
- Args:
492
- model: PyTorch model.
493
- """
494
- total_params = sum (p .numel () for p in model .parameters () if p .requires_grad )
495
- logger .log_rank_zero (f"Model has { total_params / 1e6 } Million params." )
496
-
497
-
498
- def print_trainable_parameters (model ) -> None :
499
- """
500
- Print the number of trainable parameters, all params and percentage of trainablke params.
501
-
502
- Args:
503
- model: The PyTorch model.
504
- """
505
- trainable_params , all_param = model .get_nb_trainable_parameters ()
506
- logger .log_rank_zero (
507
- f"Trainable params: { trainable_params :,d} || all params: { all_param :,d} || trainable%: { 100 * trainable_params / all_param :.4f} "
508
- )
509
-
510
-
511
- def save_to_json (
512
- output_filename ,
513
- train_step_loss ,
514
- train_epoch_loss ,
515
- train_step_metric ,
516
- train_epoch_metric ,
517
- val_step_loss ,
518
- val_epoch_loss ,
519
- val_step_metric ,
520
- val_epoch_metric ,
521
- ):
522
- metrics_data = {
523
- "train_step_loss" : train_step_loss ,
524
- "train_epoch_loss" : train_epoch_loss ,
525
- "train_step_metric" : train_step_metric ,
526
- "train_epoch_metric" : train_epoch_metric ,
527
- "val_step_loss" : val_step_loss ,
528
- "val_epoch_loss" : val_epoch_loss ,
529
- "val_step_metric" : val_step_metric ,
530
- "val_epoch_metric" : val_epoch_metric ,
531
- }
532
- with open (output_filename , "w" ) as f :
533
- json .dump (metrics_data , f )
0 commit comments