8
8
import json
9
9
import os
10
10
import time
11
- from contextlib import nullcontext
12
11
from datetime import datetime
12
+ from functools import partial
13
13
from typing import Dict , List , Tuple
14
14
15
15
import torch
19
19
from tqdm import tqdm
20
20
21
21
from QEfficient .finetune .configs .training import TrainConfig
22
- from QEfficient .finetune .utils .helper import is_rank_zero
22
+ from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx , is_rank_zero
23
23
from QEfficient .finetune .utils .logging_utils import logger
24
24
25
25
try :
@@ -85,8 +85,8 @@ def train(
85
85
max_steps_reached = False # Flag to indicate max training steps reached
86
86
87
87
tensorboard_updates = None
88
- tensorboard_log_dir = train_config .output_dir + "/runs/" + f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} "
89
88
if is_rank_zero ():
89
+ tensorboard_log_dir = train_config .output_dir + "/runs/" + f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} "
90
90
tensorboard_updates = SummaryWriter (log_dir = tensorboard_log_dir )
91
91
92
92
device_type = torch .device (device ).type
@@ -110,6 +110,9 @@ def train(
110
110
num_classes = model .classifier .out_features
111
111
acc_helper = torchmetrics .classification .MulticlassAccuracy (num_classes = num_classes ).to (device )
112
112
113
+ autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
114
+ op_verifier_ctx = partial (get_op_verifier_ctx , train_config .opByOpVerifier , device , train_config .output_dir )
115
+
113
116
# Start the training loop
114
117
for epoch in range (train_config .num_epochs ):
115
118
if loss_0_counter .item () == train_config .convergence_counter :
@@ -168,60 +171,38 @@ def train(
168
171
break
169
172
batch = {k : v .to (device ) for k , v in batch .items ()} # move the batch elements to qaic device
170
173
171
- with (
172
- torch .autocast (device_type = device_type , dtype = torch .float16 )
173
- if train_config .use_autocast
174
- else nullcontext ()
175
- ):
176
- # an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
177
- if train_config .opByOpVerifier :
178
- with qaic_debug .OpByOpVerifierMode (
179
- ref_device = "cpu" ,
180
- ref_dtype = torch .float32 ,
181
- # adjust atol & rtol this as required
182
- atol = 1e-1 ,
183
- use_ref_output_on_mismatch = True ,
184
- filter_config = qaic_debug .DispatchFilterConfig .default (device ),
185
- dump_root_dir = train_config .output_dir + "/mismatches/step_" + str (step ),
186
- ) as verifier :
187
- model_outputs = model (** batch )
188
- loss = model_outputs .loss # Forward call
189
- if (batch ["labels" ] != - 100 ).sum () == 0 :
190
- loss = loss .nan_to_num (nan = 0.0 )
191
- num_dummy_samples += train_config .train_batch_size
192
- else :
193
- num_dummy_samples_per_batch = (
194
- (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
195
- )
196
- if num_dummy_samples_per_batch > 0 :
197
- num_dummy_samples += num_dummy_samples_per_batch
198
- loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
199
-
200
- if train_config .task_type == "seq_classification" :
201
- logits = model_outputs .logits
202
- labels = batch ["labels" ][:, 0 ]
203
- preds = torch .nn .functional .softmax (logits , dim = - 1 )
204
- acc_helper .forward (preds , labels )
205
- logger .info ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
174
+ is_optimizer_step = (step + 1 ) % train_config .gradient_accumulation_steps == 0 or step == len (
175
+ train_dataloader
176
+ ) - 1
177
+ if train_config .enable_ddp :
178
+ # Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
179
+ # in DDP training we only need to sync gradients at the last micro step.
180
+ # the official way to do this is with model.no_sync() context manager, but
181
+ # using too many context managers may bloat the code and forces us to repeat code
182
+ # looking at the source of that context manager, it just toggles this variable
183
+ model .require_backward_grad_sync = is_optimizer_step
184
+
185
+ with autocast_ctx , op_verifier_ctx (step ) as verifier :
186
+ model_outputs = model (** batch )
187
+ loss = model_outputs .loss # Forward call
188
+ if (batch ["labels" ] != - 100 ).sum () == 0 :
189
+ loss = loss .nan_to_num (nan = 0.0 )
190
+ num_dummy_samples += train_config .train_batch_size
206
191
else :
207
- model_outputs = model (** batch )
208
- loss = model_outputs .loss # Forward call
209
- if (batch ["labels" ] != - 100 ).sum () == 0 :
210
- loss = loss .nan_to_num (nan = 0.0 )
211
- num_dummy_samples += train_config .train_batch_size
212
- else :
213
- num_dummy_samples_per_batch = (
214
- (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
215
- )
216
- if num_dummy_samples_per_batch > 0 :
217
- num_dummy_samples += num_dummy_samples_per_batch
218
- loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
192
+ num_dummy_samples_per_batch = (
193
+ (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
194
+ )
195
+ if num_dummy_samples_per_batch > 0 :
196
+ num_dummy_samples += num_dummy_samples_per_batch
197
+ loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
219
198
220
- if train_config .task_type == "seq_classification" :
221
- logits = model_outputs .logits
222
- labels = batch ["labels" ][:, 0 ]
223
- preds = torch .nn .functional .softmax (logits , dim = - 1 )
224
- acc_helper .forward (preds , labels )
199
+ if train_config .task_type == "seq_classification" :
200
+ logits = model_outputs .logits
201
+ labels = batch ["labels" ][:, 0 ]
202
+ preds = torch .nn .functional .softmax (logits , dim = - 1 )
203
+ acc_helper .forward (preds , labels )
204
+ if train_config .opByOpVerifier :
205
+ logger .info ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
225
206
226
207
total_loss += loss .detach ().float ()
227
208
if is_rank_zero ():
@@ -258,7 +239,7 @@ def train(
258
239
else :
259
240
loss .backward () # backward pass
260
241
261
- if ( step + 1 ) % train_config . gradient_accumulation_steps == 0 or step == len ( train_dataloader ) - 1 :
242
+ if is_optimizer_step :
262
243
if train_config .grad_scaler :
263
244
scaler .step (optimizer )
264
245
scaler .update ()
@@ -440,6 +421,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
440
421
device_type = torch .device (device ).type
441
422
442
423
num_dummy_samples = 0
424
+ autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
443
425
for step , batch in enumerate (tqdm (eval_dataloader , colour = "green" , desc = "evaluating Epoch" , dynamic_ncols = True )):
444
426
# stop when the maximum number of eval steps is reached
445
427
if train_config .max_eval_step > 0 and step > train_config .max_eval_step :
@@ -450,11 +432,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
450
432
# Ensure no gradients are computed for this scope to save memory
451
433
with torch .no_grad ():
452
434
# Forward pass and compute loss
453
- with (
454
- torch .autocast (device_type = device_type , dtype = torch .float16 )
455
- if train_config .use_autocast
456
- else nullcontext ()
457
- ):
435
+ with autocast_ctx :
458
436
outputs = model (** batch )
459
437
loss = outputs .loss
460
438
0 commit comments