77import argparse
88import dataclasses
99import functools
10+ import inspect
1011import itertools
1112import shlex
1213import sys
1516from collections import defaultdict
1617from logging import getLogger
1718from pathlib import Path
18- from typing import Any , Callable , Sequence , Type , overload
19-
19+ from typing import Any , Callable , Mapping , Sequence , Type , overload
20+ from typing_extensions import TypeGuard
21+ import warnings
2022from simple_parsing .helpers .subgroups import SubgroupKey
23+ from simple_parsing .replace import SUBGROUP_KEY_FLAG
2124from simple_parsing .wrappers .dataclass_wrapper import DataclassWrapperType
2225
2326from . import utils
2427from .conflicts import ConflictResolution , ConflictResolver
2528from .help_formatter import SimpleHelpFormatter
26- from .helpers .serialization .serializable import read_file
29+ from .helpers .serialization .serializable import DC_TYPE_KEY , read_file
2730from .utils import (
31+ K ,
32+ V ,
2833 Dataclass ,
2934 DataclassT ,
35+ PossiblyNestedDict ,
3036 dict_union ,
3137 is_dataclass_instance ,
3238 is_dataclass_type ,
@@ -593,7 +599,7 @@ def _resolve_subgroups(
593599
594600 This modifies the wrappers in-place, by possibly adding children to the wrappers in the
595601 list.
596- Returns a list with the modified wrappers.
602+ Returns a list with the (now modified) wrappers.
597603
598604 Each round does the following:
599605 1. Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting
@@ -618,13 +624,7 @@ def _resolve_subgroups(
618624 # times.
619625 subgroup_choice_parser = argparse .ArgumentParser (
620626 add_help = False ,
621- # conflict_resolution=self.conflict_resolution,
622- # add_option_string_dash_variants=self.add_option_string_dash_variants,
623- # argument_generation_mode=self.argument_generation_mode,
624- # nested_mode=self.nested_mode,
625627 formatter_class = self .formatter_class ,
626- # add_config_path_arg=self.add_config_path_arg,
627- # config_path=self.config_path,
628628 # NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues
629629 # for example if you have —a_or_b and A has a field —a then it will error out if you
630630 # pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it
@@ -644,10 +644,27 @@ def _resolve_subgroups(
644644 flags = subgroup_field .option_strings
645645 argument_options = subgroup_field .arg_options
646646
647+ # Sanity checks:
647648 if subgroup_field .subgroup_default is dataclasses .MISSING :
648649 assert argument_options ["required" ]
650+ elif isinstance (argument_options ["default" ], dict ):
651+ # BUG #276: The default here is a dict because it came from a config file.
652+ # Here we want the subgroup field to have a 'str' default, because we just want
653+ # to be able to choose between the subgroup names.
654+ _default = argument_options ["default" ]
655+ _default_key = _infer_subgroup_key_to_use_from_config (
656+ default_in_config = _default ,
657+ # subgroup_default=subgroup_field.subgroup_default,
658+ subgroup_choices = subgroup_field .subgroup_choices ,
659+ )
660+ # We'd like this field to (at least temporarily) have a different default
661+ # value that is the subgroup key instead of the dictionary.
662+ argument_options ["default" ] = _default_key
663+
649664 else :
650- assert argument_options ["default" ] is subgroup_field .subgroup_default
665+ assert (
666+ argument_options ["default" ] is subgroup_field .subgroup_default
667+ ), argument_options ["default" ]
651668 assert not is_dataclass_instance (argument_options ["default" ])
652669
653670 # TODO: Do we really need to care about this "SUPPRESS" stuff here?
@@ -1177,3 +1194,146 @@ def _create_dataclass_instance(
11771194 return None
11781195 logger .debug (f"Calling constructor: { constructor } (**{ constructor_args } )" )
11791196 return constructor (** constructor_args )
1197+
1198+
1199+ def _has_values_of_type (
1200+ mapping : Mapping [K , Any ], value_type : type [V ] | tuple [type [V ], ...]
1201+ ) -> TypeGuard [Mapping [K , V ]]:
1202+ # Utility functions used to narrow the type of dictionaries.
1203+ return all (isinstance (v , value_type ) for v in mapping .values ())
1204+
1205+
1206+ def _has_keys_of_type (
1207+ mapping : Mapping [Any , V ], key_type : type [K ] | tuple [type [K ], ...]
1208+ ) -> TypeGuard [Mapping [K , V ]]:
1209+ # Utility functions used to narrow the type of dictionaries.
1210+ return all (isinstance (k , key_type ) for k in mapping .keys ())
1211+
1212+
1213+ def _has_items_of_type (
1214+ mapping : Mapping [Any , Any ],
1215+ item_type : tuple [type [K ] | tuple [type [K ], ...], type [V ] | tuple [type [V ], ...]],
1216+ ) -> TypeGuard [Mapping [K , V ]]:
1217+ # Utility functions used to narrow the type of a dictionary or mapping.
1218+ key_type , value_type = item_type
1219+ return _has_keys_of_type (mapping , key_type ) and _has_values_of_type (mapping , value_type )
1220+
1221+
1222+ def _infer_subgroup_key_to_use_from_config (
1223+ default_in_config : dict [str , Any ],
1224+ # subgroup_default: Hashable,
1225+ subgroup_choices : Mapping [SubgroupKey , type [Dataclass ] | functools .partial [Dataclass ]],
1226+ ) -> SubgroupKey :
1227+ config_default = default_in_config
1228+
1229+ if SUBGROUP_KEY_FLAG in default_in_config :
1230+ return default_in_config [SUBGROUP_KEY_FLAG ]
1231+
1232+ for subgroup_key , subgroup_value in subgroup_choices .items ():
1233+ if default_in_config == subgroup_value :
1234+ return subgroup_key
1235+
1236+ assert (
1237+ DC_TYPE_KEY in config_default
1238+ ), f"FIXME: assuming that the { DC_TYPE_KEY } is in the config dict."
1239+ _default_type_name : str = config_default [DC_TYPE_KEY ]
1240+
1241+ if _has_values_of_type (subgroup_choices , type ) and all (
1242+ dataclasses .is_dataclass (subgroup_option ) for subgroup_option in subgroup_choices .values ()
1243+ ):
1244+ # Simpler case: All the subgroup options are dataclass types. We just get the key that
1245+ # matches the type that was saved in the config dict.
1246+ subgroup_keys_with_value_matching_config_default_type : list [SubgroupKey ] = [
1247+ k
1248+ for k , v in subgroup_choices .items ()
1249+ if (isinstance (v , type ) and f"{ v .__module__ } .{ v .__qualname__ } " == _default_type_name )
1250+ ]
1251+ # NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})`
1252+ assert len (subgroup_keys_with_value_matching_config_default_type ) >= 1
1253+ return subgroup_keys_with_value_matching_config_default_type [0 ]
1254+
1255+ # IDEA: Try to find the best subgroup key to use, based on the number of matching constructor
1256+ # arguments between the default in the config and the defaults for each subgroup.
1257+ constructor_args_in_each_subgroup = {
1258+ key : _default_constructor_argument_values (subgroup_value )
1259+ for key , subgroup_value in subgroup_choices .items ()
1260+ }
1261+ n_matching_values = {
1262+ k : _num_matching_values (config_default , constructor_args_in_subgroup_value )
1263+ for k , constructor_args_in_subgroup_value in constructor_args_in_each_subgroup .items ()
1264+ }
1265+ closest_subgroups_first = sorted (
1266+ subgroup_choices .keys (),
1267+ key = n_matching_values .__getitem__ ,
1268+ reverse = True ,
1269+ )
1270+ warnings .warn (
1271+ # TODO: Return the dataclass type instead, and be done with it!
1272+ RuntimeWarning (
1273+ f"TODO: The config file contains a default value for a subgroup that isn't in the "
1274+ f"dict of subgroup options. Because of how subgroups are currently implemented, we "
1275+ f"need to find the key in the subgroup choice dict ({ subgroup_choices } ) that most "
1276+ f"closely matches the value { config_default } ."
1277+ f"The current implementation tries to use the dataclass type of this closest match "
1278+ f"to parse the additional values from the command-line. "
1279+ f"{ default_in_config } . Consider adding the "
1280+ f"{ SUBGROUP_KEY_FLAG } : <key of the subgroup to use>"
1281+ )
1282+ )
1283+ return closest_subgroups_first [0 ]
1284+ return closest_subgroups_first [0 ]
1285+
1286+ sorted (
1287+ [k for k , v in subgroup_choices .items ()],
1288+ key = _num_matching_values ,
1289+ reversed = True ,
1290+ )
1291+ # _default_values = copy.deepcopy(config_default)
1292+ # _default_values.pop(DC_TYPE_KEY)
1293+
1294+ # default_constructor_args_for_each_subgroup = {
1295+ # k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type)
1296+ # }
1297+
1298+
1299+ def _default_constructor_argument_values (
1300+ some_dataclass_type : type [Dataclass ] | functools .partial [Dataclass ],
1301+ ) -> PossiblyNestedDict [str , Any ]:
1302+ result = {}
1303+ if isinstance (some_dataclass_type , functools .partial ) and is_dataclass_type (
1304+ some_dataclass_type .func
1305+ ):
1306+ constructor_arguments_from_classdef = _default_constructor_argument_values (
1307+ some_dataclass_type .func
1308+ )
1309+ # TODO: will probably raise an error!
1310+ constructor_arguments_from_partial = (
1311+ inspect .signature (some_dataclass_type .func )
1312+ .bind_partial (* some_dataclass_type .args , ** some_dataclass_type .keywords )
1313+ .arguments
1314+ )
1315+ constructor_arguments_from_classdef .update (constructor_arguments_from_partial )
1316+ return constructor_arguments_from_classdef
1317+
1318+ assert is_dataclass_type (some_dataclass_type )
1319+ for field in dataclasses .fields (some_dataclass_type ):
1320+ key = field .name
1321+ if field .default is not dataclasses .MISSING :
1322+ result [key ] = field .default
1323+ elif is_dataclass_type (field .type ) or (
1324+ isinstance (field .default_factory , functools .partial )
1325+ and dataclasses .is_dataclass (field .default_factory .func )
1326+ ):
1327+ result [key ] = _default_constructor_argument_values (field .type )
1328+ return result
1329+
1330+
1331+ def _num_matching_values (subgroup_default : dict [str , Any ], subgroup_choice : dict [str , Any ]) -> int :
1332+ """Returns the number of matching entries in the subgroup dict w/ the default from the
1333+ config."""
1334+ return sum (
1335+ _num_matching_values (default_v , subgroup_choice [k ])
1336+ if isinstance (subgroup_choice .get (k ), dict ) and isinstance (default_v , dict )
1337+ else int (subgroup_choice .get (k ) == default_v )
1338+ for k , default_v in subgroup_default .items ()
1339+ )
0 commit comments