Skip to content

Commit d405a80

Browse files
mamtsingquic-mamta
authored andcommitted
update branch
Signed-off-by: Mamta Singh <[email protected]>
1 parent 740f7c2 commit d405a80

File tree

10 files changed

+165
-102
lines changed

10 files changed

+165
-102
lines changed

QEfficient/cloud/finetune.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.utils.data
1818
from peft import PeftModel, get_peft_model
1919
from torch.optim.lr_scheduler import StepLR
20-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
20+
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
2121

2222
from QEfficient.finetune.configs.training import TrainConfig
2323
from QEfficient.finetune.utils.config_utils import (
@@ -26,18 +26,22 @@
2626
update_config,
2727
)
2828
from QEfficient.finetune.utils.dataset_utils import get_dataloader
29+
from QEfficient.finetune.utils.logging_utils import logger
2930
from QEfficient.finetune.utils.parser import get_finetune_parser
30-
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
31-
from QEfficient.utils._utils import login_and_download_hf_lm
31+
from QEfficient.finetune.utils.train_utils import (
32+
get_longest_seq_length,
33+
print_model_size,
34+
print_trainable_parameters,
35+
train,
36+
)
37+
from QEfficient.utils._utils import hf_download
3238

3339
# Try importing QAIC-specific module, proceed without it if unavailable
3440
try:
3541
import torch_qaic # noqa: F401
3642
except ImportError as e:
37-
print(f"Warning: {e}. Proceeding without QAIC modules.")
38-
43+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
3944

40-
from transformers import AutoModelForSequenceClassification
4145

4246
# Suppress all warnings
4347
warnings.filterwarnings("ignore")
@@ -106,7 +110,8 @@ def load_model_and_tokenizer(
106110
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
107111
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
108112
"""
109-
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
113+
logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
114+
pretrained_model_path = hf_download(train_config.model_name)
110115
if train_config.task_type == "seq_classification":
111116
model = AutoModelForSequenceClassification.from_pretrained(
112117
pretrained_model_path,
@@ -116,7 +121,7 @@ def load_model_and_tokenizer(
116121
)
117122

118123
if not hasattr(model, "base_model_prefix"):
119-
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
124+
logger.raise_runtimeerror("Given huggingface model does not have 'base_model_prefix' attribute.")
120125

121126
for param in getattr(model, model.base_model_prefix).parameters():
122127
param.requires_grad = False
@@ -141,11 +146,10 @@ def load_model_and_tokenizer(
141146
# If there is a mismatch between tokenizer vocab size and embedding matrix,
142147
# throw a warning and then expand the embedding matrix
143148
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
144-
print("WARNING: Resizing embedding matrix to match tokenizer vocab size.")
149+
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logger.WARNING)
145150
model.resize_token_embeddings(len(tokenizer))
146151

147-
# FIXME (Meet): Cover below line inside the logger once it is implemented.
148-
print_model_size(model, train_config)
152+
print_model_size(model)
149153

150154
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
151155
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -157,7 +161,9 @@ def load_model_and_tokenizer(
157161
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
158162
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
159163
else:
160-
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
164+
logger.raise_runtimeerror(
165+
"Given model doesn't support gradient checkpointing. Please disable it and run it."
166+
)
161167

162168
model = apply_peft(model, train_config, peft_config_file, **kwargs)
163169

@@ -192,7 +198,7 @@ def apply_peft(
192198
else:
193199
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
194200
model = get_peft_model(model, peft_config)
195-
model.print_trainable_parameters()
201+
print_trainable_parameters(model)
196202

197203
return model
198204

@@ -217,25 +223,25 @@ def setup_dataloaders(
217223
- Length of longest sequence in the dataset.
218224
219225
Raises:
220-
ValueError: If validation is enabled but the validation set is too small.
226+
RuntimeError: If validation is enabled but the validation set is too small.
221227
222228
Notes:
223229
- Applies a custom data collator if provided by get_custom_data_collator.
224230
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
225231
"""
226232

227233
train_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="train")
228-
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
234+
logger.log_rank_zero(f"Number of Training Set Batches loaded = {len(train_dataloader)}")
229235

230236
eval_dataloader = None
231237
if train_config.run_validation:
232238
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
233239
if len(eval_dataloader) == 0:
234-
raise ValueError(
240+
logger.raise_runtimeerror(
235241
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
236242
)
237243
else:
238-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
244+
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
239245

240246
longest_seq_length, _ = get_longest_seq_length(
241247
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
@@ -274,13 +280,16 @@ def main(peft_config_file: str = None, **kwargs) -> None:
274280
dataset_config = generate_dataset_config(train_config.dataset)
275281
update_config(dataset_config, **kwargs)
276282

283+
logger.prepare_dump_logs(train_config.dump_logs)
284+
logger.setLevel(train_config.log_level)
285+
277286
setup_distributed_training(train_config)
278287
setup_seeds(train_config.seed)
279288
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
280289

281290
# Create DataLoaders for the training and validation dataset
282291
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
283-
print(
292+
logger.log_rank_zero(
284293
f"The longest sequence length in the train data is {longest_seq_length}, "
285294
f"passed context length is {train_config.context_length} and overall model's context length is "
286295
f"{model.config.max_position_embeddings}"

QEfficient/finetune/configs/training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import logging
89
from dataclasses import dataclass
910

1011

@@ -96,3 +97,6 @@ class TrainConfig:
9697

9798
dump_root_dir: str = "mismatches/step_"
9899
opByOpVerifier: bool = False
100+
101+
dump_logs: bool = True
102+
log_level: str = logging.INFO

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import importlib
99
from pathlib import Path
1010

11+
from QEfficient.finetune.utils.logging_utils import logger
12+
1113

1214
def load_module_from_py_file(py_file: str) -> object:
1315
"""
@@ -30,20 +32,19 @@ def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=Non
3032
module_path, func_name = dataset_config.file, "get_custom_dataset"
3133

3234
if not module_path.endswith(".py"):
33-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
35+
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
3436

3537
module_path = Path(module_path)
3638
if not module_path.is_file():
37-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
39+
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
3840

3941
module = load_module_from_py_file(module_path.as_posix())
4042
try:
4143
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
42-
except AttributeError as e:
43-
print(
44+
except AttributeError:
45+
logger.raise_runtimeerror(
4446
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
4547
)
46-
raise e
4748

4849

4950
def get_data_collator(dataset_processer, dataset_config):
@@ -53,16 +54,16 @@ def get_data_collator(dataset_processer, dataset_config):
5354
module_path, func_name = dataset_config.file, "get_data_collator"
5455

5556
if not module_path.endswith(".py"):
56-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
57+
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
5758

5859
module_path = Path(module_path)
5960
if not module_path.is_file():
60-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
61+
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
6162

6263
module = load_module_from_py_file(module_path.as_posix())
6364
try:
6465
return getattr(module, func_name)(dataset_processer)
6566
except AttributeError:
66-
print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
67-
print("Using the default data_collator instead.")
67+
logger.info(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
68+
logger.info("Using the default data_collator instead.")
6869
return None

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from datasets import load_dataset
1111
from torch.utils.data import Dataset
1212

13+
from QEfficient.finetune.utils.logging_utils import logger
14+
1315

1416
class grammar(Dataset):
1517
def __init__(self, tokenizer, csv_name=None, context_length=None):
@@ -20,8 +22,8 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2022
delimiter=",",
2123
)
2224
except Exception as e:
23-
print(
24-
"Loading of grammar dataset failed! Please see [here](https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
25+
logger.raise_runtimeerror(
26+
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
2527
)
2628
raise e
2729

@@ -36,7 +38,7 @@ def convert_to_features(self, example_batch):
3638
# Create prompt and tokenize contexts and questions
3739

3840
if self.print_text:
39-
print("Input Text: ", self.clean_text(example_batch["text"]))
41+
logger.info("Input Text: ", self.clean_text(example_batch["text"]))
4042

4143
input_ = example_batch["input"]
4244
target_ = example_batch["target"]
@@ -71,9 +73,9 @@ def get_dataset(dataset_config, tokenizer, csv_name=None, context_length=None):
7173
"""cover function for handling loading the working dataset"""
7274
"""dataset loading"""
7375
currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
74-
print(f"Loading dataset {currPath}")
76+
logger.info(f"Loading dataset {currPath}")
7577
csv_name = str(currPath)
76-
print(csv_name)
78+
logger.info(csv_name)
7779
dataset = grammar(tokenizer=tokenizer, csv_name=csv_name, context_length=context_length)
7880

7981
return dataset

QEfficient/finetune/eval.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
from utils.train_utils import evaluation, print_model_size
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22+
from QEfficient.finetune.utils.logging_utils import logger
2223

2324
try:
2425
import torch_qaic # noqa: F401
2526

2627
device = "qaic:0"
2728
except ImportError as e:
28-
print(f"Warning: {e}. Moving ahead without these qaic modules.")
29+
logger.warning(f"{e}. Moving ahead without these qaic modules.")
2930
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3031

3132
# Suppress all warnings
@@ -77,25 +78,24 @@ def main(**kwargs):
7778
# If there is a mismatch between tokenizer vocab size and embedding matrix,
7879
# throw a warning and then expand the embedding matrix
7980
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
80-
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
81+
logger.warning("Resizing the embedding matrix to match the tokenizer vocab size.")
8182
model.resize_token_embeddings(len(tokenizer))
8283

83-
print_model_size(model, train_config)
84+
print_model_size(model)
8485

8586
if train_config.run_validation:
8687
# TODO: vbaddi enable packing later in entire infra.
8788
# if train_config.batching_strategy == "packing":
8889
# dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
8990

9091
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")
91-
92-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
92+
logger.log_rank_zero(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
9393
if len(eval_dataloader) == 0:
9494
raise ValueError(
9595
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
9696
)
9797
else:
98-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
98+
logger.log_rank_zero(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
9999

100100
model.to(device)
101101
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)

QEfficient/finetune/utils/config_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from QEfficient.finetune.configs.peft_config import LoraConfig
1919
from QEfficient.finetune.configs.training import TrainConfig
2020
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
21+
from QEfficient.finetune.utils.logging_utils import logger
2122

2223

2324
def update_config(config, **kwargs):
@@ -46,8 +47,7 @@ def update_config(config, **kwargs):
4647
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
4748
else:
4849
config_type = type(config).__name__
49-
# FIXME (Meet): Once logger is available put this in debug level.
50-
print(f"[WARNING]: Unknown parameter '{k}' for config type '{config_type}'")
50+
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
5151

5252

5353
def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any:

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
1313
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
14+
from QEfficient.finetune.utils.logging_utils import logger
1415

1516

1617
def get_preprocessed_dataset(
@@ -72,7 +73,7 @@ def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"
7273
print("custom_data_collator is used")
7374
dl_kwargs["collate_fn"] = custom_data_collator
7475

75-
print(f"length of dataset_{split}", len(dataset))
76+
logger.log_rank_zero(f"Length of {split} dataset is {len(dataset)}")
7677

7778
# Create data loader
7879
dataloader = torch.utils.data.DataLoader(
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import logging
9+
import os
10+
from datetime import datetime
11+
12+
import torch.distributed as dist
13+
14+
from QEfficient.utils.constants import ROOT_DIR
15+
16+
17+
class FTLogger:
18+
def __init__(self, level=logging.DEBUG):
19+
self.logger = logging.getLogger("QEfficient")
20+
if not getattr(self.logger, "_custom_methods_added", False):
21+
self._bind_custom_methods()
22+
self.logger._custom_methods_added = True # Prevent adding handlers/methods twice
23+
24+
def _bind_custom_methods(self):
25+
def raise_runtimeerror(message):
26+
self.logger.error(message)
27+
raise RuntimeError(message)
28+
29+
def log_rank_zero(msg: str, level: int = logging.INFO):
30+
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
31+
if rank != 0:
32+
return
33+
self.logger.log(level, msg, stacklevel=2)
34+
35+
def prepare_dump_logs(dump_logs=False, level=logging.INFO):
36+
if dump_logs:
37+
logs_path = os.path.join(ROOT_DIR, "logs")
38+
if not os.path.exists(logs_path):
39+
os.makedirs(logs_path, exist_ok=True)
40+
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
41+
log_file = os.path.join(logs_path, file_name)
42+
43+
fh = logging.FileHandler(log_file)
44+
fh.setLevel(level)
45+
formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s")
46+
fh.setFormatter(formatter)
47+
self.logger.addHandler(fh)
48+
49+
self.logger.raise_runtimeerror = raise_runtimeerror
50+
self.logger.log_rank_zero = log_rank_zero
51+
self.logger.prepare_dump_logs = prepare_dump_logs
52+
53+
def get_logger(self):
54+
return self.logger
55+
56+
57+
logger = FTLogger().get_logger()

0 commit comments

Comments
 (0)