diff --git a/src/azure-cli/azure/cli/command_modules/acs/_helpers.py b/src/azure-cli/azure/cli/command_modules/acs/_helpers.py index e4cc808b0a7..291bbc3e1f6 100644 --- a/src/azure-cli/azure/cli/command_modules/acs/_helpers.py +++ b/src/azure-cli/azure/cli/command_modules/acs/_helpers.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------------------------- import re +from typing import Any, List, TypeVar from azure.cli.command_modules.acs._client_factory import cf_snapshots, get_msi_client from azure.cli.core.azclierror import ( @@ -21,6 +22,69 @@ from azure.core.exceptions import AzureError, HttpResponseError, ServiceRequestError, ServiceResponseError from msrestazure.azure_exceptions import CloudError +# type variables +ManagedCluster = TypeVar("ManagedCluster") + + +def format_parameter_name_to_option_name(parameter_name: str) -> str: + """Convert a name in parameter format to option format. + + Underscores ("_") are used to connect the various parts of a parameter name, while hyphens ("-") are used to connect + each part of an option name. Besides, the option name starts with double hyphens ("--"). + + :return: str + """ + option_name = "--" + parameter_name.replace("_", "-") + return option_name + + +def safe_list_get(li: List, idx: int, default: Any = None) -> Any: + """Get an element from a list without raising IndexError. + + Attempt to get the element with index idx from a list-like object li, and if the index is invalid (such as out of + range), return default (whose default value is None). + + :return: an element of any type + """ + if isinstance(li, list): + try: + return li[idx] + except IndexError: + return default + return None + + +def safe_lower(obj: Any) -> Any: + """Return lowercase string if the provided obj is a string, otherwise return the object itself. + + :return: Any + """ + if isinstance(obj, str): + return obj.lower() + return obj + + +def check_is_msi_cluster(mc: ManagedCluster) -> bool: + """Check `mc` object to determine whether managed identity is enabled. + + :return: bool + """ + if mc and mc.identity and mc.identity.type is not None: + identity_type = mc.identity.type.casefold() + if identity_type in ("systemassigned", "userassigned"): + return True + return False + + +def check_is_private_cluster(mc: ManagedCluster) -> bool: + """Check `mc` object to determine whether private cluster is enabled. + + :return: bool + """ + if mc and mc.api_server_access_profile: + return bool(mc.api_server_access_profile.enable_private_cluster) + return False + # pylint: disable=too-many-return-statements def map_azure_error_to_cli_error(azure_error): diff --git a/src/azure-cli/azure/cli/command_modules/acs/agentpool_decorator.py b/src/azure-cli/azure/cli/command_modules/acs/agentpool_decorator.py index 40145a8c4d1..405fb4f353a 100644 --- a/src/azure-cli/azure/cli/command_modules/acs/agentpool_decorator.py +++ b/src/azure-cli/azure/cli/command_modules/acs/agentpool_decorator.py @@ -8,8 +8,7 @@ from azure.cli.command_modules.acs._client_factory import cf_agent_pools from azure.cli.command_modules.acs._consts import DecoratorMode from azure.cli.command_modules.acs._validators import extract_comma_separated_string -from azure.cli.command_modules.acs.decorator import validate_decorator_mode -from azure.cli.core import AzCommandsLoader +from azure.cli.command_modules.acs.base_decorator import BaseAKSContext, BaseAKSModels, BaseAKSParamDict from azure.cli.core.azclierror import CLIInternalError, InvalidArgumentValueError, RequiredArgumentMissingError from azure.cli.core.commands import AzCliCommand from azure.cli.core.profiles import ResourceType @@ -23,60 +22,34 @@ AgentPoolsOperations = TypeVar("AgentPoolsOperations") -# pylint: disable=too-many-instance-attributes, too-few-public-methods -class AKSAgentPoolModels: - """Store the models used in aks_agentpool_add and aks_agentpool_update. +# pylint: disable=too-few-public-methods +class AKSAgentPoolModels(BaseAKSModels): + """Store the models used in aks agentpool series of commands. The api version of the class corresponding to a model is determined by resource_type. """ - def __init__( - self, - cmd: AzCommandsLoader, - resource_type: ResourceType, - ): - self.__cmd = cmd - self.resource_type = resource_type - self.AgentPool = self.__cmd.get_models( - "AgentPool", - resource_type=self.resource_type, - operation_group="agent_pools", - ) - self.AgentPoolUpgradeSettings = self.__cmd.get_models( - "AgentPoolUpgradeSettings", - resource_type=self.resource_type, - operation_group="agent_pools", - ) + +# pylint: disable=too-few-public-methods +class AKSAgentPoolParamDict(BaseAKSParamDict): + """Store the original parameters passed in by aks agentpool series of commands as an internal dictionary. + + Only expose the "get" method externally to obtain parameter values, while recording usage. + """ # pylint: disable=too-many-public-methods -class AKSAgentPoolContext: +class AKSAgentPoolContext(BaseAKSContext): """Implement getter functions for all parameters in aks_agentpool_add and aks_agentpool_update. """ def __init__( self, cmd: AzCliCommand, - raw_parameters: Dict, + raw_parameters: AKSAgentPoolParamDict, models: AKSAgentPoolModels, decorator_mode: DecoratorMode, ): - if not isinstance(raw_parameters, dict): - raise CLIInternalError( - "Unexpected raw_parameters object with type '{}'.".format( - type(raw_parameters) - ) - ) - if not validate_decorator_mode(decorator_mode): - raise CLIInternalError( - "Unexpected decorator_mode '{}' with type '{}'.".format( - decorator_mode, type(decorator_mode) - ) - ) - self.cmd = cmd - self.raw_param = raw_parameters - self.models = models - self.decorator_mode = decorator_mode - self.intermediates = dict() + super().__init__(cmd, raw_parameters, models, decorator_mode) self.agentpool = None # pylint: disable=no-self-use @@ -389,7 +362,9 @@ def __init__( self.client = client self.models = AKSAgentPoolModels(cmd, resource_type) # store the context in the process of assemble the AgentPool object - self.context = AKSAgentPoolContext(cmd, raw_parameters, self.models, decorator_mode=DecoratorMode.CREATE) + self.context = AKSAgentPoolContext( + cmd, AKSAgentPoolParamDict(raw_parameters), self.models, decorator_mode=DecoratorMode.CREATE + ) def _ensure_agentpool(self, agentpool: AgentPool) -> None: """Internal function to ensure that the incoming `agentpool` object is valid and the same as the attached @@ -535,4 +510,6 @@ def __init__( self.client = client self.models = AKSAgentPoolModels(cmd, resource_type) # store the context in the process of assemble the AgentPool object - self.context = AKSAgentPoolContext(cmd, raw_parameters, self.models, decorator_mode=DecoratorMode.UPDATE) + self.context = AKSAgentPoolContext( + cmd, AKSAgentPoolParamDict(raw_parameters), self.models, decorator_mode=DecoratorMode.UPDATE + ) diff --git a/src/azure-cli/azure/cli/command_modules/acs/base_decorator.py b/src/azure-cli/azure/cli/command_modules/acs/base_decorator.py new file mode 100644 index 00000000000..37d936cb81f --- /dev/null +++ b/src/azure-cli/azure/cli/command_modules/acs/base_decorator.py @@ -0,0 +1,192 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from typing import Any + +from azure.cli.command_modules.acs._consts import DecoratorMode +from azure.cli.core import AzCommandsLoader +from azure.cli.core.azclierror import CLIInternalError +from azure.cli.core.commands import AzCliCommand +from azure.cli.core.profiles import ResourceType +from knack.log import get_logger + +logger = get_logger(__name__) + + +def validate_decorator_mode(decorator_mode) -> bool: + """Check if decorator_mode is a value of enum type DecoratorMode. + + :return: bool + """ + is_valid_decorator_mode = False + try: + is_valid_decorator_mode = decorator_mode in DecoratorMode + # will raise TypeError in Python >= 3.8 + except TypeError: + pass + + return is_valid_decorator_mode + + +class BaseAKSModels: + """A base class for storing the models used by aks commands. + + The api version of the class corresponding to a model is determined by resource_type. + """ + def __init__( + self, + cmd: AzCommandsLoader, + resource_type: ResourceType, + ): + self.__cmd = cmd + self.__raw_models = None + self.resource_type = resource_type + self.set_up_models() + + @property + def raw_models(self): + if self.__raw_models is None: + self.__raw_models = self.__cmd.get_models( + resource_type=self.resource_type, + operation_group="managed_clusters", + ).models + return self.__raw_models + + def set_up_models(self): + for model_name, model_class in vars(self.raw_models).items(): + if not model_name.startswith('_'): + setattr(self, model_name, model_class) + + +class BaseAKSParamDict: + """A base class for storing the original parameters passed in by the aks commands as an internal dictionary. + + Only expose the "get" method externally to obtain parameter values, while recording usage. + """ + def __init__(self, param_dict): + if not isinstance(param_dict, dict): + raise CLIInternalError( + "Unexpected param_dict object with type '{}'.".format( + type(param_dict) + ) + ) + self.__store = param_dict.copy() + self.__count = {} + + def __increase(self, key): + self.__count[key] = self.__count.get(key, 0) + 1 + + def get(self, key): + self.__increase(key) + return self.__store.get(key) + + def keys(self): + return self.__store.keys() + + def values(self): + return self.__store.values() + + def items(self): + return self.__store.items() + + def __format_count(self): + untouched_keys = [x for x in self.__store.keys() if x not in self.__count.keys()] + for k in untouched_keys: + self.__count[k] = 0 + + def print_usage_statistics(self): + self.__format_count() + print("\nParameter usage statistics:") + for k, v in self.__count.items(): + print(k, v) + print("Total: {}".format(len(self.__count.keys()))) + + +class BaseAKSContext: + """A base class for holding raw parameters, models and methods to get and store intermediates that will be used by + the decorators of aks commands. + + Note: This is a base class and should not be used directly, you need to implement getter functions in inherited + classes. + + Each getter function is responsible for obtaining the corresponding one or more parameter values, and perform + necessary parameter value completion or normalization and validation checks. + """ + def __init__( + self, cmd: AzCliCommand, raw_parameters: BaseAKSParamDict, models: BaseAKSModels, decorator_mode: DecoratorMode + ): + if not isinstance(raw_parameters, BaseAKSParamDict): + raise CLIInternalError( + "Unexpected raw_parameters object with type '{}'.".format( + type(raw_parameters) + ) + ) + if not validate_decorator_mode(decorator_mode): + raise CLIInternalError( + "Unexpected decorator_mode '{}' with type '{}'.".format( + decorator_mode, type(decorator_mode) + ) + ) + self.cmd = cmd + self.raw_param = raw_parameters + self.models = models + self.decorator_mode = decorator_mode + self.intermediates = dict() + + def get_intermediate(self, variable_name: str, default_value: Any = None) -> Any: + """Get the value of an intermediate by its name. + + Get the value from the intermediates dictionary with variable_name as the key. If variable_name does not exist, + default_value will be returned. + + :return: Any + """ + if variable_name not in self.intermediates: + logger.debug( + "The intermediate '%s' does not exist. Return default value '%s'.", + variable_name, + default_value, + ) + intermediate_value = self.intermediates.get(variable_name, default_value) + return intermediate_value + + def set_intermediate( + self, variable_name: str, value: Any, overwrite_exists: bool = False + ) -> None: + """Set the value of an intermediate by its name. + + In the case that the intermediate value already exists, if overwrite_exists is enabled, the value will be + overwritten and the log will be output at the debug level, otherwise the value will not be overwritten and + the log will be output at the warning level, which by default will be output to stderr and seen by user. + + :return: None + """ + if variable_name in self.intermediates: + if overwrite_exists: + msg = "The intermediate '{}' is overwritten. Original value: '{}', new value: '{}'.".format( + variable_name, self.intermediates.get(variable_name), value + ) + logger.debug(msg) + self.intermediates[variable_name] = value + elif self.intermediates.get(variable_name) != value: + msg = "The intermediate '{}' already exists, but overwrite is not enabled. " \ + "Original value: '{}', candidate value: '{}'.".format( + variable_name, + self.intermediates.get(variable_name), + value, + ) + # warning level log will be output to the console, which may cause confusion to users + logger.warning(msg) + else: + self.intermediates[variable_name] = value + + def remove_intermediate(self, variable_name: str) -> None: + """Remove the value of an intermediate by its name. + + No exception will be raised if the intermediate does not exist. + + :return: None + """ + self.intermediates.pop(variable_name, None) diff --git a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_agentpool_decorator.py b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_agentpool_decorator.py index f23d464b03e..e447f0e9711 100644 --- a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_agentpool_decorator.py +++ b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_agentpool_decorator.py @@ -3,18 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import importlib import unittest -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, patch from azure.cli.command_modules.acs._consts import DecoratorEarlyExitException, DecoratorMode from azure.cli.command_modules.acs.agentpool_decorator import ( + AKSAgentPoolModels, + AKSAgentPoolParamDict, AKSAgentPoolAddDecorator, AKSAgentPoolContext, - AKSAgentPoolModels, AKSAgentPoolUpdateDecorator, ) -from azure.cli.command_modules.acs.decorator import AKSParamDict from azure.cli.command_modules.acs.tests.latest.mocks import MockCLI, MockClient, MockCmd from azure.cli.core.azclierror import ( ArgumentUsageError, @@ -27,33 +26,6 @@ UnknownError, ) from azure.cli.core.profiles import ResourceType -from azure.core.exceptions import HttpResponseError -from knack.prompting import NoTTYException -from knack.util import CLIError -from msrestazure.azure_exceptions import CloudError - - -class AKSAgentPoolModelsTestCase(unittest.TestCase): - def setUp(self): - self.cli_ctx = MockCLI() - self.cmd = MockCmd(self.cli_ctx) - - def test_models(self): - models = AKSAgentPoolModels(self.cmd, ResourceType.MGMT_CONTAINERSERVICE) - - # load models directly (instead of through the `get_sdk` method provided by the cli component) - from azure.cli.core.profiles._shared import AZURE_API_PROFILES - - sdk_profile = AZURE_API_PROFILES["latest"][ResourceType.MGMT_CONTAINERSERVICE] - api_version = sdk_profile.default_api_version - module_name = "azure.mgmt.containerservice.v{}.models".format(api_version.replace("-", "_")) - module = importlib.import_module(module_name) - - self.assertEqual(models.AgentPool, getattr(module, "AgentPool")) - self.assertEqual( - models.AgentPoolUpgradeSettings, - getattr(module, "AgentPoolUpgradeSettings"), - ) class AKSAgentPoolContextTestCase(unittest.TestCase): @@ -68,10 +40,10 @@ def test__init__(self): AKSAgentPoolContext(self.cmd, [], self.models, decorator_mode=DecoratorMode.CREATE) # fail on not passing decorator_mode with Enum type DecoratorMode with self.assertRaises(CLIInternalError): - AKSAgentPoolContext(self.cmd, {}, self.models, decorator_mode=1) + AKSAgentPoolContext(self.cmd, AKSAgentPoolParamDict({}), self.models, decorator_mode=1) def test_attach_agentpool(self): - ctx_1 = AKSAgentPoolContext(self.cmd, {}, self.models, decorator_mode=DecoratorMode.CREATE) + ctx_1 = AKSAgentPoolContext(self.cmd, AKSAgentPoolParamDict({}), self.models, decorator_mode=DecoratorMode.CREATE) agentpool = self.models.AgentPool() ctx_1.attach_agentpool(agentpool) self.assertEqual(ctx_1.agentpool, agentpool) @@ -80,7 +52,7 @@ def test_attach_agentpool(self): ctx_1.attach_agentpool(agentpool) def test_validate_counts_in_autoscaler(self): - ctx = AKSAgentPoolContext(self.cmd, {}, self.models, decorator_mode=DecoratorMode.CREATE) + ctx = AKSAgentPoolContext(self.cmd, AKSAgentPoolParamDict({}), self.models, decorator_mode=DecoratorMode.CREATE) # default ctx._AKSAgentPoolContext__validate_counts_in_autoscaler(3, False, None, None, DecoratorMode.CREATE) @@ -111,7 +83,7 @@ def test_get_resource_group_name(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - {"resource_group_name": "test_rg_name"}, + AKSAgentPoolParamDict({"resource_group_name": "test_rg_name"}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -121,7 +93,7 @@ def test_get_cluster_name(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - {"cluster_name": "test_cluster_name"}, + AKSAgentPoolParamDict({"cluster_name": "test_cluster_name"}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -131,7 +103,7 @@ def test_get_nodepool_name(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - {"nodepool_name": "test_nodepool_name"}, + AKSAgentPoolParamDict({"nodepool_name": "test_nodepool_name"}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -153,7 +125,7 @@ def test_get_nodepool_name(self): # custom ctx_2 = AKSAgentPoolContext( self.cmd, - {"nodepool_name": "test_nodepool_name"}, + AKSAgentPoolParamDict({"nodepool_name": "test_nodepool_name"}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -171,9 +143,9 @@ def test_get_max_surge(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - { + AKSAgentPoolParamDict({ "max_surge": None, - }, + }), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -190,12 +162,12 @@ def test_get_node_count_and_enable_cluster_autoscaler_min_max_count( # default ctx_1 = AKSAgentPoolContext( self.cmd, - { + AKSAgentPoolParamDict({ "node_count": 3, "enable_cluster_autoscaler": False, "min_count": None, "max_count": None, - }, + }), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -219,7 +191,7 @@ def test_get_node_osdisk_size(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - {"node_osdisk_size": 0}, + AKSAgentPoolParamDict({"node_osdisk_size": 0}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -232,7 +204,7 @@ def test_get_node_osdisk_type(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - {"node_osdisk_type": None}, + AKSAgentPoolParamDict({"node_osdisk_type": None}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -245,9 +217,9 @@ def test_get_aks_custom_headers(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - { + AKSAgentPoolParamDict({ "aks_custom_headers": None, - }, + }), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -256,9 +228,9 @@ def test_get_aks_custom_headers(self): # custom value ctx_2 = AKSAgentPoolContext( self.cmd, - { + AKSAgentPoolParamDict({ "aks_custom_headers": "abc=def,xyz=123", - }, + }), self.models, decorator_mode=DecoratorMode.UPDATE, ) @@ -268,7 +240,7 @@ def test_get_no_wait(self): # default ctx_1 = AKSAgentPoolContext( self.cmd, - {"no_wait": False}, + AKSAgentPoolParamDict({"no_wait": False}), self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -359,7 +331,6 @@ def test_construct_default_agentpool(self): "nodepool_name": "test_nodepool_name", } raw_param_dict.update(optional_params) - raw_param_dict = AKSParamDict(raw_param_dict) # default value in `aks_create` dec_1 = AKSAgentPoolAddDecorator( @@ -381,7 +352,7 @@ def test_construct_default_agentpool(self): ) agentpool_1.name = "test_nodepool_name" self.assertEqual(dec_agentpool_1, agentpool_1) - raw_param_dict.print_usage_statistics() + dec_1.context.raw_param.print_usage_statistics() class AKSAgentPoolUpdateDecoratorTestCase(unittest.TestCase): diff --git a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_base_decorator.py b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_base_decorator.py new file mode 100644 index 00000000000..ae1921b691b --- /dev/null +++ b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_base_decorator.py @@ -0,0 +1,131 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import importlib +import unittest + +from azure.cli.command_modules.acs._consts import DecoratorMode +from azure.cli.command_modules.acs.base_decorator import ( + BaseAKSContext, + BaseAKSModels, + BaseAKSParamDict, + validate_decorator_mode, +) +from azure.cli.command_modules.acs.tests.latest.mocks import MockCLI, MockCmd +from azure.cli.core.azclierror import CLIInternalError +from azure.cli.core.profiles import ResourceType + + +class BaseDecoratorHelperFunctionsTestCase(unittest.TestCase): + def test_validate_decorator_mode(self): + self.assertEqual(validate_decorator_mode(DecoratorMode.CREATE), True) + self.assertEqual(validate_decorator_mode(DecoratorMode.UPDATE), True) + self.assertEqual(validate_decorator_mode(DecoratorMode), False) + self.assertEqual(validate_decorator_mode(1), False) + self.assertEqual(validate_decorator_mode("1"), False) + self.assertEqual(validate_decorator_mode(True), False) + self.assertEqual(validate_decorator_mode({}), False) + + +class BaseAKSModelsTestCase(unittest.TestCase): + def setUp(self): + self.cli_ctx = MockCLI() + self.cmd = MockCmd(self.cli_ctx) + + def test_models(self): + # load models directly (instead of through the `get_sdk` method provided by the cli component) + from azure.cli.core.profiles._shared import AZURE_API_PROFILES + + sdk_profile = AZURE_API_PROFILES["latest"][ResourceType.MGMT_CONTAINERSERVICE] + api_version = sdk_profile.default_api_version + module_name = "azure.mgmt.containerservice.v{}.models".format(api_version.replace("-", "_")) + module = importlib.import_module(module_name) + models = BaseAKSModels(self.cmd, ResourceType.MGMT_CONTAINERSERVICE) + self.assertEqual(models.raw_models, module) + + +class BaseAKSParamDictTestCase(unittest.TestCase): + def test__init__(self): + # fail on not passing dictionary-like parameters + with self.assertRaises(CLIInternalError): + BaseAKSParamDict([]) + + def test_get(self): + param_dict = BaseAKSParamDict({"abc": "xyz"}) + self.assertEqual(param_dict.get("abc"), "xyz") + + def test_keys(self): + param_dict = BaseAKSParamDict({"abc": "xyz"}) + self.assertEqual(list(param_dict.keys()), ["abc"]) + + def test_values(self): + param_dict = BaseAKSParamDict({"abc": "xyz"}) + self.assertEqual(list(param_dict.values()), ["xyz"]) + + def test_items(self): + param_dict = BaseAKSParamDict({"abc": "xyz"}) + self.assertEqual(list(param_dict.items()), [("abc", "xyz")]) + + def test_print_usage_statistics(self): + param_dict = BaseAKSParamDict({"abc": "xyz", "def": 100}) + param_dict.print_usage_statistics() + + +class BaseAKSContextTestCase(unittest.TestCase): + def setUp(self): + self.cli_ctx = MockCLI() + self.cmd = MockCmd(self.cli_ctx) + self.models = BaseAKSModels(self.cmd, ResourceType.MGMT_CONTAINERSERVICE) + + def test__init__(self): + # fail on not passing dictionary-like parameters + with self.assertRaises(CLIInternalError): + BaseAKSContext(self.cmd, [], self.models, decorator_mode=DecoratorMode.CREATE) + # fail on not passing decorator_mode with Enum type DecoratorMode + with self.assertRaises(CLIInternalError): + BaseAKSContext(self.cmd, BaseAKSParamDict({}), self.models, decorator_mode=1) + + def test_get_intermediate(self): + ctx_1 = BaseAKSContext(self.cmd, BaseAKSParamDict({}), self.models, decorator_mode=DecoratorMode.CREATE) + self.assertEqual( + ctx_1.get_intermediate("fake-intermediate", "not found"), + "not found", + ) + + def test_set_intermediate(self): + ctx_1 = BaseAKSContext(self.cmd, BaseAKSParamDict({}), self.models, decorator_mode=DecoratorMode.CREATE) + ctx_1.set_intermediate("test-intermediate", "test-intermediate-value") + self.assertEqual( + ctx_1.get_intermediate("test-intermediate"), + "test-intermediate-value", + ) + ctx_1.set_intermediate("test-intermediate", "new-test-intermediate-value") + self.assertEqual( + ctx_1.get_intermediate("test-intermediate"), + "test-intermediate-value", + ) + ctx_1.set_intermediate( + "test-intermediate", + "new-test-intermediate-value", + overwrite_exists=True, + ) + self.assertEqual( + ctx_1.get_intermediate("test-intermediate"), + "new-test-intermediate-value", + ) + + def test_remove_intermediate(self): + ctx_1 = BaseAKSContext(self.cmd, BaseAKSParamDict({}), self.models, decorator_mode=DecoratorMode.CREATE) + ctx_1.set_intermediate("test-intermediate", "test-intermediate-value") + self.assertEqual( + ctx_1.get_intermediate("test-intermediate"), + "test-intermediate-value", + ) + ctx_1.remove_intermediate("test-intermediate") + self.assertEqual(ctx_1.get_intermediate("test-intermediate"), None) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_helpers.py b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_helpers.py index 45b00c3e11a..15f50558061 100644 --- a/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_helpers.py +++ b/src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_helpers.py @@ -6,7 +6,20 @@ import unittest from unittest.mock import Mock, patch -from azure.cli.command_modules.acs import _helpers as helpers +from azure.cli.command_modules.acs._helpers import ( + check_is_msi_cluster, + check_is_private_cluster, + format_parameter_name_to_option_name, + get_snapshot, + get_snapshot_by_snapshot_id, + get_user_assigned_identity, + get_user_assigned_identity_by_resource_id, + map_azure_error_to_cli_error, + safe_list_get, + safe_lower, +) +from azure.cli.command_modules.acs.base_decorator import BaseAKSModels +from azure.cli.command_modules.acs.tests.latest.mocks import MockCLI, MockCmd from azure.cli.core.azclierror import ( AzureInternalError, AzureResponseError, @@ -19,10 +32,84 @@ UnauthorizedError, UnclassifiedUserFault, ) +from azure.cli.core.profiles import ResourceType from azure.core.exceptions import AzureError, HttpResponseError, ServiceRequestError, ServiceResponseError from msrestazure.azure_exceptions import CloudError +class DecoratorFunctionsTestCase(unittest.TestCase): + def setUp(self): + self.cli_ctx = MockCLI() + self.cmd = MockCmd(self.cli_ctx) + self.models = BaseAKSModels(self.cmd, ResourceType.MGMT_CONTAINERSERVICE) + + def test_format_parameter_name_to_option_name(self): + self.assertEqual(format_parameter_name_to_option_name("abc_xyz"), "--abc-xyz") + + def test_safe_list_get(self): + list_1 = [1, 2, 3] + self.assertEqual(safe_list_get(list_1, 0), 1) + self.assertEqual(safe_list_get(list_1, 10), None) + + tuple_1 = (1, 2, 3) + self.assertEqual(safe_list_get(tuple_1, 0), None) + + def test_safe_lower(self): + self.assertEqual(safe_lower(None), None) + self.assertEqual(safe_lower("ABC"), "abc") + + def test_check_is_msi_cluster(self): + self.assertEqual(check_is_msi_cluster(None), False) + + mc_1 = self.models.ManagedCluster( + location="test_location", + identity=self.models.ManagedClusterIdentity(type="SystemAssigned"), + ) + self.assertEqual(check_is_msi_cluster(mc_1), True) + + mc_2 = self.models.ManagedCluster( + location="test_location", + identity=self.models.ManagedClusterIdentity(type="UserAssigned"), + ) + self.assertEqual(check_is_msi_cluster(mc_2), True) + + mc_3 = self.models.ManagedCluster( + location="test_location", + identity=self.models.ManagedClusterIdentity(type="Test"), + ) + self.assertEqual(check_is_msi_cluster(mc_3), False) + + def test_check_is_private_cluster(self): + self.assertEqual(check_is_private_cluster(None), False) + + mc_1 = self.models.ManagedCluster( + location="test_location", + api_server_access_profile=self.models.ManagedClusterAPIServerAccessProfile( + enable_private_cluster=True, + ), + ) + self.assertEqual(check_is_private_cluster(mc_1), True) + + mc_2 = self.models.ManagedCluster( + location="test_location", + api_server_access_profile=self.models.ManagedClusterAPIServerAccessProfile( + enable_private_cluster=False, + ), + ) + self.assertEqual(check_is_private_cluster(mc_2), False) + + mc_3 = self.models.ManagedCluster( + location="test_location", + api_server_access_profile=self.models.ManagedClusterAPIServerAccessProfile(), + ) + self.assertEqual(check_is_private_cluster(mc_3), False) + + mc_4 = self.models.ManagedCluster( + location="test_location", + ) + self.assertEqual(check_is_private_cluster(mc_4), False) + + class ErrorMappingTestCase(unittest.TestCase): def check_error_equality(self, mapped_error, mock_error): self.assertEqual(type(mapped_error), type(mock_error)) @@ -47,7 +134,7 @@ def test_http_response_error(self): status_code = status_code_cli_error_pair[0] azure_error.status_code = status_code azure_error.message = f"error_msg_{idx}" - mapped_error = helpers.map_azure_error_to_cli_error(azure_error) + mapped_error = map_azure_error_to_cli_error(azure_error) # get mock error cli_error = status_code_cli_error_pair[1] mock_error = cli_error(f"error_msg_{idx}") @@ -55,19 +142,19 @@ def test_http_response_error(self): def test_service_request_error(self): azure_error = ServiceRequestError("test_error_msg") - cli_error = helpers.map_azure_error_to_cli_error(azure_error) + cli_error = map_azure_error_to_cli_error(azure_error) mock_error = ClientRequestError("test_error_msg") self.check_error_equality(cli_error, mock_error) def test_service_response_error(self): azure_error = ServiceResponseError("test_error_msg") - cli_error = helpers.map_azure_error_to_cli_error(azure_error) + cli_error = map_azure_error_to_cli_error(azure_error) mock_error = AzureResponseError("test_error_msg") self.check_error_equality(cli_error, mock_error) def test_azure_error(self): azure_error = AzureError("test_error_msg") - cli_error = helpers.map_azure_error_to_cli_error(azure_error) + cli_error = map_azure_error_to_cli_error(azure_error) mock_error = ServiceError("test_error_msg") self.check_error_equality(cli_error, mock_error) @@ -75,13 +162,13 @@ def test_azure_error(self): class GetSnapShotTestCase(unittest.TestCase): def test_get_snapshot_by_snapshot_id(self): with self.assertRaises(InvalidArgumentValueError): - helpers.get_snapshot_by_snapshot_id("mock_cli_ctx", "") + get_snapshot_by_snapshot_id("mock_cli_ctx", "") mock_snapshot = Mock() with patch( "azure.cli.command_modules.acs._helpers.get_snapshot", return_value=mock_snapshot ) as mock_get_snapshot: - snapshot = helpers.get_snapshot_by_snapshot_id( + snapshot = get_snapshot_by_snapshot_id( "mock_cli_ctx", "/subscriptions/test_sub/resourcegroups/test_rg/providers/microsoft.containerservice/snapshots/test_snapshot", ) @@ -92,14 +179,14 @@ def test_get_snapshot(self): mock_snapshot = Mock() mock_snapshot_operations = Mock(get=Mock(return_value=mock_snapshot)) with patch("azure.cli.command_modules.acs._helpers.cf_snapshots", return_value=mock_snapshot_operations): - snapshot = helpers.get_snapshot("mock_cli_ctx", "mock_rg", "mock_snapshot_name") + snapshot = get_snapshot("mock_cli_ctx", "mock_rg", "mock_snapshot_name") self.assertEqual(snapshot, mock_snapshot) mock_snapshot_operations_2 = Mock(get=Mock(side_effect=AzureError("mock snapshot was not found"))) with patch( "azure.cli.command_modules.acs._helpers.cf_snapshots", return_value=mock_snapshot_operations_2 ), self.assertRaises(ResourceNotFoundError): - helpers.get_snapshot("mock_cli_ctx", "mock_rg", "mock_snapshot_name") + get_snapshot("mock_cli_ctx", "mock_rg", "mock_snapshot_name") http_response_error = HttpResponseError() http_response_error.status_code = 400 @@ -108,20 +195,20 @@ def test_get_snapshot(self): with patch( "azure.cli.command_modules.acs._helpers.cf_snapshots", return_value=mock_snapshot_operations_3 ), self.assertRaises(BadRequestError): - helpers.get_snapshot("mock_cli_ctx", "mock_rg", "mock_snapshot_name") + get_snapshot("mock_cli_ctx", "mock_rg", "mock_snapshot_name") class GetUserAssignedIdentityTestCase(unittest.TestCase): def test_get_user_assigned_identity_by_resource_id(self): with self.assertRaises(InvalidArgumentValueError): - helpers.get_user_assigned_identity_by_resource_id("mock_cli_ctx", "") + get_user_assigned_identity_by_resource_id("mock_cli_ctx", "") mock_user_assigned_identity = Mock() with patch( "azure.cli.command_modules.acs._helpers.get_user_assigned_identity", return_value=mock_user_assigned_identity, ) as mock_get_user_assigned_identity: - user_assigned_identity = helpers.get_user_assigned_identity_by_resource_id( + user_assigned_identity = get_user_assigned_identity_by_resource_id( "mock_cli_ctx", "/subscriptions/test_sub/resourcegroups/test_rg/providers/microsoft.managedidentity/userassignedidentities/test_user_assigned_identity", ) @@ -138,7 +225,7 @@ def test_get_user_assigned_identity(self): with patch( "azure.cli.command_modules.acs._helpers.get_msi_client", return_value=mock_user_assigned_identity_operations ): - user_assigned_identity = helpers.get_user_assigned_identity( + user_assigned_identity = get_user_assigned_identity( "mock_cli_ctx", "mock_sub_id", "mock_rg", "mock_identity_name" ) self.assertEqual(user_assigned_identity, mock_user_assigned_identity) @@ -151,7 +238,7 @@ def test_get_user_assigned_identity(self): "azure.cli.command_modules.acs._helpers.get_msi_client", return_value=mock_user_assigned_identity_operations_2, ), self.assertRaises(ResourceNotFoundError): - helpers.get_user_assigned_identity("mock_cli_ctx", "mock_sub_id", "mock_rg", "mock_identity_name") + get_user_assigned_identity("mock_cli_ctx", "mock_sub_id", "mock_rg", "mock_identity_name") cloud_error_3 = CloudError(Mock(status_code="xxx"), "test_error_msg") mock_user_assigned_identity_operations_3 = Mock( @@ -161,7 +248,7 @@ def test_get_user_assigned_identity(self): "azure.cli.command_modules.acs._helpers.get_msi_client", return_value=mock_user_assigned_identity_operations_3, ), self.assertRaises(ServiceError): - helpers.get_user_assigned_identity("mock_cli_ctx", "mock_sub_id", "mock_rg", "mock_identity_name") + get_user_assigned_identity("mock_cli_ctx", "mock_sub_id", "mock_rg", "mock_identity_name") if __name__ == "__main__":