Skip to content

Commit 5818728

Browse files
committed
Updated handling of custom dataset in FT. Updated finetune.md readme accordingly.
Signed-off-by: meetkuma <[email protected]>
1 parent eff9472 commit 5818728

File tree

7 files changed

+149
-64
lines changed

7 files changed

+149
-64
lines changed

QEfficient/cloud/finetune.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
2121

2222
from QEfficient.finetune.configs.training import TrainConfig
23+
from QEfficient.finetune.utils.helper import parse_unk_args
2324
from QEfficient.finetune.utils.config_utils import (
2425
generate_dataset_config,
2526
generate_peft_config,
@@ -85,7 +86,7 @@ def setup_seeds(seed: int) -> None:
8586

8687

8788
def load_model_and_tokenizer(
88-
train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs
89+
train_config: TrainConfig, dataset_config: Any, peft_config_file: Optional[str] = None, **kwargs
8990
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
9091
"""Load the pre-trained model and tokenizer from Hugging Face.
9192
@@ -111,7 +112,7 @@ def load_model_and_tokenizer(
111112
model = AutoModelForSequenceClassification.from_pretrained(
112113
pretrained_model_path,
113114
num_labels=dataset_config.num_labels,
114-
attn_implementation="sdpa",
115+
attn_implementation="eager",
115116
torch_dtype=torch.float16,
116117
)
117118

@@ -128,7 +129,7 @@ def load_model_and_tokenizer(
128129
model = AutoModelForCausalLM.from_pretrained(
129130
pretrained_model_path,
130131
use_cache=False,
131-
attn_implementation="sdpa",
132+
attn_implementation="eager",
132133
torch_dtype=torch.float16,
133134
)
134135

@@ -246,13 +247,13 @@ def setup_dataloaders(
246247
return train_dataloader, eval_dataloader, longest_seq_length
247248

248249

249-
def main(peft_config_file: str = None, **kwargs) -> None:
250+
def main(**kwargs) -> None:
250251
"""
251252
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
252253
253254
Args:
254-
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
255-
kwargs: Additional arguments to override TrainConfig.
255+
kwargs: Keyword arguments fetched from CLI to override train config,
256+
dataset config and peft config params.
256257
257258
Example:
258259
.. code-block:: bash
@@ -268,14 +269,14 @@ def main(peft_config_file: str = None, **kwargs) -> None:
268269
--model_name "meta-llama/Llama-3.2-1B" \\
269270
--lr 5e-4
270271
"""
271-
# TODO:Remove TrainConfig() and update_config() as all params are passed in kwargs by parser
272272
train_config = TrainConfig()
273273
update_config(train_config, **kwargs)
274-
dataset_config = generate_dataset_config(train_config.dataset)
275-
update_config(dataset_config, **kwargs)
274+
dataset_config_file = kwargs.pop("dataset_config", None)
275+
dataset_config = generate_dataset_config(train_config.dataset, dataset_config_file)
276276

277277
setup_distributed_training(train_config)
278278
setup_seeds(train_config.seed)
279+
peft_config_file = kwargs.pop("peft_config_file", None)
279280
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
280281

281282
# Create DataLoaders for the training and validation dataset
@@ -308,6 +309,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
308309

309310
if __name__ == "__main__":
310311
parser = get_finetune_parser()
311-
args = parser.parse_args()
312+
args, unk_args = parser.parse_known_args()
313+
unk_args_dict = parse_unk_args(unk_args)
312314
args_dict = vars(args)
313-
main(**args_dict)
315+
main(**args_dict, **unk_args_dict)

QEfficient/finetune/configs/dataset_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,5 @@ class imdb_dataset:
4848
@dataclass
4949
class custom_dataset:
5050
dataset: str = "custom_dataset"
51-
file: str = "dataset/custom_dataset.py"
5251
train_split: str = "train"
5352
test_split: str = "validation"
54-
data_path: str = ""

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,68 @@ def load_module_from_py_file(py_file: str) -> object:
2424

2525

2626
def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=None):
27-
if ":" in dataset_config.file:
28-
module_path, func_name = dataset_config.file.split(":")
27+
if not hasattr(dataset_config, "preproc_file"):
28+
raise RuntimeError("Can not find preproc_file key in dataset_config file.")
29+
30+
if ":" in dataset_config.preproc_file:
31+
module_path, func_name = dataset_config.preproc_file.split(":")
2932
else:
30-
module_path, func_name = dataset_config.file, "get_custom_dataset"
33+
module_path, func_name = dataset_config.preproc_file, "get_custom_dataset"
34+
print(
35+
f"Using '{func_name}' function from "
36+
f"{dataset_config.preproc_file} as preprocessing function in "
37+
"dataset preprocessing."
38+
)
3139

3240
if not module_path.endswith(".py"):
33-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
41+
raise ValueError(f"Custom dataset preprocessing file {module_path} is not a .py file.")
3442

3543
module_path = Path(module_path)
3644
if not module_path.is_file():
37-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
45+
raise FileNotFoundError(f"Custom dataset file {module_path.as_posix()} does not exist or is not a file.")
3846

3947
module = load_module_from_py_file(module_path.as_posix())
4048
try:
4149
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
4250
except AttributeError as e:
4351
print(
44-
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
52+
f"For custom dataset preprocessing, the method ({func_name}) is not "
53+
f"present in the file ({module_path.as_posix()})."
4554
)
4655
raise e
4756

4857

4958
def get_data_collator(dataset_processer, dataset_config):
50-
if ":" in dataset_config.file:
51-
module_path, func_name = dataset_config.file.split(":")
59+
if not hasattr(dataset_config, "collate_file"):
60+
print(
61+
f"Can not find collate_file key in dataset_config file. Using the default data collator function instead."
62+
)
63+
return None
64+
65+
if ":" in dataset_config.collate_file:
66+
module_path, func_name = dataset_config.collate_file.split(":")
5267
else:
53-
module_path, func_name = dataset_config.file, "get_data_collator"
68+
module_path, func_name = dataset_config.collate_file, "get_data_collator"
69+
print(
70+
f"Using '{func_name}' function from {dataset_config.collate_file} as collate_fn in dataset preprocessing."
71+
)
5472

5573
if not module_path.endswith(".py"):
56-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
74+
raise ValueError(f"Custom dataset collate file {module_path} is not a .py file.")
5775

5876
module_path = Path(module_path)
5977
if not module_path.is_file():
60-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
78+
raise FileNotFoundError(
79+
f"Custom dataset collate file {module_path.as_posix()} does not exist or is not a file."
80+
)
6181

6282
module = load_module_from_py_file(module_path.as_posix())
6383
try:
6484
return getattr(module, func_name)(dataset_processer)
6585
except AttributeError:
66-
print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
67-
print("Using the default data_collator instead.")
86+
print(
87+
f"Can not find the function {func_name} in file "
88+
f"({module_path.as_posix()}). Using the default data collator "
89+
"function instead."
90+
)
6891
return None

QEfficient/finetune/utils/config_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import json
1010
import os
1111
from dataclasses import asdict
12-
from typing import Any, Dict
12+
from typing import Any, Dict, Optional
13+
from collections import namedtuple
1314

1415
import yaml
1516
from peft import LoraConfig as PeftLoraConfig
1617

17-
import QEfficient.finetune.configs.dataset_config as datasets
18+
import QEfficient.finetune.configs.dataset_config as qeff_datasets
1819
from QEfficient.finetune.configs.peft_config import LoraConfig
1920
from QEfficient.finetune.configs.training import TrainConfig
2021
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
@@ -84,11 +85,14 @@ def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None
8485
return peft_config
8586

8687

87-
def generate_dataset_config(dataset_name: str) -> Any:
88+
def generate_dataset_config(dataset_name: str, custom_dataset_config: Optional[str] = None) -> Any:
8889
"""Generate a dataset configuration based on the specified dataset.
8990
9091
Args:
9192
dataset_name (str): Name of the dataset to be used for finetuning.
93+
custom_dataset_config (str): Dataset config json file for custom datset.
94+
This file contains dataset specific arguments to be used in dataset
95+
preprocessing step.
9296
9397
Returns:
9498
Any: A dataset configuration object.
@@ -99,7 +103,15 @@ def generate_dataset_config(dataset_name: str) -> Any:
99103
supported_datasets = DATASET_PREPROC.keys()
100104
assert dataset_name in supported_datasets, f"Given dataset '{dataset_name}' is not supported."
101105
# FIXME (Meet): Replace below logic by creating using auto registry of datasets.
102-
dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[dataset_name]()
106+
dataset_config = {k: v for k, v in inspect.getmembers(qeff_datasets)}[dataset_name]()
107+
if dataset_name == "custom_dataset":
108+
custom_dataset_dict = asdict(dataset_config)
109+
custom_dataset_dict_override = load_config_file(custom_dataset_config)
110+
# Override existing and add new params to dataset_config.
111+
custom_dataset_dict.update(custom_dataset_dict_override)
112+
113+
custom_dataset_class = namedtuple("custom_dataset", custom_dataset_dict.keys())
114+
dataset_config = custom_dataset_class(**custom_dataset_dict)
103115
return dataset_config
104116

105117

QEfficient/finetune/utils/helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,10 @@
99
PEFT_METHOD = ["lora"]
1010
DEVICE = ["qaic", "cpu", "cuda"]
1111
BATCHING_STRATEGY = ["padding", "packing"]
12+
13+
14+
def parse_unk_args(unk_args_str):
15+
if len(unk_args_str) % 2 != 0:
16+
raise RuntimeError("Unknown arguments must be in pairs")
17+
unk_args_dict = {unk_args_str[i].replace("--", ""): unk_args_str[i + 1] for i in range(0, len(unk_args_str), 2)}
18+
return unk_args_dict

QEfficient/finetune/utils/parser.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ def get_finetune_parser():
4242
default=None,
4343
help="Name of the tokenizer,if not passed as an argument, it uses the value of model_name",
4444
)
45+
parser.add_argument(
46+
"--peft_config_file",
47+
"--peft-config-file",
48+
type=str,
49+
default=None,
50+
help="Path of PEFT config json file to override the PEFT config params such as lora_r, lora_alpha etc.",
51+
)
52+
parser.add_argument(
53+
"--custom_dataset_config",
54+
"--custom-dataset-config",
55+
type=str,
56+
default=None,
57+
help="Path of custom dataset config json file to override the custom dataset params such as test_split_ratio, test_split etc.",
58+
)
4559
parser.add_argument(
4660
"--run_validation",
4761
"--run-validation",

docs/source/finetune.md

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,38 +75,67 @@ tensorboard --logdir runs/<file> --bind_all
7575
1) Gradient accumulation: By default, gradient accumulation happens for 4 steps. To update this value, command line argument gradient_accumulation_steps has to be passed. (Example: '--gradient_accumulation_steps 8')
7676
2) Gradient Checkpointing: By default, gradient checkpointing is disabled. To enable it, command line argument gradient_accumulation_steps has to be passed.
7777

78-
## Fine-Tuning on custom dataset
7978

80-
To run fine tuning for any user specific dataset, prepare the dataset using the following steps:
81-
82-
1. Create a directory named 'dataset' inside efficient-transformers.
83-
2. Inside this directory, create a file named 'custom_dataset.py'.
84-
3. Inside the newly created efficient-transformers/dataset/custom_dataset.py, define a function named 'get_custom_dataset'.
85-
4. get_custom_dataset() should have following 4 parameters: dataset_config, tokenizer, split, context_length.
86-
5. Inside get_custom_dataset(), user needs to apply prompt and tokenize the dataset accordingly. Please refer the below template on how to define get_custom_dataset().
87-
6. For examples, please refer python files present in [dataset](https://github.com/quic/efficient-transformers/tree/main/QEfficient/finetune/dataset). In case of Samsum dataset, get_preprocessed_samsum() of efficient-transformers/QEfficient/finetune/dataset/samsum_dataset.py is called.
88-
7. In [dataset_config.py](https://github.com/quic/efficient-transformers/blob/main/QEfficient/finetune/configs/dataset_config.py), for custom_dataset class, pass the appropriate value for train_split and test_split. As an alternative, these values can be passed as command line arguments as well with the finetune command. For example "--train_split train".
89-
8. While running fine tuning, pass argument "-–dataset custom_dataset" to finetune on custom dataset.
90-
91-
Template for get_custom_dataset() to be defined inside efficient-transformers/dataset/custom_dataset.py is as follows:
92-
93-
```python
94-
def get_custom_dataset(dataset_config, tokenizer, split, context_length=None):
95-
96-
# load dataset
97-
# based on split, retrieve only the specific portion of the dataset (train or eval) either here or at the last
98-
99-
def apply_prompt_template():
100-
# transform the passed datapoint by applying the prompt on it
101-
102-
def tokenize():
103-
# tokenize the passed datapoint
104-
105-
# define the prompt
106-
# call apply_prompt_template() for each data point:
107-
# dataset = dataset.map(apply_prompt_template ,<other args>)
108-
# call tokenize() for each data point:
109-
# dataset = dataset.map(tokenize, <other args>)
110-
111-
return dataset
112-
```
79+
### 🔧 Steps to Fine-Tune with a Custom Dataset
80+
81+
1. **Launching Fine-Tuning with a Custom Dataset**
82+
Use the following command-line arguments to begin fine-tuning:
83+
```
84+
--dataset custom_dataset --dataset_config data_config.json
85+
```
86+
The `data_config.json` file contains essential parameters used during dataset preprocessing.
87+
88+
2. **Specifying the Preprocessing Function**
89+
- In `data_config.json`, include a `"preproc_file"` key to define the path to your preprocessing Python file.
90+
- To specify a custom function within that file, use the format `"filename.py:function_name"`.
91+
_Example:_
92+
```json
93+
"preproc_file": "disc_preproc.py:get_preprocessed_disc"
94+
```
95+
- Your preprocessing function must follow this structure:
96+
```python
97+
def get_custom_dataset(dataset_config, tokenizer, split, context_length=None):
98+
def apply_prompt_template():
99+
# Apply prompt formatting to each datapoint
100+
101+
def tokenize():
102+
# Tokenize the formatted datapoint
103+
104+
# Apply functions to dataset using map
105+
dataset = dataset.map(apply_prompt_template, ...)
106+
dataset = dataset.map(tokenize, ...)
107+
108+
return dataset
109+
```
110+
111+
3. **Custom Collate Function for Batching**
112+
- When using a batch size greater than 1, you may override the default collate behavior by including a `"collate_file"` key in `data_config.json`.
113+
- Use the same `"file.py:function"` format. If omitted, the default Hugging Face `DataCollatorForSeq2Seq` is used, which pads sequences to the longest length in the batch.
114+
- A custom collate function must have the following signature:
115+
```python
116+
def get_data_collator(tokenizer):
117+
# Define and return a custom collate_fn here
118+
```
119+
120+
4. **Passing Additional Configuration Parameters**
121+
You can add custom arguments in `data_config.json`, which will be accessible via the `dataset_config` argument inside your `get_custom_dataset()` function.
122+
123+
5. **Example `data_config.json` File**
124+
```json
125+
{
126+
"train_split": "train",
127+
"test_split": "test",
128+
"test_split_ratio": 0.15,
129+
"preproc_file": "disc_preprocd.py:get_preprocessed_disc",
130+
"collate_file": "disc_preprocd.py:get_collate_fn_disc",
131+
"disc_style": "sarcasm_more"
132+
}
133+
```
134+
135+
6. **Implementing Custom Preprocessing Logic**
136+
Within your dataset loader function, define `apply_prompt_template()` to manipulate raw data into desired prompt format, and `tokenize()` to convert it into token IDs using the tokenizer.
137+
138+
7. **Reference for Dataset Utilities**
139+
You can refer to existing implementations in the [dataset directory of this repository](https://github.com/quic/efficient-transformers/tree/main/QEfficient/finetune/dataset).
140+
141+
---

0 commit comments

Comments
 (0)