20
20
from QEfficient .finetune .utils .helper import (
21
21
Task_Mode ,
22
22
get_autocast_ctx ,
23
+ get_num_ddp_devices ,
23
24
get_op_verifier_ctx ,
24
25
is_rank_zero ,
25
26
save_to_json ,
@@ -67,8 +68,8 @@ def train(
67
68
68
69
train_metric = []
69
70
train_loss = []
70
- val_metric = []
71
- val_loss = []
71
+ eval_metric = []
72
+ eval_loss = []
72
73
73
74
if train_config .save_metrics :
74
75
if not os .path .exists (train_config .output_dir ):
@@ -78,13 +79,13 @@ def train(
78
79
)
79
80
train_step_metric = []
80
81
train_step_loss = []
81
- val_step_loss = []
82
- val_step_metric = []
82
+ eval_step_loss = []
83
+ eval_step_metric = []
83
84
84
85
epoch_times = []
85
86
checkpoint_times = []
86
87
results = {}
87
- best_val_loss = float ("inf" )
88
+ best_eval_loss = float ("inf" )
88
89
total_train_steps = 0
89
90
max_steps_reached = False # Flag to indicate max training steps reached
90
91
@@ -132,8 +133,7 @@ def train(
132
133
continue
133
134
134
135
logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
135
- if train_config .max_train_step > 0 :
136
- logger .log_rank_zero (f"Max train steps : { train_config .max_train_step } " )
136
+
137
137
if max_steps_reached :
138
138
break
139
139
@@ -170,6 +170,11 @@ def train(
170
170
171
171
if train_config .max_train_step > 0 and total_train_steps > train_config .max_train_step :
172
172
max_steps_reached = True
173
+ logger .log_rank_zero (
174
+ "Maximum training steps reached "
175
+ f"(max_train_step={ train_config .max_train_step } ). Stopping "
176
+ "the training process."
177
+ )
173
178
break
174
179
batch = {k : v .to (device ) for k , v in batch .items ()} # move the batch elements to qaic device
175
180
@@ -207,6 +212,7 @@ def train(
207
212
logger .info ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
208
213
209
214
total_loss += loss .detach ().float ()
215
+
210
216
if is_rank_zero ():
211
217
tensorboard_updates .add_scalars ("loss" , {"train" : loss }, total_train_steps )
212
218
if loss <= train_config .convergence_loss :
@@ -219,10 +225,10 @@ def train(
219
225
if train_config .save_metrics :
220
226
train_step_loss .append (loss .detach ().float ().item ())
221
227
if train_config .task_mode == Task_Mode .SEQ_CLASSIFICATION :
222
- step_metric_val = float (acc_helper .compute ())
228
+ step_metric_value = float (acc_helper .compute ())
223
229
else :
224
- step_metric_val = float (torch .exp (loss .detach ().float ()))
225
- train_step_metric .append (step_metric_val )
230
+ step_metric_value = float (torch .exp (loss .detach ().float ()))
231
+ train_step_metric .append (step_metric_value )
226
232
227
233
# Accumulate gradients
228
234
complete_accum_steps = (
@@ -250,15 +256,17 @@ def train(
250
256
pbar .update (1 )
251
257
252
258
# Save the trained checkpoints for every given steps
253
- if step % train_config .intermediate_step_save == 0 :
259
+ if ( step + 1 ) % train_config .intermediate_step_save == 0 :
254
260
qaic_profile .stop_profiling (device ) if train_config .use_profiler else None
255
261
if train_config .enable_ddp :
256
262
if dist .get_rank () == 0 :
257
263
model .module .save_pretrained (
258
- train_config .output_dir + f"/trained_weights/epoch_{ epoch + 1 } /step_{ step } "
264
+ train_config .output_dir + f"/trained_weights/epoch_{ epoch + 1 } /step_{ step + 1 } "
259
265
)
260
266
else :
261
- model .save_pretrained (train_config .output_dir + f"/trained_weights/epoch_{ epoch + 1 } /step_{ step } " )
267
+ model .save_pretrained (
268
+ train_config .output_dir + f"/trained_weights/epoch_{ epoch + 1 } /step_{ step + 1 } "
269
+ )
262
270
263
271
pbar .set_description (
264
272
f"Training Epoch: { epoch + 1 } /{ train_config .num_epochs } , step { step + 1 } /{ len (train_dataloader )} completed (loss: { loss .detach ().float ()} )"
@@ -270,10 +278,10 @@ def train(
270
278
train_loss ,
271
279
train_step_metric ,
272
280
train_metric ,
273
- val_step_loss ,
274
- val_loss ,
275
- val_step_metric ,
276
- val_metric ,
281
+ eval_step_loss ,
282
+ eval_loss ,
283
+ eval_step_metric ,
284
+ eval_metric ,
277
285
)
278
286
if loss_0_counter .item () == train_config .convergence_counter :
279
287
logger .log_rank_zero (
@@ -285,44 +293,64 @@ def train(
285
293
epoch_end_time = time .perf_counter () - epoch_start_time
286
294
epoch_times .append (epoch_end_time )
287
295
288
- if loss_0_counter .item () == train_config .convergence_counter :
289
- if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
290
- train_epoch_loss = (
291
- 0.0
292
- if total_loss == 0.0
293
- else total_loss / (step - intermediate_step - num_dummy_samples / train_config .train_batch_size )
294
- )
295
- else :
296
- train_epoch_loss = (
297
- 0.0
298
- if total_loss == 0.0
299
- else total_loss / (step + 1 - num_dummy_samples / train_config .train_batch_size )
300
- )
296
+ if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
297
+ train_epoch_loss = (
298
+ 0.0
299
+ if total_loss == 0.0
300
+ else total_loss / (step - intermediate_step - (num_dummy_samples / train_config .train_batch_size ))
301
+ )
301
302
else :
302
- if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
303
- train_epoch_loss = (
304
- 0.0
305
- if total_loss == 0.0
306
- else total_loss / (step - intermediate_step - (num_dummy_samples / train_config .train_batch_size ))
307
- )
308
- else :
309
- train_epoch_loss = (
310
- 0.0
311
- if total_loss == 0.0
312
- else total_loss / (step + 1 - (num_dummy_samples / train_config .train_batch_size ))
313
- )
303
+ train_epoch_loss = (
304
+ 0.0
305
+ if total_loss == 0.0
306
+ else total_loss / (step + 1 - (num_dummy_samples / train_config .train_batch_size ))
307
+ )
308
+
314
309
if train_config .task_mode == Task_Mode .SEQ_CLASSIFICATION :
315
- metric_val = acc_helper .compute ()
310
+ train_epoch_metric = acc_helper .compute ()
316
311
acc_helper .reset ()
317
312
else :
318
- metric_val = torch .exp (train_epoch_loss )
313
+ train_epoch_metric = torch .exp (train_epoch_loss )
319
314
320
- train_metric .append (float (metric_val ))
315
+ train_metric .append (float (train_epoch_metric ))
321
316
train_loss .append (float (train_epoch_loss ))
322
317
318
+ if train_config .enable_ddp :
319
+ dist .all_reduce (train_epoch_loss , op = dist .ReduceOp .SUM )
320
+ train_epoch_loss /= get_num_ddp_devices ()
321
+ dist .all_reduce (train_epoch_metric , op = dist .ReduceOp .SUM )
322
+ train_epoch_metric /= get_num_ddp_devices ()
323
+
323
324
# Update the learning rate as needed
324
325
lr_scheduler .step ()
325
326
327
+ if train_config .run_validation :
328
+ eval_epoch_loss , eval_epoch_metric , step_loss , step_metric = evaluation (
329
+ model , train_config , eval_dataloader , device
330
+ )
331
+
332
+ if is_rank_zero ():
333
+ tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
334
+ if train_config .save_metrics :
335
+ eval_step_loss .extend (step_loss )
336
+ eval_step_metric .extend (step_metric )
337
+ eval_loss .append (float (eval_epoch_loss ))
338
+ eval_metric .append (float (eval_epoch_metric ))
339
+
340
+ if train_config .enable_ddp :
341
+ dist .all_reduce (eval_epoch_loss , op = dist .ReduceOp .SUM )
342
+ eval_epoch_loss /= get_num_ddp_devices ()
343
+ dist .all_reduce (eval_epoch_metric , op = dist .ReduceOp .SUM )
344
+ eval_epoch_metric /= get_num_ddp_devices ()
345
+
346
+ if eval_epoch_loss < best_eval_loss :
347
+ best_eval_loss = eval_epoch_loss
348
+ logger .log_rank_zero (f"Best eval loss on epoch { epoch + 1 } is { best_eval_loss :.4f} " )
349
+
350
+ logger .log_rank_zero (
351
+ f"Epoch { epoch + 1 } : Eval Loss: { eval_epoch_loss .detach ().cpu ():.4f} , Eval metric: { eval_epoch_metric .detach ().cpu ():.4f} "
352
+ )
353
+
326
354
# saving the adapters after completion of each epoch
327
355
if train_config .save_model :
328
356
if train_config .enable_ddp :
@@ -331,25 +359,10 @@ def train(
331
359
else :
332
360
model .save_pretrained (train_config .output_dir + f"/complete_epoch_{ epoch + 1 } " )
333
361
334
- if train_config .run_validation :
335
- eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation (
336
- model , train_config , eval_dataloader , device
337
- )
338
- if eval_epoch_loss < best_val_loss :
339
- best_val_loss = eval_epoch_loss
340
- logger .log_rank_zero (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
341
-
342
- if is_rank_zero ():
343
- tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
344
-
345
- if train_config .save_metrics :
346
- val_step_loss .extend (temp_val_loss )
347
- val_step_metric .extend (temp_step_metric )
348
- val_loss .append (float (eval_epoch_loss ))
349
- val_metric .append (float (eval_metric ))
350
362
logger .log_rank_zero (
351
- f"Epoch { epoch + 1 } : train_metric= { metric_val :.4f} , train_epoch_loss= { train_epoch_loss :.4f} , epoch time { epoch_end_time } s "
363
+ f"Epoch { epoch + 1 } : Train epoch loss: { train_epoch_loss :.4f} , Train metric: { train_epoch_metric :.4f} , Epoch time { epoch_end_time :.2f } sec "
352
364
)
365
+
353
366
# Saving the results every epoch to plot later
354
367
if train_config .save_metrics :
355
368
save_to_json (
@@ -358,24 +371,19 @@ def train(
358
371
train_loss ,
359
372
train_step_metric ,
360
373
train_metric ,
361
- val_step_loss ,
362
- val_loss ,
363
- val_step_metric ,
364
- val_metric ,
374
+ eval_step_loss ,
375
+ eval_loss ,
376
+ eval_step_metric ,
377
+ eval_metric ,
365
378
)
366
379
avg_epoch_time = sum (epoch_times ) / len (epoch_times )
367
380
avg_checkpoint_time = sum (checkpoint_times ) / len (checkpoint_times ) if len (checkpoint_times ) > 0 else 0
368
- avg_train_metric = sum (train_metric ) / len (train_metric )
369
- avg_train_loss = sum (train_loss ) / len (train_loss )
370
- if train_config .run_validation :
371
- avg_eval_metric = sum (val_metric ) / len (val_metric )
372
- avg_eval_loss = sum (val_loss ) / len (val_loss )
373
381
374
- results ["avg_train_metric " ] = avg_train_metric
375
- results ["avg_train_loss " ] = avg_train_loss
382
+ results ["last_epoch_train_loss " ] = train_epoch_loss . cpu ()
383
+ results ["last_epoch_train_metric " ] = train_epoch_metric . cpu ()
376
384
if train_config .run_validation :
377
- results ["avg_eval_metric " ] = avg_eval_metric
378
- results ["avg_eval_loss " ] = avg_eval_loss
385
+ results ["last_epoch_eval_loss " ] = eval_epoch_loss . cpu ()
386
+ results ["last_epoch_eval_metric " ] = eval_epoch_metric . cpu ()
379
387
results ["avg_epoch_time" ] = avg_epoch_time
380
388
results ["avg_checkpoint_time" ] = avg_checkpoint_time
381
389
if train_config .save_metrics :
@@ -391,7 +399,7 @@ def evaluation(model, train_config, eval_dataloader, device):
391
399
model: The model to evaluate
392
400
eval_dataloader: The dataloader containing the evaluation data
393
401
394
- Returns: eval_epoch_loss, eval_metric , eval_step_loss, eval_step_metric
402
+ Returns: eval_epoch_loss, eval_epoch_metric , eval_step_loss, eval_step_metric
395
403
"""
396
404
if train_config .enable_ddp :
397
405
dist .barrier ()
@@ -408,17 +416,17 @@ def evaluation(model, train_config, eval_dataloader, device):
408
416
# special handling for qaic device and dtype
409
417
# model.to(device)
410
418
411
- val_step_loss = []
412
- val_step_metric = []
419
+ eval_step_loss = []
420
+ eval_step_metric = []
413
421
414
- eval_loss = 0.0 # Initialize evaluation loss
422
+ eval_loss = torch . tensor ( 0.0 , dtype = torch . float32 , device = device ) # Initialize evaluation loss
415
423
device_type = torch .device (device ).type
416
424
417
425
num_dummy_samples = 0
418
426
autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
419
427
for step , batch in enumerate (tqdm (eval_dataloader , colour = "green" , desc = "evaluating Epoch" , dynamic_ncols = True )):
420
428
# stop when the maximum number of eval steps is reached
421
- if train_config .max_eval_step > 0 and step > train_config .max_eval_step :
429
+ if train_config .max_eval_step > 0 and step >= train_config .max_eval_step :
422
430
break
423
431
for key in batch .keys ():
424
432
batch [key ] = batch [key ].to (device )
@@ -445,29 +453,27 @@ def evaluation(model, train_config, eval_dataloader, device):
445
453
logits = outputs .logits
446
454
labels = batch ["labels" ][:, 0 ]
447
455
preds = torch .nn .functional .softmax (logits , dim = - 1 )
448
- val_acc = acc_helper .forward (preds , labels )
449
- metric_val = val_acc .detach ().float ().item ()
456
+ eval_acc = acc_helper .forward (preds , labels )
457
+ metric_value = eval_acc .detach ().float ().item ()
450
458
else :
451
- metric_val = float (torch .exp (loss .detach ().float ()))
459
+ metric_value = float (torch .exp (loss .detach ().float ()))
452
460
453
461
if train_config .save_metrics :
454
- val_step_loss .append (loss .detach ().float ().item ())
455
- val_step_metric .append (metric_val )
462
+ eval_step_loss .append (loss .detach ().float ().item ())
463
+ eval_step_metric .append (metric_value )
456
464
457
465
eval_loss += loss .detach ().float ()
466
+
458
467
# Compute average loss and metric
459
468
eval_epoch_loss = (
460
469
0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config .val_batch_size )
461
470
)
462
471
if train_config .task_mode == Task_Mode .SEQ_CLASSIFICATION :
463
- eval_metric = acc_helper .compute ()
472
+ eval_epoch_metric = acc_helper .compute ()
464
473
else :
465
- eval_metric = torch .exp (eval_epoch_loss )
466
-
467
- # Print evaluation metrics
468
- logger .log_rank_zero (f"{ eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
474
+ eval_epoch_metric = torch .exp (eval_epoch_loss )
469
475
470
- return eval_epoch_loss , eval_metric , val_step_loss , val_step_metric
476
+ return eval_epoch_loss , eval_epoch_metric , eval_step_loss , eval_step_metric
471
477
472
478
473
479
def print_model_size (model ) -> None :
0 commit comments