Skip to content

Commit 95efa6e

Browse files
mamtsingquic-mamta
authored andcommitted
minor refactoring
Signed-off-by: Mamta Singh <[email protected]>
1 parent 76b953a commit 95efa6e

File tree

4 files changed

+87
-110
lines changed

4 files changed

+87
-110
lines changed

QEfficient/cloud/finetune.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,10 @@
2828
)
2929
from QEfficient.finetune.utils.dataset_utils import get_dataloader
3030
from QEfficient.finetune.utils.device_map import get_device_map
31+
from QEfficient.finetune.utils.helper import get_longest_seq_length, print_model_size, print_trainable_parameters
3132
from QEfficient.finetune.utils.logging_utils import logger
3233
from QEfficient.finetune.utils.parser import get_finetune_parser
33-
from QEfficient.finetune.utils.train_utils import (
34-
get_longest_seq_length,
35-
print_model_size,
36-
print_trainable_parameters,
37-
train,
38-
)
34+
from QEfficient.finetune.utils.train_utils import train
3935
from QEfficient.utils._utils import hf_download
4036

4137
# Try importing QAIC-specific module, proceed without it if unavailable
@@ -338,7 +334,6 @@ def main(peft_config_file: str = None, **kwargs) -> None:
338334
optimizer,
339335
scheduler,
340336
train_config,
341-
dist.get_rank() if train_config.enable_ddp else None,
342337
)
343338
if train_config.enable_ddp:
344339
dist.destroy_process_group()

QEfficient/finetune/utils/helper.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import json
79
import os
810
from contextlib import nullcontext
11+
from typing import Dict, List, Tuple
912

1013
import torch
1114

15+
from QEfficient.finetune.utils.logging_utils import logger
16+
1217
try:
1318
import torch_qaic.debug as qaic_debug # noqa: F401
1419
except ImportError as e:
@@ -58,3 +63,60 @@ def get_op_verifier_ctx(
5863
filter_config=filter_config,
5964
dump_root_dir=dump_dir,
6065
)
66+
67+
68+
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
69+
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
70+
lengths = [len(d["input_ids"]) for d in data]
71+
longest_seq_length = max(lengths)
72+
longest_seq_ix = lengths.index(longest_seq_length)
73+
return longest_seq_length, longest_seq_ix
74+
75+
76+
def print_model_size(model) -> None:
77+
"""
78+
Print the number of trainable parameters.
79+
80+
Args:
81+
model: PyTorch model.
82+
"""
83+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
84+
logger.log_rank_zero(f"Model has {total_params / 1e6} Million params.")
85+
86+
87+
def print_trainable_parameters(model) -> None:
88+
"""
89+
Print the number of trainable parameters, all params and percentage of trainable params.
90+
91+
Args:
92+
model: The PyTorch model.
93+
"""
94+
trainable_params, all_param = model.get_nb_trainable_parameters()
95+
logger.log_rank_zero(
96+
f"Trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
97+
)
98+
99+
100+
def save_to_json(
101+
output_filename,
102+
train_step_loss,
103+
train_epoch_loss,
104+
train_step_metric,
105+
train_epoch_metric,
106+
val_step_loss,
107+
val_epoch_loss,
108+
val_step_metric,
109+
val_epoch_metric,
110+
):
111+
metrics_data = {
112+
"train_step_loss": train_step_loss,
113+
"train_epoch_loss": train_epoch_loss,
114+
"train_step_metric": train_step_metric,
115+
"train_epoch_metric": train_epoch_metric,
116+
"val_step_loss": val_step_loss,
117+
"val_epoch_loss": val_epoch_loss,
118+
"val_step_metric": val_step_metric,
119+
"val_epoch_metric": val_epoch_metric,
120+
}
121+
with open(output_filename, "w") as f:
122+
json.dump(metrics_data, f)

QEfficient/finetune/utils/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def get_finetune_parser():
276276
# This is for debugging purpose only.
277277
# Enables operation-by-operation verification w.r.t reference device(cpu).
278278
# It is a context manager interface that captures and verifies each operator against reference device.
279-
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.
279+
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at output_dir/mismatches.
280280
)
281281
parser.add_argument(
282282
"--dump_logs",

QEfficient/finetune/utils/train_utils.py

Lines changed: 22 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
import json
98
import os
109
import time
1110
from datetime import datetime
1211
from functools import partial
13-
from typing import Dict, List, Tuple
1412

1513
import torch
1614
import torch.distributed as dist
@@ -19,7 +17,7 @@
1917
from tqdm import tqdm
2018

2119
from QEfficient.finetune.configs.training import TrainConfig
22-
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx, is_rank_zero
20+
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx, is_rank_zero, save_to_json
2321
from QEfficient.finetune.utils.logging_utils import logger
2422

2523
try:
@@ -42,24 +40,24 @@ def train(
4240
optimizer,
4341
lr_scheduler,
4442
train_config: TrainConfig,
45-
local_rank=None,
4643
):
4744
"""
4845
Trains the model on the given dataloader
4946
5047
Args:
5148
model: The model to be trained
52-
tokenizer: tokenizer used in the eval for decoding the predicitons
49+
tokenizer: tokenizer used in the eval for decoding the predictions
5350
train_dataloader: The dataloader containing the training data
5451
eval_dataloader: The dataloader containing the eval data
5552
optimizer: The optimizer used for training
5653
lr_scheduler: The learning rate scheduler
5754
train_config: The training configuration
58-
local_rank: The rank of the current node in a distributed setting
5955
6056
Returns: results dictionary containing average training and validation perplexity and loss
6157
"""
6258
device = train_config.device
59+
device_type = torch.device(device).type
60+
local_rank = int(os.getenv("LOCAL_RANK", 0))
6361

6462
train_metric = []
6563
train_loss = []
@@ -89,8 +87,6 @@ def train(
8987
tensorboard_log_dir = train_config.output_dir + "/runs/" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
9088
tensorboard_updates = SummaryWriter(log_dir=tensorboard_log_dir)
9189

92-
device_type = torch.device(device).type
93-
9490
if train_config.grad_scaler:
9591
if device.startswith("qaic"):
9692
scaler = QAicGradScaler()
@@ -130,10 +126,11 @@ def train(
130126
continue
131127

132128
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
133-
logger.log_rank_zero(f"train_config.max_train_step: {train_config.max_train_step}")
134-
# stop when the maximum number of training steps is reached
129+
if train_config.max_train_step > 0:
130+
logger.log_rank_zero(f"Max train steps : {train_config.max_train_step}")
135131
if max_steps_reached:
136132
break
133+
137134
epoch_start_time = time.perf_counter()
138135
model.train()
139136

@@ -165,7 +162,6 @@ def train(
165162
continue
166163
total_train_steps += 1
167164

168-
# stop when the maximum number of training steps is reached
169165
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
170166
max_steps_reached = True
171167
break
@@ -223,7 +219,7 @@ def train(
223219
step_metric_val = float(torch.exp(loss.detach().float()))
224220
train_step_metric.append(step_metric_val)
225221

226-
# Accumalate gradients
222+
# Accumulate gradients
227223
complete_accum_steps = (
228224
len(train_dataloader) - len(train_dataloader) % train_config.gradient_accumulation_steps
229225
)
@@ -297,19 +293,6 @@ def train(
297293
if total_loss == 0.0
298294
else total_loss / (step + 1 - num_dummy_samples / train_config.train_batch_size)
299295
)
300-
else:
301-
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
302-
train_epoch_loss = (
303-
0.0
304-
if total_loss == 0.0
305-
else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size))
306-
)
307-
else:
308-
train_epoch_loss = (
309-
0.0
310-
if total_loss == 0.0
311-
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
312-
)
313296
if train_config.task_type == "seq_classification":
314297
metric_val = acc_helper.compute()
315298
acc_helper.reset()
@@ -322,17 +305,6 @@ def train(
322305
# Update the learning rate as needed
323306
lr_scheduler.step()
324307

325-
if train_config.run_validation:
326-
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
327-
model, train_config, eval_dataloader, device
328-
)
329-
if is_rank_zero():
330-
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
331-
332-
if train_config.save_metrics:
333-
val_step_loss.extend(temp_val_loss)
334-
val_step_metric.extend(temp_step_metric)
335-
336308
# saving the adapters after completion of each epoch
337309
if train_config.save_model:
338310
if train_config.enable_ddp:
@@ -342,19 +314,24 @@ def train(
342314
model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")
343315

344316
if train_config.run_validation:
317+
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation(
318+
model, train_config, eval_dataloader, device
319+
)
345320
if eval_epoch_loss < best_val_loss:
346321
best_val_loss = eval_epoch_loss
347322
logger.log_rank_zero(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
323+
324+
if is_rank_zero():
325+
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
326+
327+
if train_config.save_metrics:
328+
val_step_loss.extend(temp_val_loss)
329+
val_step_metric.extend(temp_step_metric)
348330
val_loss.append(float(eval_epoch_loss))
349331
val_metric.append(float(eval_metric))
350-
if train_config.task_type == "seq_classification":
351-
logger.log_rank_zero(
352-
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
353-
)
354-
else:
355-
logger.log_rank_zero(
356-
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
357-
)
332+
logger.log_rank_zero(
333+
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
334+
)
358335

359336
# Saving the results every epoch to plot later
360337
if train_config.save_metrics:
@@ -389,7 +366,7 @@ def train(
389366
return results
390367

391368

392-
def evaluation_helper(model, train_config, eval_dataloader, device):
369+
def evaluation(model, train_config, eval_dataloader, device):
393370
"""
394371
Evaluates the model on the given dataloader
395372
@@ -474,60 +451,3 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
474451
logger.log_rank_zero(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
475452

476453
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
477-
478-
479-
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
480-
# find out the minimum max_seq_length required during fine-tuning (saves memory!)
481-
lengths = [len(d["input_ids"]) for d in data]
482-
longest_seq_length = max(lengths)
483-
longest_seq_ix = lengths.index(longest_seq_length)
484-
return longest_seq_length, longest_seq_ix
485-
486-
487-
def print_model_size(model) -> None:
488-
"""
489-
Print model name, the number of trainable parameters and initialization time.
490-
491-
Args:
492-
model: PyTorch model.
493-
"""
494-
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
495-
logger.log_rank_zero(f"Model has {total_params / 1e6} Million params.")
496-
497-
498-
def print_trainable_parameters(model) -> None:
499-
"""
500-
Print the number of trainable parameters, all params and percentage of trainablke params.
501-
502-
Args:
503-
model: The PyTorch model.
504-
"""
505-
trainable_params, all_param = model.get_nb_trainable_parameters()
506-
logger.log_rank_zero(
507-
f"Trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
508-
)
509-
510-
511-
def save_to_json(
512-
output_filename,
513-
train_step_loss,
514-
train_epoch_loss,
515-
train_step_metric,
516-
train_epoch_metric,
517-
val_step_loss,
518-
val_epoch_loss,
519-
val_step_metric,
520-
val_epoch_metric,
521-
):
522-
metrics_data = {
523-
"train_step_loss": train_step_loss,
524-
"train_epoch_loss": train_epoch_loss,
525-
"train_step_metric": train_step_metric,
526-
"train_epoch_metric": train_epoch_metric,
527-
"val_step_loss": val_step_loss,
528-
"val_epoch_loss": val_epoch_loss,
529-
"val_step_metric": val_step_metric,
530-
"val_epoch_metric": val_epoch_metric,
531-
}
532-
with open(output_filename, "w") as f:
533-
json.dump(metrics_data, f)

0 commit comments

Comments
 (0)