4
4
# SPDX-License-Identifier: BSD-3-Clause
5
5
#
6
6
# -----------------------------------------------------------------------------
7
+
7
8
from dataclasses import dataclass
8
9
9
10
@@ -16,10 +17,13 @@ class TrainConfig:
16
17
model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B").
17
18
tokenizer_name (str): Name of the tokenizer (defaults to model_name if None).
18
19
run_validation (bool): Whether to run validation during training (default: True).
19
- batch_size_training (int): Batch size for training (default: 1).
20
+ train_batch_size (int): Batch size for training (default: 1).
21
+ val_batch_size (int): Batch size for validation (default: 1).
20
22
context_length (Optional[int]): Maximum sequence length for inputs (default: None).
21
23
gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4).
22
24
gradient checkpointing (bool): Enable gradient checkpointing to save the memory by compromising the speed. (default: False).
25
+ use_autocast (bool): Use autocast for mixed precision (default: True).
26
+ grad_scaler (bool): Use gradient scaler (default: True).
23
27
num_epochs (int): Number of training epochs (default: 1).
24
28
max_train_step (int): Maximum training steps (default: 0, unlimited if 0).
25
29
max_eval_step (int): Maximum evaluation steps (default: 0, unlimited if 0).
@@ -29,17 +33,12 @@ class TrainConfig:
29
33
weight_decay (float): Weight decay for optimizer (default: 0.0).
30
34
gamma (float): Learning rate decay factor (default: 0.85).
31
35
seed (int): Random seed for reproducibility (default: 42).
32
- use_fp16 (bool): Use mixed precision training (default: True).
33
- use_autocast (bool): Use autocast for mixed precision (default: True).
34
- val_batch_size (int): Batch size for validation (default: 1).
35
36
dataset (str): Dataset name for training (default: "samsum_dataset").
36
37
task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
37
- peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
38
38
use_peft (bool): Whether to use PEFT (default: True).
39
+ peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
39
40
from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
40
41
output_dir (str): Directory to save outputs (default: "meta-llama-samsum").
41
- num_freeze_layers (int): Number of layers to freeze (default: 1).
42
- one_qaic (bool): Use single QAIC device (default: False).
43
42
save_model (bool): Save the trained model (default: True).
44
43
save_metrics (bool): Save training metrics (default: True).
45
44
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
@@ -49,19 +48,20 @@ class TrainConfig:
49
48
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
50
49
use_profiler (bool): Enable profiling (default: False).
51
50
enable_ddp (bool): Enable distributed data parallel (default: False).
52
- dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo").
53
- grad_scaler (bool): Use gradient scaler (default: True).
54
51
dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_").
55
52
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
56
53
"""
57
54
58
55
model_name : str = "meta-llama/Llama-3.2-1B"
59
56
tokenizer_name : str = None # if not passed as an argument, it uses the value of model_name
60
57
run_validation : bool = True
61
- batch_size_training : int = 1
58
+ train_batch_size : int = 1
59
+ val_batch_size : int = 1
62
60
context_length : int = None
63
61
gradient_accumulation_steps : int = 4
64
62
gradient_checkpointing : bool = False
63
+ use_autocast : bool = True
64
+ grad_scaler : bool = True
65
65
num_epochs : int = 1
66
66
max_train_step : int = 0
67
67
max_eval_step : int = 0
@@ -71,21 +71,17 @@ class TrainConfig:
71
71
weight_decay : float = 0.0
72
72
gamma : float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
73
73
seed : int = 42
74
- use_fp16 : bool = True
75
- use_autocast : bool = True
76
- val_batch_size : int = 1
77
- dataset = "samsum_dataset"
78
- task_type = "generation" # "generation" / "seq_classification"
74
+ dataset : str = "alpaca_dataset"
75
+ task_type : str = "generation" # "generation" / "seq_classification"
76
+ use_peft : bool = True # use parameter efficient finetuning
79
77
peft_method : str = "lora"
80
- use_peft : bool = True # use parameter efficient fine tuning
81
- from_peft_checkpoint : str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
82
- output_dir : str = "meta-llama-samsum"
83
- num_freeze_layers : int = 1
84
- one_qaic : bool = False
78
+ from_peft_checkpoint : str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
79
+ output_dir : str = "training_results"
85
80
save_model : bool = True
86
81
save_metrics : bool = True # saves training metrics to a json file for later plotting
87
82
intermediate_step_save : int = 1000
88
83
batching_strategy : str = "packing"
84
+ enable_ddp : bool = False
89
85
enable_sorting_for_ddp : bool = True
90
86
convergence_counter : int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps
91
87
convergence_loss : float = (
@@ -98,10 +94,5 @@ class TrainConfig:
98
94
use_profiler : bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
99
95
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
100
96
101
- # dist-related
102
- enable_ddp : bool = False
103
- dist_backend : str = "cpu:gloo,qaic:qccl,cuda:gloo"
104
-
105
- grad_scaler : bool = True
106
- dump_root_dir : str = "meta-llama-samsum-mismatches/step_"
97
+ dump_root_dir : str = "mismatches/step_"
107
98
opByOpVerifier : bool = False
0 commit comments