88import json
99import os
1010import time
11- from contextlib import nullcontext
1211from datetime import datetime
12+ from functools import partial
1313from typing import Dict , List , Tuple
1414
1515import torch
1919from tqdm import tqdm
2020
2121from QEfficient .finetune .configs .training import TrainConfig
22+ from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx
2223
2324try :
2425 import torch_qaic # noqa: F401
@@ -110,6 +111,9 @@ def train(
110111 num_classes = model .classifier .out_features
111112 acc_helper = torchmetrics .classification .MulticlassAccuracy (num_classes = num_classes ).to (device )
112113
114+ autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
115+ op_verifier_ctx = partial (get_op_verifier_ctx , train_config .opByOpVerifier , device , train_config .dump_root_dir )
116+
113117 # Start the training loop
114118 for epoch in range (train_config .num_epochs ):
115119 if loss_0_counter .item () == train_config .convergence_counter :
@@ -174,60 +178,38 @@ def train(
174178 break
175179 batch = {k : v .to (device ) for k , v in batch .items ()} # move the batch elements to qaic device
176180
177- with (
178- torch .autocast (device_type = device_type , dtype = torch .float16 )
179- if train_config .use_autocast
180- else nullcontext ()
181- ):
182- # an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
183- if train_config .opByOpVerifier :
184- with qaic_debug .OpByOpVerifierMode (
185- ref_device = "cpu" ,
186- ref_dtype = torch .float32 ,
187- # adjust atol & rtol this as required
188- atol = 1e-1 ,
189- use_ref_output_on_mismatch = True ,
190- filter_config = qaic_debug .DispatchFilterConfig .default (device ),
191- dump_root_dir = train_config .dump_root_dir + str (step ),
192- ) as verifier :
193- model_outputs = model (** batch )
194- loss = model_outputs .loss # Forward call
195- if (batch ["labels" ] != - 100 ).sum () == 0 :
196- loss = loss .nan_to_num (nan = 0.0 )
197- num_dummy_samples += train_config .train_batch_size
198- else :
199- num_dummy_samples_per_batch = (
200- (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
201- )
202- if num_dummy_samples_per_batch > 0 :
203- num_dummy_samples += num_dummy_samples_per_batch
204- loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
205-
206- if train_config .task_type == "seq_classification" :
207- logits = model_outputs .logits
208- labels = batch ["labels" ][:, 0 ]
209- preds = torch .nn .functional .softmax (logits , dim = - 1 )
210- acc_helper .forward (preds , labels )
211- print ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
181+ is_optimizer_step = (step + 1 ) % train_config .gradient_accumulation_steps == 0 or step == len (
182+ train_dataloader
183+ ) - 1
184+ if train_config .enable_ddp :
185+ # Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
186+ # in DDP training we only need to sync gradients at the last micro step.
187+ # the official way to do this is with model.no_sync() context manager, but
188+ # using too many context managers may bloat the code and forces us to repeat code
189+ # looking at the source of that context manager, it just toggles this variable
190+ model .require_backward_grad_sync = is_optimizer_step
191+
192+ with autocast_ctx , op_verifier_ctx (step ) as verifier :
193+ model_outputs = model (** batch )
194+ loss = model_outputs .loss # Forward call
195+ if (batch ["labels" ] != - 100 ).sum () == 0 :
196+ loss = loss .nan_to_num (nan = 0.0 )
197+ num_dummy_samples += train_config .train_batch_size
212198 else :
213- model_outputs = model (** batch )
214- loss = model_outputs .loss # Forward call
215- if (batch ["labels" ] != - 100 ).sum () == 0 :
216- loss = loss .nan_to_num (nan = 0.0 )
217- num_dummy_samples += train_config .train_batch_size
218- else :
219- num_dummy_samples_per_batch = (
220- (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
221- )
222- if num_dummy_samples_per_batch > 0 :
223- num_dummy_samples += num_dummy_samples_per_batch
224- loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
199+ num_dummy_samples_per_batch = (
200+ (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
201+ )
202+ if num_dummy_samples_per_batch > 0 :
203+ num_dummy_samples += num_dummy_samples_per_batch
204+ loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
225205
226- if train_config .task_type == "seq_classification" :
227- logits = model_outputs .logits
228- labels = batch ["labels" ][:, 0 ]
229- preds = torch .nn .functional .softmax (logits , dim = - 1 )
230- acc_helper .forward (preds , labels )
206+ if train_config .task_type == "seq_classification" :
207+ logits = model_outputs .logits
208+ labels = batch ["labels" ][:, 0 ]
209+ preds = torch .nn .functional .softmax (logits , dim = - 1 )
210+ acc_helper .forward (preds , labels )
211+ if train_config .opByOpVerifier :
212+ print ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
231213
232214 total_loss += loss .detach ().float ()
233215
@@ -274,7 +256,7 @@ def train(
274256 else :
275257 loss .backward () # backward pass
276258
277- if ( step + 1 ) % train_config . gradient_accumulation_steps == 0 or step == len ( train_dataloader ) - 1 :
259+ if is_optimizer_step :
278260 if train_config .grad_scaler :
279261 scaler .step (optimizer )
280262 scaler .update ()
@@ -468,6 +450,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
468450 device_type = torch .device (device ).type
469451
470452 num_dummy_samples = 0
453+ autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
471454 for step , batch in enumerate (tqdm (eval_dataloader , colour = "green" , desc = "evaluating Epoch" , dynamic_ncols = True )):
472455 # stop when the maximum number of eval steps is reached
473456 if train_config .max_eval_step > 0 and step > train_config .max_eval_step :
@@ -478,11 +461,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
478461 # Ensure no gradients are computed for this scope to save memory
479462 with torch .no_grad ():
480463 # Forward pass and compute loss
481- with (
482- torch .autocast (device_type = device_type , dtype = torch .float16 )
483- if train_config .use_autocast
484- else nullcontext ()
485- ):
464+ with autocast_ctx :
486465 outputs = model (** batch )
487466 loss = outputs .loss
488467
0 commit comments