-
Notifications
You must be signed in to change notification settings - Fork 485
Expand file tree
/
Copy pathclevr_count_70k_sft.py
More file actions
36 lines (28 loc) · 911 Bytes
/
clevr_count_70k_sft.py
File metadata and controls
36 lines (28 loc) · 911 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
from areal import SFTTrainer
from areal.api.cli_args import SFTConfig, load_expr_config
from areal.dataset import get_custom_dataset
from areal.utils.hf_utils import load_hf_processor_and_tokenizer
def main(args):
config, _ = load_expr_config(args, SFTConfig)
processor, tokenizer = load_hf_processor_and_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
split="train",
dataset_config=config.train_dataset,
tokenizer=tokenizer,
processor=processor,
)
valid_dataset = get_custom_dataset(
split="test",
dataset_config=config.valid_dataset,
tokenizer=tokenizer,
processor=processor,
)
with SFTTrainer(
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
) as trainer:
trainer.train()
if __name__ == "__main__":
main(sys.argv[1:])