diff --git a/tap/tap.py b/tap/tap.py index b4417c1..3b01557 100644 --- a/tap/tap.py +++ b/tap/tap.py @@ -1,7 +1,9 @@ from argparse import ArgumentParser, ArgumentTypeError from collections import OrderedDict from copy import deepcopy +from enum import Enum from functools import partial +from inspect import isclass import json from pathlib import Path from pprint import pformat @@ -156,7 +158,7 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None: kwargs['help'] = '(' # Type - if variable in self._annotations: + if variable in self._annotations and not self._inherit_enum(variable): kwargs['help'] += type_to_str(self._annotations[variable]) + ', ' # Required/default @@ -270,6 +272,10 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None: elif kwargs.get('action') not in {'count', 'append_const'}: kwargs['type'] = var_type + # If inherited from Enum + if self._inherit_enum(variable): + kwargs['choices'] = [enum.value for enum in var_type] + if self._underscores_to_dashes: # Replace "_" with "-" for arguments that aren't positional name_or_flags = tuple(name_or_flag.replace('_', '-') if name_or_flag.startswith('-') else name_or_flag @@ -702,6 +708,17 @@ def _load_from_config_files(self, config_files: Optional[List[str]]) -> List[str return args_from_config + def _inherit_enum(self, variable: str) -> bool: + """Return if the variable inherit from an Enum or not + + :param variable: The name of the argument + :return: True if it inherit Enum, False otherwise""" + if variable not in self._annotations: + return False + + var_type = self._annotations[variable] + return isclass(var_type) and issubclass(var_type, Enum) + def __str__(self) -> str: """Returns a string representation of self.