Skip to content

Commit 89e0ee6

Browse files
authored
Merge branch 'main' into pp_ddp
Signed-off-by: Mamta Singh <[email protected]>
2 parents 4d8d470 + 5fb7532 commit 89e0ee6

File tree

3 files changed

+145
-132
lines changed

3 files changed

+145
-132
lines changed

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
from typing import Dict, List, Tuple
9+
import logging
910

1011
import datasets
1112
import torch
@@ -69,6 +70,11 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
6970

7071

7172
def padding_dataset(train_config, dataset, batch_size):
73+
num_replicas = get_num_ddp_devices()
74+
remainder = len(dataset) % (num_replicas * batch_size)
75+
if remainder == 0:
76+
return dataset
77+
7278
if train_config.enable_ddp and train_config.enable_sorting_for_ddp:
7379
if isinstance(dataset, datasets.Dataset):
7480
# Hugging Face Dataset transformation
@@ -80,24 +86,26 @@ def padding_dataset(train_config, dataset, batch_size):
8086

8187
dummy_row = next(iter(dataset))
8288
dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"]))
83-
padding_size = 0
84-
num_replicas = get_num_ddp_devices()
85-
remainder = len(dataset) % (num_replicas * batch_size)
86-
padding_size = (num_replicas * batch_size) - remainder
8789

90+
padding_size = (num_replicas * batch_size) - remainder
8891
dummy_data = [dummy_row.copy() for _ in range(padding_size)]
8992
dummy_dataset = datasets.Dataset.from_list(dummy_data)
9093
if isinstance(dataset, datasets.Dataset):
9194
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
9295
else:
9396
combined_dataset = dataset + list(dummy_dataset)
97+
98+
logger.log_rank_zero("Padding dataset to make it divisible by batch_size * num_devices.", logging.DEBUG)
99+
logger.log_rank_zero(f"Length of dataset before padding: {len(dataset)}", logging.DEBUG)
100+
logger.log_rank_zero(f"Length of dataset after padding: {len(combined_dataset)}", logging.DEBUG)
94101
return combined_dataset
95102

96103

97104
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
98105
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)
99106

100107
batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size
108+
101109
dataset = padding_dataset(train_config, dataset, batch_size)
102110

103111
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)

QEfficient/finetune/utils/train_utils.py

Lines changed: 97 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from QEfficient.finetune.utils.helper import (
2121
Task_Mode,
2222
get_autocast_ctx,
23+
get_num_ddp_devices,
2324
get_op_verifier_ctx,
2425
is_rank_zero,
2526
save_to_json,
@@ -67,8 +68,8 @@ def train(
6768

6869
train_metric = []
6970
train_loss = []
70-
val_metric = []
71-
val_loss = []
71+
eval_metric = []
72+
eval_loss = []
7273

7374
if train_config.save_metrics:
7475
if not os.path.exists(train_config.output_dir):
@@ -78,13 +79,13 @@ def train(
7879
)
7980
train_step_metric = []
8081
train_step_loss = []
81-
val_step_loss = []
82-
val_step_metric = []
82+
eval_step_loss = []
83+
eval_step_metric = []
8384

8485
epoch_times = []
8586
checkpoint_times = []
8687
results = {}
87-
best_val_loss = float("inf")
88+
best_eval_loss = float("inf")
8889
total_train_steps = 0
8990
max_steps_reached = False # Flag to indicate max training steps reached
9091

@@ -132,8 +133,7 @@ def train(
132133
continue
133134

134135
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
135-
if train_config.max_train_step > 0:
136-
logger.log_rank_zero(f"Max train steps : {train_config.max_train_step}")
136+
137137
if max_steps_reached:
138138
break
139139

@@ -170,6 +170,11 @@ def train(
170170

171171
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
172172
max_steps_reached = True
173+
logger.log_rank_zero(
174+
"Maximum training steps reached "
175+
f"(max_train_step={train_config.max_train_step}). Stopping "
176+
"the training process."
177+
)
173178
break
174179
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device
175180

@@ -207,6 +212,7 @@ def train(
207212
logger.info("Mismatches detected:", verifier.get_perop_mismatch_count())
208213

209214
total_loss += loss.detach().float()
215+
210216
if is_rank_zero():
211217
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
212218
if loss <= train_config.convergence_loss:
@@ -219,10 +225,10 @@ def train(
219225
if train_config.save_metrics:
220226
train_step_loss.append(loss.detach().float().item())
221227
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
222-
step_metric_val = float(acc_helper.compute())
228+
step_metric_value = float(acc_helper.compute())
223229
else:
224-
step_metric_val = float(torch.exp(loss.detach().float()))
225-
train_step_metric.append(step_metric_val)
230+
step_metric_value = float(torch.exp(loss.detach().float()))
231+
train_step_metric.append(step_metric_value)
226232

227233
# Accumulate gradients
228234
complete_accum_steps = (
@@ -250,15 +256,17 @@ def train(
250256
pbar.update(1)
251257

252258
# Save the trained checkpoints for every given steps
253-
if step % train_config.intermediate_step_save == 0:
259+
if (step + 1) % train_config.intermediate_step_save == 0:
254260
qaic_profile.stop_profiling(device) if train_config.use_profiler else None
255261
if train_config.enable_ddp:
256262
if dist.get_rank() == 0:
257263
model.module.save_pretrained(
258-
train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}"
264+
train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step + 1}"
259265
)
260266
else:
261-
model.save_pretrained(train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}")
267+
model.save_pretrained(
268+
train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step + 1}"
269+
)
262270

263271
pbar.set_description(
264272
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
@@ -270,10 +278,10 @@ def train(
270278
train_loss,
271279
train_step_metric,
272280
train_metric,
273-
val_step_loss,
274-
val_loss,
275-
val_step_metric,
276-
val_metric,
281+
eval_step_loss,
282+
eval_loss,
283+
eval_step_metric,
284+
eval_metric,
277285
)
278286
if loss_0_counter.item() == train_config.convergence_counter:
279287
logger.log_rank_zero(
@@ -285,44 +293,64 @@ def train(
285293
epoch_end_time = time.perf_counter() - epoch_start_time
286294
epoch_times.append(epoch_end_time)
287295

288-
if loss_0_counter.item() == train_config.convergence_counter:
289-
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
290-
train_epoch_loss = (
291-
0.0
292-
if total_loss == 0.0
293-
else total_loss / (step - intermediate_step - num_dummy_samples / train_config.train_batch_size)
294-
)
295-
else:
296-
train_epoch_loss = (
297-
0.0
298-
if total_loss == 0.0
299-
else total_loss / (step + 1 - num_dummy_samples / train_config.train_batch_size)
300-
)
296+
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
297+
train_epoch_loss = (
298+
0.0
299+
if total_loss == 0.0
300+
else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size))
301+
)
301302
else:
302-
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
303-
train_epoch_loss = (
304-
0.0
305-
if total_loss == 0.0
306-
else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size))
307-
)
308-
else:
309-
train_epoch_loss = (
310-
0.0
311-
if total_loss == 0.0
312-
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
313-
)
303+
train_epoch_loss = (
304+
0.0
305+
if total_loss == 0.0
306+
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
307+
)
308+
314309
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
315-
metric_val = acc_helper.compute()
310+
train_epoch_metric = acc_helper.compute()
316311
acc_helper.reset()
317312
else:
318-
metric_val = torch.exp(train_epoch_loss)
313+
train_epoch_metric = torch.exp(train_epoch_loss)
319314

320-
train_metric.append(float(metric_val))
315+
train_metric.append(float(train_epoch_metric))
321316
train_loss.append(float(train_epoch_loss))
322317

318+
if train_config.enable_ddp:
319+
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
320+
train_epoch_loss /= get_num_ddp_devices()
321+
dist.all_reduce(train_epoch_metric, op=dist.ReduceOp.SUM)
322+
train_epoch_metric /= get_num_ddp_devices()
323+
323324
# Update the learning rate as needed
324325
lr_scheduler.step()
325326

327+
if train_config.run_validation:
328+
eval_epoch_loss, eval_epoch_metric, step_loss, step_metric = evaluation(
329+
model, train_config, eval_dataloader, device
330+
)
331+
332+
if is_rank_zero():
333+
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
334+
if train_config.save_metrics:
335+
eval_step_loss.extend(step_loss)
336+
eval_step_metric.extend(step_metric)
337+
eval_loss.append(float(eval_epoch_loss))
338+
eval_metric.append(float(eval_epoch_metric))
339+
340+
if train_config.enable_ddp:
341+
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
342+
eval_epoch_loss /= get_num_ddp_devices()
343+
dist.all_reduce(eval_epoch_metric, op=dist.ReduceOp.SUM)
344+
eval_epoch_metric /= get_num_ddp_devices()
345+
346+
if eval_epoch_loss < best_eval_loss:
347+
best_eval_loss = eval_epoch_loss
348+
logger.log_rank_zero(f"Best eval loss on epoch {epoch + 1} is {best_eval_loss:.4f}")
349+
350+
logger.log_rank_zero(
351+
f"Epoch {epoch + 1}: Eval Loss: {eval_epoch_loss.detach().cpu():.4f}, Eval metric: {eval_epoch_metric.detach().cpu():.4f}"
352+
)
353+
326354
# saving the adapters after completion of each epoch
327355
if train_config.save_model:
328356
if train_config.enable_ddp:
@@ -331,25 +359,10 @@ def train(
331359
else:
332360
model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")
333361

334-
if train_config.run_validation:
335-
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation(
336-
model, train_config, eval_dataloader, device
337-
)
338-
if eval_epoch_loss < best_val_loss:
339-
best_val_loss = eval_epoch_loss
340-
logger.log_rank_zero(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
341-
342-
if is_rank_zero():
343-
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
344-
345-
if train_config.save_metrics:
346-
val_step_loss.extend(temp_val_loss)
347-
val_step_metric.extend(temp_step_metric)
348-
val_loss.append(float(eval_epoch_loss))
349-
val_metric.append(float(eval_metric))
350362
logger.log_rank_zero(
351-
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
363+
f"Epoch {epoch + 1}: Train epoch loss: {train_epoch_loss:.4f}, Train metric: {train_epoch_metric:.4f}, Epoch time {epoch_end_time:.2f} sec"
352364
)
365+
353366
# Saving the results every epoch to plot later
354367
if train_config.save_metrics:
355368
save_to_json(
@@ -358,24 +371,19 @@ def train(
358371
train_loss,
359372
train_step_metric,
360373
train_metric,
361-
val_step_loss,
362-
val_loss,
363-
val_step_metric,
364-
val_metric,
374+
eval_step_loss,
375+
eval_loss,
376+
eval_step_metric,
377+
eval_metric,
365378
)
366379
avg_epoch_time = sum(epoch_times) / len(epoch_times)
367380
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
368-
avg_train_metric = sum(train_metric) / len(train_metric)
369-
avg_train_loss = sum(train_loss) / len(train_loss)
370-
if train_config.run_validation:
371-
avg_eval_metric = sum(val_metric) / len(val_metric)
372-
avg_eval_loss = sum(val_loss) / len(val_loss)
373381

374-
results["avg_train_metric"] = avg_train_metric
375-
results["avg_train_loss"] = avg_train_loss
382+
results["last_epoch_train_loss"] = train_epoch_loss.cpu()
383+
results["last_epoch_train_metric"] = train_epoch_metric.cpu()
376384
if train_config.run_validation:
377-
results["avg_eval_metric"] = avg_eval_metric
378-
results["avg_eval_loss"] = avg_eval_loss
385+
results["last_epoch_eval_loss"] = eval_epoch_loss.cpu()
386+
results["last_epoch_eval_metric"] = eval_epoch_metric.cpu()
379387
results["avg_epoch_time"] = avg_epoch_time
380388
results["avg_checkpoint_time"] = avg_checkpoint_time
381389
if train_config.save_metrics:
@@ -391,7 +399,7 @@ def evaluation(model, train_config, eval_dataloader, device):
391399
model: The model to evaluate
392400
eval_dataloader: The dataloader containing the evaluation data
393401
394-
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
402+
Returns: eval_epoch_loss, eval_epoch_metric, eval_step_loss, eval_step_metric
395403
"""
396404
if train_config.enable_ddp:
397405
dist.barrier()
@@ -408,17 +416,17 @@ def evaluation(model, train_config, eval_dataloader, device):
408416
# special handling for qaic device and dtype
409417
# model.to(device)
410418

411-
val_step_loss = []
412-
val_step_metric = []
419+
eval_step_loss = []
420+
eval_step_metric = []
413421

414-
eval_loss = 0.0 # Initialize evaluation loss
422+
eval_loss = torch.tensor(0.0, dtype=torch.float32, device=device) # Initialize evaluation loss
415423
device_type = torch.device(device).type
416424

417425
num_dummy_samples = 0
418426
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
419427
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
420428
# stop when the maximum number of eval steps is reached
421-
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
429+
if train_config.max_eval_step > 0 and step >= train_config.max_eval_step:
422430
break
423431
for key in batch.keys():
424432
batch[key] = batch[key].to(device)
@@ -445,29 +453,27 @@ def evaluation(model, train_config, eval_dataloader, device):
445453
logits = outputs.logits
446454
labels = batch["labels"][:, 0]
447455
preds = torch.nn.functional.softmax(logits, dim=-1)
448-
val_acc = acc_helper.forward(preds, labels)
449-
metric_val = val_acc.detach().float().item()
456+
eval_acc = acc_helper.forward(preds, labels)
457+
metric_value = eval_acc.detach().float().item()
450458
else:
451-
metric_val = float(torch.exp(loss.detach().float()))
459+
metric_value = float(torch.exp(loss.detach().float()))
452460

453461
if train_config.save_metrics:
454-
val_step_loss.append(loss.detach().float().item())
455-
val_step_metric.append(metric_val)
462+
eval_step_loss.append(loss.detach().float().item())
463+
eval_step_metric.append(metric_value)
456464

457465
eval_loss += loss.detach().float()
466+
458467
# Compute average loss and metric
459468
eval_epoch_loss = (
460469
0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size)
461470
)
462471
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
463-
eval_metric = acc_helper.compute()
472+
eval_epoch_metric = acc_helper.compute()
464473
else:
465-
eval_metric = torch.exp(eval_epoch_loss)
466-
467-
# Print evaluation metrics
468-
logger.log_rank_zero(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
474+
eval_epoch_metric = torch.exp(eval_epoch_loss)
469475

470-
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
476+
return eval_epoch_loss, eval_epoch_metric, eval_step_loss, eval_step_metric
471477

472478

473479
def print_model_size(model) -> None:

0 commit comments

Comments
 (0)