Skip to content

Commit 501a896

Browse files
committedMay 8, 2020
[refactor] Remove args totally from the picture (#41)
1 parent 99decc3 commit 501a896

File tree

7 files changed

+72
-376
lines changed

7 files changed

+72
-376
lines changed
 

‎pythia/common/defaults/configs/base.yml

+55-45
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,65 @@
22
training:
33
# Name of the trainer class used to define the training/evalution loop
44
trainer: 'base_trainer'
5+
# Seed to be used for training. -1 means random seed between 1 and 100000.
6+
# Either pass fixed through your config or command line arguments
7+
seed: null
58
# Name of the experiment, will be used while saving checkpoints
69
# and generating reports
710
experiment_name: run
811
# Maximum number of iterations the training will run
912
max_updates: 22000
10-
# Maximum epochs in case you don't want to use iterations
13+
# Maximum epochs in case you don't want to use max_updates
1114
# Can be mixed with max iterations, so it will stop whichever is
1215
# completed first. Default: null means epochs won't be used
1316
max_epochs: null
17+
18+
# Type of run, train+inference by default means both training and inference
19+
# (test) stage will be run, if run_type contains 'val',
20+
# inference will be run on val set also.
21+
run_type: train_inference
22+
23+
# Directory for saving checkpoints and other metadata
24+
save_dir: "./save"
25+
1426
# After `log_interval` iterations, current iteration's training loss and
1527
# metrics will be reported. This will also report validation
1628
# loss and metrics on a single batch from validation set
1729
# to provide an estimate on validation side
1830
log_interval: 100
31+
# Directory for saving logs, default is "logs" inside the save folder
32+
# If log_dir is specifically passed, logs will be written inside that folder
33+
log_dir: null
34+
# Level of logging, only logs which are >= to current level will be logged
35+
logger_level: info
36+
# Log format: json, simple
37+
log_format: simple
38+
# Whether to log detailed final configuration parameters
39+
log_detailed_config: false
40+
# Whether Pythia should log or not, Default: False, which means
41+
# pythia will log by default
42+
should_not_log: false
43+
44+
# Tensorboard control, by default tensorboard is disabled
45+
tensorboard: false
46+
# Log directory for tensorboard, default points to same as logs
47+
tensorboard_logdir: null
48+
49+
# Size of each batch. If distributed or data_parallel
50+
# is used, this will be divided equally among GPUs
51+
batch_size: 512
52+
# Number of workers to be used in dataloaders
53+
num_workers: 4
54+
# Some datasets allow fast reading by loading everything in the memory
55+
# Use this to enable it
56+
fast_read: false
57+
# Whether JSON files for evalai evaluation should be generated
58+
evalai_inference: false
59+
# Use in multi-tasking, when you want to sample tasks proportional to their sizes
60+
dataset_size_proportional_sampling: true
61+
# Whether to pin memory in dataloader
62+
pin_memory: false
63+
1964
# After `checkpoint_interval` iterations, pythia will make a snapshot
2065
# which will involve creating a checkpoint for current training scenarios
2166
checkpoint_interval: 1000
@@ -25,18 +70,6 @@ training:
2570
clip_gradients: false
2671
# Mode for clip norm
2772
clip_norm_mode: all
28-
# Tensorboard control
29-
tensorboard: false
30-
tensorboard_logdir: null
31-
32-
# Seed to be used for training. -1 means random seed.
33-
# Either pass fixed through your config or command line arguments
34-
seed: null
35-
# Size of each batch. If distributed or data_parallel
36-
# is used, this will be divided equally among GPUs
37-
batch_size: 512
38-
# Number of workers to be used in dataloaders
39-
num_workers: 4
4073

4174
# Whether to use early stopping, (Default: false)
4275
should_early_stop: false
@@ -66,44 +99,20 @@ training:
6699
# Iteration until which warnup should be done
67100
warmup_iterations: 1000
68101

69-
# Type of run, train+inference by default means both training and inference
70-
# (test) stage will be run, if run_type contains 'val',
71-
# inference will be run on val set also.
72-
run_type: train+inference
73-
# Level of logging, only logs which are >= to current level will be logged
74-
logger_level: info
75-
76-
device: cuda
77-
78102
# Local rank of the GPU device
103+
device: cuda
79104
local_rank: null
80105

81-
# Whether JSON files for evalai evaluation should be generated
82-
evalai_inference: false
83106
# Use to load specific modules from checkpoint to your model,
84107
# this is helpful in finetuning. for e.g. you can specify
85108
# text_embeddings: text_embedding_pythia
86109
# for loading `text_embedding` module of your model
87110
# from `text_embedding_pythia`
88111
pretrained_mapping: {}
89-
# Whether the above mentioned pretrained mapping should be loaded or not
112+
# If using a pretrained model. Must be used with --resume_file parameter
113+
# to specify pretrained model checkpoint. Will load only specific layers if
114+
# pretrained mapping is specified in config
90115
load_pretrained: false
91-
92-
# Directory for saving checkpoints and other metadata
93-
save_dir: "./save"
94-
# Directory for saving logs
95-
log_dir: "./logs"
96-
# Log format: json, simple
97-
log_format: simple
98-
# Whether to log detailed final configuration parameters
99-
log_detailed_config: false
100-
# Whether Pythia should log or not, Default: False, which means
101-
# pythia will log by default
102-
should_not_log: false
103-
104-
# If verbose dump is active, pythia will dump dataset, model specific
105-
# information which can be useful in debugging
106-
verbose_dump: false
107116
# If resume is true, pythia will try to load automatically load
108117
# last of same parameters from save_dir
109118
resume: false
@@ -112,14 +121,15 @@ training:
112121
# `resume_best` will load the best checkpoint according to monitored metric instead of
113122
# the last saved ckpt
114123
resume_best: false
115-
# Whether to pin memory in dataloader
116-
pin_memory: false
117124

118-
# Use in multi-tasking, when you want to sample tasks proportional to their sizes
119-
dataset_size_proportional_sampling: true
125+
# If verbose dump is active, pythia will dump dataset, model specific
126+
# information which can be useful in debugging
127+
verbose_dump: false
120128

129+
# Turn on if you want to ignore unused parameters in case of DDP
121130
find_unused_parameters: false
122131

132+
123133
# Configuration for models, default configuration files for various models
124134
# included in pythia can be found under configs directory in root folder
125135
model_config: {}

‎pythia/datasets/multi_dataset.py

-34
Original file line numberDiff line numberDiff line change
@@ -209,40 +209,6 @@ def update_registry_for_model(self, config):
209209
for builder in self._builders:
210210
builder.update_registry_for_model(config)
211211

212-
def init_args(self, parser):
213-
parser.add_argument_group("General MultiDataset Arguments")
214-
parser.add_argument(
215-
"-dsp",
216-
"--dataset_size_proportional_sampling",
217-
type=bool,
218-
default=0,
219-
help="Pass if you want to sample from"
220-
" dataset according to its size. Default: Equal "
221-
" weighted sampling",
222-
)
223-
224-
# TODO: Figure out later if we want to init args from datasets
225-
# self._init_args(parser)
226-
227-
def _init_args(self, parser):
228-
"""Override this function to add extra parameters to
229-
parser in your child task class.
230-
231-
Parameters
232-
----------
233-
parser : ArgumentParser
234-
Original parser object passed from the higher level classes like
235-
trainer
236-
237-
Returns
238-
-------
239-
type
240-
Description of returned object.
241-
242-
"""
243-
for builder in self._builders:
244-
builder.init_args(parser)
245-
246212
def clean_config(self, config):
247213
"""
248214
Override this in case you want to clean the config you updated earlier

‎pythia/datasets/vqa/vqa2/builder.py

-16
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,6 @@ def update_registry_for_model(self, config):
4949
self.dataset.answer_processor.get_vocab_size(),
5050
)
5151

52-
def init_args(self, parser):
53-
parser.add_argument_group("VQA2 task specific arguments")
54-
parser.add_argument(
55-
"--data_root_dir",
56-
type=str,
57-
default="../data",
58-
help="Root directory for data",
59-
)
60-
parser.add_argument(
61-
"-nfr",
62-
"--fast_read",
63-
type=bool,
64-
default=None,
65-
help="Disable fast read and load features on fly",
66-
)
67-
6852
def set_dataset_class(self, cls):
6953
self.dataset_class = cls
7054

‎pythia/models/base_model.py

-4
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,6 @@ def init_losses_and_metrics(self):
105105
def config_path(cls):
106106
return None
107107

108-
@classmethod
109-
def init_args(cls, parser):
110-
return parser
111-
112108
@classmethod
113109
def format_state_key(cls, key):
114110
"""Can be implemented if something special needs to be done

‎pythia/utils/configuration.py

+15-46
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,14 @@ def __init__(self, args):
134134
user_config = self._build_user_config(opts_config)
135135
model_config = self._build_model_config(opts_config)
136136
dataset_config = self._build_dataset_config(opts_config)
137-
args_overrides = self._build_args_overrides(args)
137+
args_overrides = self._build_demjson_config(args.config_override)
138138

139139
self._default_config = default_config
140140
self._user_config = user_config
141141
self.config = OmegaConf.merge(
142142
default_config, model_config, dataset_config, user_config, args_overrides
143143
)
144144

145-
# TODO: Remove in next iteration
146-
self.config = self._update_with_args(self.config, args)
147-
148145
self.config = self._merge_with_dotlist(self.config, args.opts)
149146
self._update_specific(self.config)
150147

@@ -167,29 +164,24 @@ def _build_user_config(self, opts):
167164

168165
return user_config
169166

170-
def _build_args_overrides(self, args):
171-
# Update with demjson if passed
172-
demjson_config = self._get_demjson_config(args.config_override)
173-
# TODO: Remove in next iteration
174-
args_config = self._get_args_config(args)
175-
return OmegaConf.merge(demjson_config, args_config)
176-
177167
def _build_model_config(self, config):
178168
model = config.model
179169
if model is None:
180170
raise KeyError("Required argument 'model' not passed")
181171
model_cls = registry.get_model_class(model)
182172

183173
if model_cls is None:
184-
warnings.warn("No model named '{}' has been registered".format(model))
174+
warning = "No model named '{}' has been registered".format(model)
175+
warnings.warn(warning)
185176
return OmegaConf.create()
186177

187178
default_model_config_path = model_cls.config_path()
188179

189180
if default_model_config_path is None:
190-
warnings.warn(
191-
"Model {}'s class has no default configuration provided".format(model)
181+
warning = "Model {}'s class has no default configuration provided".format(
182+
model
192183
)
184+
warnings.warn(warning)
193185
return OmegaConf.create()
194186

195187
return load_yaml(default_model_config_path)
@@ -213,17 +205,15 @@ def _build_dataset_config(self, config):
213205
builder_cls = registry.get_builder_class(dataset)
214206

215207
if builder_cls is None:
216-
warnings.warn(
217-
"No dataset named '{}' has been registered".format(dataset)
218-
)
208+
warning = "No dataset named '{}' has been registered".format(dataset)
209+
warnings.warn(warning)
219210
continue
220211
default_dataset_config_path = builder_cls.config_path()
221212
if default_dataset_config_path is None:
222-
warnings.warn(
223-
"Dataset {}'s builder class has no default configuration provided".format(
224-
dataset
225-
)
213+
warning = "Dataset {}'s builder class has no default configuration provided".format(
214+
dataset
226215
)
216+
warnings.warn(warning)
227217
continue
228218
dataset_config = OmegaConf.merge(
229219
dataset_config, load_yaml(default_dataset_config_path)
@@ -235,16 +225,7 @@ def get_config(self):
235225
self._register_resolvers()
236226
return self.config
237227

238-
def _update_with_args(self, config, args, force=False):
239-
args_dict = vars(args)
240-
241-
self._update_key(config, args_dict)
242-
if force is True:
243-
config.update(args_dict)
244-
245-
return config
246-
247-
def _get_demjson_config(self, demjson_string):
228+
def _build_demjson_config(self, demjson_string):
248229
if demjson_string is None:
249230
return OmegaConf.create()
250231

@@ -350,6 +331,9 @@ def freeze(self):
350331
# self.config = ConfigNode(self.config)
351332
OmegaConf.set_struct(self.config, True)
352333

334+
def defrost(self):
335+
OmegaConf.set_struct(self.config, False)
336+
353337
def _convert_to_dot_list(self, opts):
354338
if opts is None:
355339
opts = []
@@ -414,21 +398,6 @@ def _get_default_config_path(self):
414398
directory, "..", "common", "defaults", "configs", "base.yml"
415399
)
416400

417-
def _update_key(self, dictionary, update_dict):
418-
"""
419-
Takes a single depth dictionary update_dict and uses it to
420-
update 'dictionary' whenever key in 'update_dict' is found at
421-
any level in 'dictionary'
422-
"""
423-
for key, value in dictionary.items():
424-
if not isinstance(value, collections.abc.Mapping):
425-
if key in update_dict and update_dict[key] is not None:
426-
dictionary[key] = update_dict[key]
427-
else:
428-
dictionary[key] = self._update_key(value, update_dict)
429-
430-
return dictionary
431-
432401
def _update_specific(self, config):
433402
self.writer = registry.get("writer")
434403
tp = self.config.training

‎pythia/utils/flags.py

-230
Original file line numberDiff line numberDiff line change
@@ -11,255 +11,25 @@ class Flags:
1111
def __init__(self):
1212
self.parser = argparse.ArgumentParser()
1313
self.add_core_args()
14-
self.update_model_args()
1514

1615
def get_parser(self):
1716
return self.parser
1817

1918
def add_core_args(self):
20-
# TODO: Update default values
2119
self.parser.add_argument_group("Core Arguments")
22-
23-
self.parser.add_argument(
24-
"--config", type=str, default=None, required=False, help="config yaml file"
25-
)
26-
27-
self.parser.add_argument(
28-
"--tasks", type=str, default="", help="Tasks for training"
29-
)
30-
self.parser.add_argument(
31-
"--datasets",
32-
type=str,
33-
required=False,
34-
default="all",
35-
help="Datasets to be used for required task",
36-
)
37-
# self.parser.add_argument(
38-
# "--model", type=str, default="", help="Model for training"
39-
# )
40-
self.parser.add_argument(
41-
"--run_type",
42-
type=str,
43-
default=None,
44-
help="Type of run. Default=train+predict",
45-
)
46-
self.parser.add_argument(
47-
"-exp",
48-
"--experiment_name",
49-
type=str,
50-
default=None,
51-
help="Name of the experiment",
52-
)
53-
54-
self.parser.add_argument(
55-
"--seed",
56-
type=int,
57-
default=None,
58-
help="random seed, default None, meaning nothing will be seeded"
59-
" set seed to -1 if need a random seed"
60-
" between 1 and 100000",
61-
)
62-
self.parser.add_argument(
63-
"--config_overwrite",
64-
type=str,
65-
help="a json string to update yaml config file",
66-
default=None,
67-
)
68-
69-
self.parser.add_argument(
70-
"--force_restart",
71-
action="store_true",
72-
help="flag to force clean previous result and restart training",
73-
)
74-
self.parser.add_argument(
75-
"--log_interval",
76-
type=int,
77-
default=None,
78-
help="Number of iterations after which we should log validation results",
79-
)
80-
self.parser.add_argument(
81-
"--checkpoint_interval",
82-
type=int,
83-
default=None,
84-
help="Number of iterations after which we should save snapshots",
85-
)
86-
self.parser.add_argument(
87-
"--evaluation_interval",
88-
type=int,
89-
default=None,
90-
help="Number of iterations after which we should save snapshots",
91-
)
92-
self.parser.add_argument(
93-
"--max_updates",
94-
type=int,
95-
default=None,
96-
help="Number of iterations after which we should stop training",
97-
)
98-
self.parser.add_argument(
99-
"--max_epochs",
100-
type=int,
101-
default=None,
102-
help="Number of epochs after which "
103-
" we should stop training"
104-
" (mutually exclusive with max_updates)",
105-
)
106-
self.parser.add_argument(
107-
"--batch_size",
108-
type=int,
109-
default=None,
110-
help="Batch size to be used for training "
111-
"If not passed it will default to config one",
112-
)
113-
self.parser.add_argument(
114-
"--save_dir",
115-
type=str,
116-
default="./save",
117-
help="Location for saving model checkpoint",
118-
)
119-
self.parser.add_argument(
120-
"--log_dir", type=str, default=None, help="Location for saving logs"
121-
)
122-
self.parser.add_argument(
123-
"--logger_level", type=str, default=None, help="Level of logging"
124-
)
125-
self.parser.add_argument(
126-
"--log_detailed_config",
127-
type=int,
128-
default=None,
129-
help="Log detailed final configuration parameters",
130-
)
131-
132-
self.parser.add_argument(
133-
"--should_not_log",
134-
action="store_true",
135-
default=False,
136-
help="Set when you don't want logging to happen",
137-
)
13820
self.parser.add_argument(
13921
"-co",
14022
"--config_override",
14123
type=str,
14224
default=None,
14325
help="Use to override config from command line directly",
14426
)
145-
self.parser.add_argument(
146-
"--resume_file",
147-
type=str,
148-
default=None,
149-
help="File from which to resume checkpoint",
150-
)
151-
self.parser.add_argument(
152-
"--resume",
153-
type=bool,
154-
default=None,
155-
help="Use when you want to restore from automatic checkpoint",
156-
)
157-
self.parser.add_argument(
158-
"--resume_best",
159-
type=bool,
160-
default=None,
161-
help="Use when you want to restore from last best checkpoint instead of last ckpt",
162-
)
163-
self.parser.add_argument(
164-
"--evalai_inference",
165-
type=bool,
166-
default=None,
167-
help="Whether predictions should be made for EvalAI.",
168-
)
169-
self.parser.add_argument(
170-
"--verbose_dump",
171-
type=bool,
172-
default=None,
173-
help="Whether to do verbose dump of dataset"
174-
" samples, predictions and other things.",
175-
)
176-
self.parser.add_argument(
177-
"--lr_scheduler",
178-
type=bool,
179-
default=None,
180-
help="Use when you want to use lr scheduler",
181-
)
182-
self.parser.add_argument(
183-
"--clip_gradients",
184-
type=bool,
185-
default=None,
186-
help="Use when you want to clip gradients",
187-
)
188-
self.parser.add_argument(
189-
"--tensorboard", type=bool, default=False, help="Enable tensorboard"
190-
)
191-
self.parser.add_argument(
192-
"--tensorboard_logdir",
193-
type=str,
194-
default=None,
195-
help="Default logdir for tensorboard",
196-
)
197-
198-
self.parser.add_argument(
199-
"-dev",
200-
"--device_id",
201-
type=str,
202-
default=None,
203-
help="Specify device to be used for training",
204-
)
205-
self.parser.add_argument(
206-
"-p", "--patience", type=int, default=None, help="Patience for early stop"
207-
)
208-
self.parser.add_argument(
209-
"-fr",
210-
"--fast_read",
211-
type=bool,
212-
default=None,
213-
help="If fast read should be activated",
214-
)
215-
self.parser.add_argument(
216-
"-pt",
217-
"--load_pretrained",
218-
type=int,
219-
default=None,
220-
help="If using a pretrained model. "
221-
"Must be used with --resume_file parameter "
222-
"to specify pretrained model checkpoint. "
223-
"Will load only specific layers if "
224-
"pretrained mapping is specified in config",
225-
)
226-
227-
self.parser.add_argument(
228-
"-nw",
229-
"--num_workers",
230-
type=int,
231-
default=None,
232-
help="Number of workers for dataloaders",
233-
)
234-
self.parser.add_argument(
235-
"-lr",
236-
"--local_rank",
237-
type=int,
238-
default=None,
239-
help="Local rank of the current node",
240-
)
24127
self.parser.add_argument(
24228
"opts",
24329
default=None,
24430
nargs=argparse.REMAINDER,
24531
help="Modify config options from command line",
24632
)
24733

248-
def update_model_args(self):
249-
args = sys.argv
250-
model_name = None
251-
for index, item in enumerate(args):
252-
if item == "--model":
253-
model_name = args[index + 1]
254-
255-
if model_name is None:
256-
return
257-
258-
model_class = registry.get_model_class(model_name)
259-
if model_class is None:
260-
return
261-
262-
model_class.init_args(self.parser)
263-
26434

26535
flags = Flags()

‎tests/datasets/test_base_dataset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def test_init_processors(self):
1818
args = dummy_args()
1919
args.opts.append("config={}".format(path))
2020
configuration = Configuration(args)
21-
answer_processor = configuration.get_config()["dataset_attributes"]["vqa2"][
21+
print(configuration.get_config())
22+
answer_processor = configuration.get_config()["dataset_config"]["vqa2"][
2223
"processors"
2324
]["answer_processor"]
2425
vocab_path = os.path.join(

0 commit comments

Comments
 (0)
Please sign in to comment.