@@ -163,6 +163,116 @@ def get_mmf_env(key=None):
163
163
return config .env
164
164
165
165
166
+ def _merge_with_dotlist (
167
+ config : DictConfig ,
168
+ opts : List [str ],
169
+ skip_missing : bool = False ,
170
+ log_info : bool = True ,
171
+ ):
172
+ # TODO: To remove technical debt, a possible solution is to use
173
+ # struct mode to update with dotlist OmegaConf node. Look into this
174
+ # in next iteration
175
+ # TODO: Simplify this function
176
+ if opts is None :
177
+ opts = []
178
+
179
+ if len (opts ) == 0 :
180
+ return config
181
+
182
+ # Support equal e.g. model=visual_bert for better future hydra support
183
+ has_equal = opts [0 ].find ("=" ) != - 1
184
+ if has_equal :
185
+ opt_values = [opt .split ("=" , maxsplit = 1 ) for opt in opts ]
186
+ if not all (len (opt ) == 2 for opt in opt_values ):
187
+ for opt in opt_values :
188
+ assert len (opt ) == 2 , f"{ opt } has no value"
189
+ else :
190
+ assert len (opts ) % 2 == 0 , "Number of opts should be multiple of 2"
191
+ opt_values = zip (opts [0 ::2 ], opts [1 ::2 ])
192
+
193
+ for opt , value in opt_values :
194
+ if opt == "dataset" :
195
+ opt = "datasets"
196
+
197
+ splits = opt .split ("." )
198
+ current = config
199
+ for idx , field in enumerate (splits ):
200
+ array_index = - 1
201
+ if field .find ("[" ) != - 1 and field .find ("]" ) != - 1 :
202
+ stripped_field = field [: field .find ("[" )]
203
+ array_index = int (field [field .find ("[" ) + 1 : field .find ("]" )])
204
+ else :
205
+ stripped_field = field
206
+ if stripped_field not in current :
207
+ if skip_missing is True :
208
+ break
209
+ raise AttributeError (
210
+ "While updating configuration"
211
+ " option {} is missing from"
212
+ " configuration at field {}" .format (opt , stripped_field )
213
+ )
214
+ if isinstance (current [stripped_field ], collections .abc .Mapping ):
215
+ current = current [stripped_field ]
216
+ elif (
217
+ isinstance (current [stripped_field ], collections .abc .Sequence )
218
+ and array_index != - 1
219
+ ):
220
+ try :
221
+ current_value = current [stripped_field ][array_index ]
222
+ except OCErrors .ConfigIndexError :
223
+ if skip_missing :
224
+ break
225
+ raise
226
+
227
+ # Case where array element to be updated is last element
228
+ if (
229
+ not isinstance (
230
+ current_value ,
231
+ (collections .abc .Mapping , collections .abc .Sequence ),
232
+ )
233
+ or idx == len (splits ) - 1
234
+ ):
235
+ if log_info :
236
+ logger .info (f"Overriding option { opt } to { value } " )
237
+ current [stripped_field ][array_index ] = _decode_value (value )
238
+ else :
239
+ # Otherwise move on down the chain
240
+ current = current_value
241
+ else :
242
+ if idx == len (splits ) - 1 :
243
+ if log_info :
244
+ logger .info (f"Overriding option { opt } to { value } " )
245
+ current [stripped_field ] = _decode_value (value )
246
+ else :
247
+ if skip_missing :
248
+ break
249
+
250
+ raise AttributeError (
251
+ "While updating configuration" ,
252
+ "option {} is not present "
253
+ "after field {}" .format (opt , stripped_field ),
254
+ )
255
+
256
+ return config
257
+
258
+
259
+ def _decode_value (value ):
260
+ # https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L400
261
+ if not isinstance (value , str ):
262
+ return value
263
+
264
+ if value == "None" :
265
+ value = None
266
+
267
+ try :
268
+ value = literal_eval (value )
269
+ except ValueError :
270
+ pass
271
+ except SyntaxError :
272
+ pass
273
+ return value
274
+
275
+
166
276
def resolve_cache_dir (env_variable = "MMF_CACHE_DIR" , default = "mmf" ):
167
277
# Some of this follow what "transformers" does for there cache resolving
168
278
try :
@@ -217,7 +327,7 @@ def __init__(self, args=None, default_only=False):
217
327
218
328
# Initially, silently add opts so that some of the overrides for the defaults
219
329
# from command line required for setup can be honored
220
- self ._default_config = self . _merge_with_dotlist (
330
+ self ._default_config = _merge_with_dotlist (
221
331
self ._default_config , args .opts , skip_missing = True , log_info = False
222
332
)
223
333
# Register the config and configuration for setup
@@ -231,7 +341,7 @@ def __init__(self, args=None, default_only=False):
231
341
232
342
self .config = OmegaConf .merge (self ._default_config , other_configs )
233
343
234
- self .config = self . _merge_with_dotlist (self .config , args .opts )
344
+ self .config = _merge_with_dotlist (self .config , args .opts )
235
345
self ._update_specific (self .config )
236
346
self .upgrade (self .config )
237
347
# Resolve the config here itself after full creation so that spawned workers
@@ -382,115 +492,6 @@ def _register_resolvers(self):
382
492
OmegaConf .register_resolver ("resolve_cache_dir" , resolve_cache_dir )
383
493
OmegaConf .register_resolver ("resolve_dir" , resolve_dir )
384
494
385
- def _merge_with_dotlist (
386
- self ,
387
- config : DictConfig ,
388
- opts : List [str ],
389
- skip_missing : bool = False ,
390
- log_info : bool = True ,
391
- ):
392
- # TODO: To remove technical debt, a possible solution is to use
393
- # struct mode to update with dotlist OmegaConf node. Look into this
394
- # in next iteration
395
- # TODO: Simplify this function
396
- if opts is None :
397
- opts = []
398
-
399
- if len (opts ) == 0 :
400
- return config
401
-
402
- # Support equal e.g. model=visual_bert for better future hydra support
403
- has_equal = opts [0 ].find ("=" ) != - 1
404
- if has_equal :
405
- opt_values = [opt .split ("=" , maxsplit = 1 ) for opt in opts ]
406
- if not all (len (opt ) == 2 for opt in opt_values ):
407
- for opt in opt_values :
408
- assert len (opt ) == 2 , "{} has no value" .format (opt )
409
- else :
410
- assert len (opts ) % 2 == 0 , "Number of opts should be multiple of 2"
411
- opt_values = zip (opts [0 ::2 ], opts [1 ::2 ])
412
-
413
- for opt , value in opt_values :
414
- if opt == "dataset" :
415
- opt = "datasets"
416
-
417
- splits = opt .split ("." )
418
- current = config
419
- for idx , field in enumerate (splits ):
420
- array_index = - 1
421
- if field .find ("[" ) != - 1 and field .find ("]" ) != - 1 :
422
- stripped_field = field [: field .find ("[" )]
423
- array_index = int (field [field .find ("[" ) + 1 : field .find ("]" )])
424
- else :
425
- stripped_field = field
426
- if stripped_field not in current :
427
- if skip_missing is True :
428
- break
429
- raise AttributeError (
430
- "While updating configuration"
431
- " option {} is missing from"
432
- " configuration at field {}" .format (opt , stripped_field )
433
- )
434
- if isinstance (current [stripped_field ], collections .abc .Mapping ):
435
- current = current [stripped_field ]
436
- elif (
437
- isinstance (current [stripped_field ], collections .abc .Sequence )
438
- and array_index != - 1
439
- ):
440
- try :
441
- current_value = current [stripped_field ][array_index ]
442
- except OCErrors .ConfigIndexError :
443
- if skip_missing :
444
- break
445
- raise
446
-
447
- # Case where array element to be updated is last element
448
- if (
449
- not isinstance (
450
- current_value ,
451
- (collections .abc .Mapping , collections .abc .Sequence ),
452
- )
453
- or idx == len (splits ) - 1
454
- ):
455
- if log_info :
456
- logger .info (f"Overriding option { opt } to { value } " )
457
- current [stripped_field ][array_index ] = self ._decode_value (value )
458
- else :
459
- # Otherwise move on down the chain
460
- current = current_value
461
- else :
462
- if idx == len (splits ) - 1 :
463
- if log_info :
464
- logger .info (f"Overriding option { opt } to { value } " )
465
- current [stripped_field ] = self ._decode_value (value )
466
- else :
467
- if skip_missing :
468
- break
469
-
470
- raise AttributeError (
471
- "While updating configuration" ,
472
- "option {} is not present "
473
- "after field {}" .format (opt , stripped_field ),
474
- )
475
-
476
- return config
477
-
478
- def _decode_value (self , value ):
479
- # https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L400
480
- if not isinstance (value , str ):
481
- return value
482
-
483
- if value == "None" :
484
- value = None
485
-
486
- try :
487
- value = literal_eval (value )
488
- except ValueError :
489
- pass
490
- except SyntaxError :
491
- pass
492
- return value
493
-
494
495
def freeze (self ):
495
496
OmegaConf .set_struct (self .config , True )
496
497
0 commit comments