Skip to content

Commit 4d46105

Browse files
mamtsingquic-mamta
authored andcommitted
modify error handling
Signed-off-by: Mamta Singh <[email protected]>
1 parent 1e1519b commit 4d46105

File tree

9 files changed

+56
-50
lines changed

9 files changed

+56
-50
lines changed

QEfficient/cloud/finetune.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def load_model_and_tokenizer(
121121
)
122122

123123
if not hasattr(model, "base_model_prefix"):
124-
logger.raise_runtimeerror("Given huggingface model does not have 'base_model_prefix' attribute.")
124+
logger.raise_error("Given huggingface model does not have 'base_model_prefix' attribute.", RuntimeError)
125125

126126
for param in getattr(model, model.base_model_prefix).parameters():
127127
param.requires_grad = False
@@ -161,8 +161,8 @@ def load_model_and_tokenizer(
161161
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
162162
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
163163
else:
164-
logger.raise_runtimeerror(
165-
"Given model doesn't support gradient checkpointing. Please disable it and run it."
164+
logger.raise_error(
165+
"Given model doesn't support gradient checkpointing. Please disable it and run it.", RuntimeError
166166
)
167167

168168
model = apply_peft(model, train_config, peft_config_file, **kwargs)
@@ -237,8 +237,9 @@ def setup_dataloaders(
237237
if train_config.run_validation:
238238
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
239239
if len(eval_dataloader) == 0:
240-
logger.raise_runtimeerror(
241-
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
240+
logger.raise_error(
241+
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
242+
ValueError,
242243
)
243244
else:
244245
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
@@ -280,8 +281,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
280281
dataset_config = generate_dataset_config(train_config.dataset)
281282
update_config(dataset_config, **kwargs)
282283

283-
logger.prepare_dump_logs(train_config.dump_logs)
284-
logger.setLevel(train_config.log_level)
284+
logger.prepare_dump_logs(train_config.output_dir, train_config.dump_logs, train_config.log_level)
285285

286286
setup_distributed_training(train_config)
287287
setup_seeds(train_config.seed)

QEfficient/finetune/configs/training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class TrainConfig:
9595
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
9696
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
9797

98-
dump_root_dir: str = "mismatches/step_"
9998
opByOpVerifier: bool = False
10099

101100
dump_logs: bool = True

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=Non
3232
module_path, func_name = dataset_config.file, "get_custom_dataset"
3333

3434
if not module_path.endswith(".py"):
35-
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
35+
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
3636

3737
module_path = Path(module_path)
3838
if not module_path.is_file():
39-
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
39+
logger.raise_error(
40+
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
41+
)
4042

4143
module = load_module_from_py_file(module_path.as_posix())
4244
try:
4345
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
44-
except AttributeError:
45-
logger.raise_runtimeerror(
46-
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
46+
except AttributeError as e:
47+
logger.raise_error(
48+
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).",
49+
e,
4750
)
4851

4952

@@ -54,11 +57,13 @@ def get_data_collator(dataset_processer, dataset_config):
5457
module_path, func_name = dataset_config.file, "get_data_collator"
5558

5659
if not module_path.endswith(".py"):
57-
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
60+
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
5861

5962
module_path = Path(module_path)
6063
if not module_path.is_file():
61-
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
64+
logger.raise_error(
65+
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
66+
)
6267

6368
module = load_module_from_py_file(module_path.as_posix())
6469
try:

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2222
delimiter=",",
2323
)
2424
except Exception as e:
25-
logger.raise_runtimeerror(
26-
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
25+
logger.raise_error(
26+
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset.",
27+
e,
2728
)
28-
raise e
2929

3030
self.context_length = context_length
3131
self.tokenizer = tokenizer

QEfficient/finetune/utils/config_utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def update_config(config, **kwargs):
4444
if hasattr(config, param_name):
4545
setattr(config, param_name, v)
4646
else:
47-
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
47+
logger.raise_error(
48+
f"Config '{config_name}' does not have parameter: '{param_name}'", ValueError
49+
)
4850
else:
4951
config_type = type(config).__name__
5052
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
@@ -70,7 +72,7 @@ def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None
7072
else:
7173
config_map = {"lora": (LoraConfig, PeftLoraConfig)}
7274
if train_config.peft_method not in config_map:
73-
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
75+
logger.raise_error(f"Peft config not found: {train_config.peft_method}", RuntimeError)
7476

7577
config_cls, peft_config_cls = config_map[train_config.peft_method]
7678
if config_cls is None:
@@ -119,7 +121,7 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
119121
- Ensures types match expected values (int, float, list, etc.).
120122
"""
121123
if config_type.lower() != "lora":
122-
raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.")
124+
logger.raise_error(f"Unsupported config_type: {config_type}. Only 'lora' is supported.", ValueError)
123125

124126
required_fields = {
125127
"r": int,
@@ -136,26 +138,28 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
136138
# Check for missing required fields
137139
missing_fields = [field for field in required_fields if field not in config_data]
138140
if missing_fields:
139-
raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}")
141+
logger.raise_error(f"Missing required fields in {config_type} config: {missing_fields}", ValueError)
140142

141143
# Validate types of required fields
142144
for field, expected_type in required_fields.items():
143145
if not isinstance(config_data[field], expected_type):
144-
raise ValueError(
146+
logger.raise_error(
145147
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
146-
f"got {type(config_data[field]).__name__}"
148+
f"got {type(config_data[field]).__name__}",
149+
ValueError,
147150
)
148151

149152
# Validate target_modules contains strings
150153
if not all(isinstance(mod, str) for mod in config_data["target_modules"]):
151-
raise ValueError("All elements in 'target_modules' must be strings")
154+
logger.raise_error("All elements in 'target_modules' must be strings", ValueError)
152155

153156
# Validate types of optional fields if present
154157
for field, expected_type in optional_fields.items():
155158
if field in config_data and not isinstance(config_data[field], expected_type):
156-
raise ValueError(
159+
logger.raise_error(
157160
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
158-
f"got {type(config_data[field]).__name__}"
161+
f"got {type(config_data[field]).__name__}",
162+
ValueError,
159163
)
160164

161165

@@ -173,12 +177,12 @@ def load_config_file(config_path: str) -> Dict[str, Any]:
173177
ValueError: If the file format is unsupported.
174178
"""
175179
if not os.path.exists(config_path):
176-
raise FileNotFoundError(f"Config file not found: {config_path}")
180+
logger.raise_error(f"Config file not found: {config_path}", FileNotFoundError)
177181

178182
with open(config_path, "r") as f:
179183
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
180184
return yaml.safe_load(f)
181185
elif config_path.endswith(".json"):
182186
return json.load(f)
183187
else:
184-
raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json")
188+
logger.raise_error("Unsupported config file format. Use .yaml, .yml, or .json", ValueError)

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def get_preprocessed_dataset(
1818
tokenizer, dataset_config, split: str = "train", context_length: int = None
1919
) -> torch.utils.data.Dataset:
2020
if dataset_config.dataset not in DATASET_PREPROC:
21-
raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
21+
logger.raise_error(f"{dataset_config.dataset} is not (yet) implemented", NotImplementedError)
2222

2323
def get_split():
2424
return dataset_config.train_split if split == "train" else dataset_config.test_split
@@ -39,8 +39,9 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
3939
if train_config.enable_ddp:
4040
if train_config.enable_sorting_for_ddp:
4141
if train_config.context_length:
42-
raise ValueError(
43-
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
42+
logger.raise_error(
43+
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding",
44+
ValueError,
4445
)
4546
else:
4647
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(

QEfficient/finetune/utils/logging_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010
from datetime import datetime
1111

1212
from QEfficient.finetune.utils.helper import is_rank_zero
13-
from QEfficient.utils.constants import ROOT_DIR
1413

1514

1615
class FTLogger:
17-
def __init__(self, level=logging.DEBUG):
16+
def __init__(self):
1817
self.logger = logging.getLogger("QEfficient")
1918
if not getattr(self.logger, "_custom_methods_added", False):
2019
self._bind_custom_methods()
2120
self.logger._custom_methods_added = True # Prevent adding handlers/methods twice
2221

2322
def _bind_custom_methods(self):
24-
def raise_runtimeerror(message):
23+
def raise_error(message, errortype=RuntimeError):
2524
self.logger.error(message)
26-
raise RuntimeError(message)
25+
raise errortype(message)
2726

2827
def log_rank_zero(msg: str, level: int = logging.INFO):
2928
if not is_rank_zero:
3029
return
3130
self.logger.log(level, msg, stacklevel=2)
3231

33-
def prepare_dump_logs(dump_logs=False, level=logging.INFO):
32+
def prepare_dump_logs(output_path, dump_logs=False, level=logging.INFO):
33+
self.logger.setLevel(level)
3434
if dump_logs:
35-
logs_path = os.path.join(ROOT_DIR, "logs")
35+
logs_path = os.path.join(output_path, "logs")
3636
if not os.path.exists(logs_path):
3737
os.makedirs(logs_path, exist_ok=True)
3838
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
@@ -44,7 +44,7 @@ def prepare_dump_logs(dump_logs=False, level=logging.INFO):
4444
fh.setFormatter(formatter)
4545
self.logger.addHandler(fh)
4646

47-
self.logger.raise_runtimeerror = raise_runtimeerror
47+
self.logger.raise_error = raise_error
4848
self.logger.log_rank_zero = log_rank_zero
4949
self.logger.prepare_dump_logs = prepare_dump_logs
5050

QEfficient/finetune/utils/parser.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,14 @@ def get_finetune_parser():
254254
action="store_true",
255255
help="Enable distributed data parallel training. This will load the replicas of model on given number of devices and train the model. This should be used using torchrun interface. Please check docs for exact usage.",
256256
)
257-
parser.add_argument(
258-
"--dump_root_dir",
259-
"--dump-root-dir",
260-
required=False,
261-
type=str,
262-
default="mismatches/step_",
263-
help="Directory for mismatch dumps by opByOpVerifier",
264-
)
265257
parser.add_argument(
266258
"--opByOpVerifier",
267259
action="store_true",
268-
help="Enable operation-by-operation verification w.r.t reference device(cpu). It is a context manager interface that captures and verifies each operator against reference device. In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.",
260+
help=argparse.SUPPRESS,
261+
# This is for debugging purpose only.
262+
# Enables operation-by-operation verification w.r.t reference device(cpu).
263+
# It is a context manager interface that captures and verifies each operator against reference device.
264+
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.
269265
)
270266

271267
return parser

QEfficient/finetune/utils/train_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ def train(
8585
max_steps_reached = False # Flag to indicate max training steps reached
8686

8787
tensorboard_updates = None
88+
tensorboard_log_dir = train_config.output_dir + "/runs/" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
8889
if is_rank_zero():
89-
tensorboard_updates = SummaryWriter()
90+
tensorboard_updates = SummaryWriter(log_dir=tensorboard_log_dir)
9091

9192
device_type = torch.device(device).type
9293

@@ -181,7 +182,7 @@ def train(
181182
atol=1e-1,
182183
use_ref_output_on_mismatch=True,
183184
filter_config=qaic_debug.DispatchFilterConfig.default(device),
184-
dump_root_dir=train_config.dump_root_dir + str(step),
185+
dump_root_dir=train_config.output_dir + "/mismatches/step_" + str(step),
185186
) as verifier:
186187
model_outputs = model(**batch)
187188
loss = model_outputs.loss # Forward call

0 commit comments

Comments
 (0)