-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs.py
68 lines (45 loc) · 2.56 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from configargparse import Parser
def parse_args():
"""Reads configuration file and returns configuration dictionary."""
parser = Parser(
args_for_setting_config_path=["-c", "--cfg", "--config"],
config_arg_is_required=False,
)
parser = add_input_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
return args
def add_input_args(parser):
"""Loads configuration parameters into given configargparse.Parser."""
model_names, datamodule_names = valid_model_and_datamodule_names()
parser.add("--batch_size", type=int,
help="The number of samples to include per batch.")
parser.add("--data_dir", default="data",
help="The name of the directory where data will be saved.")
parser.add("--datamodule", choices=datamodule_names,
help="The name of the DataModule to utilize.")
parser.add("--eval_only", default=False, type=lambda x: bool(strtobool(x)),
help="Whether to skip training and only evaluate the model.")
parser.add("--loss", choices=["cross_entropy", "mse"], default="cross_entropy",
help="The name of the loss function to utilize for optimization.")
parser.add("--lr", type=float,
help="The learning rate to utilize for optimization.")
parser.add("--lr_drop", default=0.1, type=float,
help="The factor by which to drop the LR when using the step scheduler.")
parser.add("--lr_scheduler", choices=["cosine", "cosine_warmup", "linear", "step"], default="step",
help="The name of the LR scheduler to utilize.")
parser.add("--lr_warmup_epochs", default=0, type=int,
help="The number of epochs to warm up using certain schedulers.")
parser.add("--model", choices=model_names,
help="The name of the Model to utilize.")
parser.add("--momentum", default=0.9, type=float,
help="The momentum value to utilize with the SGD optimizer.")
parser.add("--num_workers", default=4, type=int,
help="The number of sub-processes to use for data loading.")
parser.add("--optimizer", choices=["adam", "adamw", "sgd"], default="sgd",
help="The name of the optimizer to utilize.")
parser.add("--seed", default=1, type=int,
help="The random seed to utilize.")
parser.add("--weight_decay", default=1e-4, type=float,
help="The l2-norm regularization to utilize during optimization.")
return parser