Skip to content

Commit 1e8039b

Browse files
quic-mamtamamtsing
andauthored
[QEff Finetune]: Enable --help for finetune CLI (#392)
1. Enabled --help for finetune 2. Updated finetune docs 3. removed unused flags --------- Signed-off-by: Mamta Singh <[email protected]> Co-authored-by: Mamta Singh <[email protected]>
1 parent 2514c0b commit 1e8039b

File tree

10 files changed

+376
-183
lines changed

10 files changed

+376
-183
lines changed

QEfficient/cloud/finetune.py

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import warnings
1010
from typing import Any, Dict, Optional, Union
1111

12-
import fire
1312
import numpy as np
1413
import torch
1514
import torch.distributed as dist
@@ -24,13 +23,10 @@
2423
from QEfficient.finetune.utils.config_utils import (
2524
generate_dataset_config,
2625
generate_peft_config,
27-
get_dataloader_kwargs,
2826
update_config,
2927
)
30-
from QEfficient.finetune.utils.dataset_utils import (
31-
get_custom_data_collator,
32-
get_preprocessed_dataset,
33-
)
28+
from QEfficient.finetune.utils.dataset_utils import get_dataloader
29+
from QEfficient.finetune.utils.parser import get_finetune_parser
3430
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3531
from QEfficient.utils._utils import login_and_download_hf_lm
3632

@@ -68,7 +64,8 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
6864
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
6965
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
7066

71-
dist.init_process_group(backend=train_config.dist_backend)
67+
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
68+
dist.init_process_group(backend=dist_backend_map[torch_device.type])
7269
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
7370
getattr(torch, torch_device.type).set_device(dist.get_rank())
7471

@@ -180,7 +177,7 @@ def apply_peft(
180177
kwargs: Additional arguments to override PEFT config params.
181178
182179
Returns:
183-
Union[AutoModel, PeftModel]: If the use_peft in train_config is True
180+
Union[AutoModel, PeftModel]: If use_peft in train_config is True
184181
then PeftModel object is returned else original model object
185182
(AutoModel) is returned.
186183
"""
@@ -226,58 +223,13 @@ def setup_dataloaders(
226223
- Applies a custom data collator if provided by get_custom_data_collator.
227224
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
228225
"""
229-
# Get the dataset utils
230-
dataset_processer = tokenizer
231-
232-
# Load and preprocess the dataset for training and validation
233-
dataset_train = get_preprocessed_dataset(
234-
dataset_processer, dataset_config, split="train", context_length=train_config.context_length
235-
)
236-
237-
dataset_val = get_preprocessed_dataset(
238-
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
239-
)
240226

241-
# TODO: vbaddi, check if its necessary to do this?
242-
# dataset_train = ConcatDataset(
243-
# dataset_train, chunk_size=train_config.context_length
244-
# )
245-
##
246-
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
247-
print("length of dataset_train", len(dataset_train))
248-
249-
# FIXME (Meet): Add custom data collator registration from the outside by the user.
250-
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
251-
if custom_data_collator:
252-
print("custom_data_collator is used")
253-
train_dl_kwargs["collate_fn"] = custom_data_collator
254-
255-
# Create DataLoaders for the training and validation dataset
256-
train_dataloader = torch.utils.data.DataLoader(
257-
dataset_train,
258-
num_workers=train_config.num_workers_dataloader,
259-
pin_memory=True,
260-
**train_dl_kwargs,
261-
)
227+
train_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="train")
262228
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
263229

264230
eval_dataloader = None
265231
if train_config.run_validation:
266-
# if train_config.batching_strategy == "packing":
267-
# dataset_val = ConcatDataset(
268-
# dataset_val, chunk_size=train_config.context_length
269-
# )
270-
271-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
272-
if custom_data_collator:
273-
val_dl_kwargs["collate_fn"] = custom_data_collator
274-
275-
eval_dataloader = torch.utils.data.DataLoader(
276-
dataset_val,
277-
num_workers=train_config.num_workers_dataloader,
278-
pin_memory=True,
279-
**val_dl_kwargs,
280-
)
232+
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
281233
if len(eval_dataloader) == 0:
282234
raise ValueError(
283235
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
@@ -316,6 +268,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
316268
--model_name "meta-llama/Llama-3.2-1B" \\
317269
--lr 5e-4
318270
"""
271+
# TODO:Remove TrainConfig() and update_config() as all params are passed in kwargs by parser
319272
train_config = TrainConfig()
320273
update_config(train_config, **kwargs)
321274
dataset_config = generate_dataset_config(train_config.dataset)
@@ -354,4 +307,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
354307

355308

356309
if __name__ == "__main__":
357-
fire.Fire(main)
310+
parser = get_finetune_parser()
311+
args = parser.parse_args()
312+
args_dict = vars(args)
313+
main(**args_dict)

QEfficient/finetune/configs/peft_config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,3 @@ class LoraConfig:
3030
task_type: str = "CAUSAL_LM"
3131
lora_dropout: float = 0.05
3232
inference_mode: bool = False # should be False for finetuning
33-
34-
35-
# CAUTION prefix tuning is currently not supported
36-
@dataclass
37-
class PrefixConfig:
38-
num_virtual_tokens: int = 30
39-
task_type: str = "CAUSAL_LM"

QEfficient/finetune/configs/training.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
78
from dataclasses import dataclass
89

910

@@ -16,10 +17,13 @@ class TrainConfig:
1617
model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B").
1718
tokenizer_name (str): Name of the tokenizer (defaults to model_name if None).
1819
run_validation (bool): Whether to run validation during training (default: True).
19-
batch_size_training (int): Batch size for training (default: 1).
20+
train_batch_size (int): Batch size for training (default: 1).
21+
val_batch_size (int): Batch size for validation (default: 1).
2022
context_length (Optional[int]): Maximum sequence length for inputs (default: None).
2123
gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4).
2224
gradient checkpointing (bool): Enable gradient checkpointing to save the memory by compromising the speed. (default: False).
25+
use_autocast (bool): Use autocast for mixed precision (default: True).
26+
grad_scaler (bool): Use gradient scaler (default: True).
2327
num_epochs (int): Number of training epochs (default: 1).
2428
max_train_step (int): Maximum training steps (default: 0, unlimited if 0).
2529
max_eval_step (int): Maximum evaluation steps (default: 0, unlimited if 0).
@@ -29,17 +33,12 @@ class TrainConfig:
2933
weight_decay (float): Weight decay for optimizer (default: 0.0).
3034
gamma (float): Learning rate decay factor (default: 0.85).
3135
seed (int): Random seed for reproducibility (default: 42).
32-
use_fp16 (bool): Use mixed precision training (default: True).
33-
use_autocast (bool): Use autocast for mixed precision (default: True).
34-
val_batch_size (int): Batch size for validation (default: 1).
3536
dataset (str): Dataset name for training (default: "samsum_dataset").
3637
task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
37-
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
3838
use_peft (bool): Whether to use PEFT (default: True).
39+
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
3940
from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
4041
output_dir (str): Directory to save outputs (default: "meta-llama-samsum").
41-
num_freeze_layers (int): Number of layers to freeze (default: 1).
42-
one_qaic (bool): Use single QAIC device (default: False).
4342
save_model (bool): Save the trained model (default: True).
4443
save_metrics (bool): Save training metrics (default: True).
4544
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
@@ -49,19 +48,20 @@ class TrainConfig:
4948
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
5049
use_profiler (bool): Enable profiling (default: False).
5150
enable_ddp (bool): Enable distributed data parallel (default: False).
52-
dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo").
53-
grad_scaler (bool): Use gradient scaler (default: True).
5451
dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_").
5552
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
5653
"""
5754

5855
model_name: str = "meta-llama/Llama-3.2-1B"
5956
tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name
6057
run_validation: bool = True
61-
batch_size_training: int = 1
58+
train_batch_size: int = 1
59+
val_batch_size: int = 1
6260
context_length: int = None
6361
gradient_accumulation_steps: int = 4
6462
gradient_checkpointing: bool = False
63+
use_autocast: bool = True
64+
grad_scaler: bool = True
6565
num_epochs: int = 1
6666
max_train_step: int = 0
6767
max_eval_step: int = 0
@@ -71,21 +71,17 @@ class TrainConfig:
7171
weight_decay: float = 0.0
7272
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
7373
seed: int = 42
74-
use_fp16: bool = True
75-
use_autocast: bool = True
76-
val_batch_size: int = 1
77-
dataset = "samsum_dataset"
78-
task_type = "generation" # "generation" / "seq_classification"
74+
dataset: str = "alpaca_dataset"
75+
task_type: str = "generation" # "generation" / "seq_classification"
76+
use_peft: bool = True # use parameter efficient finetuning
7977
peft_method: str = "lora"
80-
use_peft: bool = True # use parameter efficient fine tuning
81-
from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
82-
output_dir: str = "meta-llama-samsum"
83-
num_freeze_layers: int = 1
84-
one_qaic: bool = False
78+
from_peft_checkpoint: str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
79+
output_dir: str = "training_results"
8580
save_model: bool = True
8681
save_metrics: bool = True # saves training metrics to a json file for later plotting
8782
intermediate_step_save: int = 1000
8883
batching_strategy: str = "packing"
84+
enable_ddp: bool = False
8985
enable_sorting_for_ddp: bool = True
9086
convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps
9187
convergence_loss: float = (
@@ -98,10 +94,5 @@ class TrainConfig:
9894
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
9995
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
10096

101-
# dist-related
102-
enable_ddp: bool = False
103-
dist_backend: str = "cpu:gloo,qaic:qccl,cuda:gloo"
104-
105-
grad_scaler: bool = True
106-
dump_root_dir: str = "meta-llama-samsum-mismatches/step_"
97+
dump_root_dir: str = "mismatches/step_"
10798
opByOpVerifier: bool = False

QEfficient/finetune/eval.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import os
89
import random
910
import warnings
1011

@@ -13,15 +14,8 @@
1314
import torch
1415
from peft import AutoPeftModelForCausalLM
1516
from transformers import AutoModelForCausalLM, AutoTokenizer
16-
from utils.config_utils import (
17-
generate_dataset_config,
18-
get_dataloader_kwargs,
19-
update_config,
20-
)
21-
from utils.dataset_utils import (
22-
get_custom_data_collator,
23-
get_preprocessed_dataset,
24-
)
17+
from utils.config_utils import generate_dataset_config, update_config
18+
from utils.dataset_utils import get_dataloader
2519
from utils.train_utils import evaluation, print_model_size
2620

2721
from QEfficient.finetune.configs.training import TrainConfig
@@ -42,18 +36,24 @@ def main(**kwargs):
4236
# update the configuration for the training process
4337
train_config = TrainConfig()
4438
update_config(train_config, **kwargs)
39+
dataset_config = generate_dataset_config(train_config.dataset)
40+
update_config(dataset_config, **kwargs)
4541

4642
# Set the seeds for reproducibility
4743
torch.manual_seed(train_config.seed)
4844
random.seed(train_config.seed)
4945
np.random.seed(train_config.seed)
5046

51-
# Load the pre-trained model and setup its configuration
52-
# config = AutoConfig.from_pretrained(train_config.model_name)
53-
save_dir = "meta-llama-samsum/trained_weights/step_14000"
47+
# Load the pre-trained model from latest checkpoint
48+
trained_weights_path = os.path.join(train_config.output_dir, "trained_weights")
49+
epoch_max_index = max([int(name.split("_")[-1]) for name in os.listdir(trained_weights_path)])
50+
epochs_path = os.path.join(trained_weights_path, "epoch_" + str(epoch_max_index))
51+
step_max_index = max([int(name.split("_")[-1]) for name in os.listdir(epochs_path)])
52+
save_dir = os.path.join(epochs_path, "step_" + str(step_max_index))
5453

5554
# Load PEFT model on CPU
5655
model_peft = AutoPeftModelForCausalLM.from_pretrained(save_dir)
56+
5757
# Merge LoRA and base model and save
5858
merged_model = model_peft.merge_and_unload()
5959
merged_model.save_pretrained(train_config.output_dir, safe_serialization=True)
@@ -82,32 +82,13 @@ def main(**kwargs):
8282

8383
print_model_size(model, train_config)
8484

85-
# Get the dataset utils
86-
dataset_config = generate_dataset_config(train_config, kwargs)
87-
dataset_processer = tokenizer
88-
89-
# Load and preprocess the dataset for training and validation
90-
dataset_val = get_preprocessed_dataset(
91-
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
92-
)
93-
94-
eval_dataloader = None
95-
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
9685
if train_config.run_validation:
9786
# TODO: vbaddi enable packing later in entire infra.
9887
# if train_config.batching_strategy == "packing":
9988
# dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
10089

101-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
102-
if custom_data_collator:
103-
val_dl_kwargs["collate_fn"] = custom_data_collator
90+
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")
10491

105-
eval_dataloader = torch.utils.data.DataLoader(
106-
dataset_val,
107-
num_workers=train_config.num_workers_dataloader,
108-
pin_memory=True,
109-
**val_dl_kwargs,
110-
)
11192
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
11293
if len(eval_dataloader) == 0:
11394
raise ValueError(

0 commit comments

Comments
 (0)