From bfaacc2407ac0f800e4b8140454a7377aa730b71 Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:03:40 +0800 Subject: [PATCH 1/7] expires_on --- src/azure-cli-core/azure/cli/core/_profile.py | 47 +++++++++---------- .../cli/core/auth/msal_authentication.py | 7 +++ .../cli/command_modules/profile/custom.py | 2 +- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index d6330e8874e..1df9e5383a3 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -361,37 +361,41 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No raise CLIError("Please specify only one of subscription and tenant, not both") account = self.get_subscription(subscription) - resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id + + # This is not used anyway, just a placeholder + mi_resoure = self.cli_ctx.cloud.endpoints.active_directory_resource_id identity_type, identity_id = Profile._try_parse_msi_account_name(account) if identity_type: - # MSI + # managed identity if tenant: - raise CLIError("Tenant shouldn't be specified for MSI account") - msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource) - msi_creds.set_token() - token_entry = msi_creds.token - creds = (token_entry['token_type'], token_entry['access_token'], token_entry) + raise CLIError("Tenant shouldn't be specified for managed identity account") + msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, mi_resoure) + sdk_token = msi_creds.get_token(*scopes) elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): # Cloud Shell if tenant: raise CLIError("Tenant shouldn't be specified for Cloud Shell account") - creds = self._get_token_from_cloud_shell(resource) + msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id, mi_resoure) + sdk_token = msi_creds.get_token(*scopes) else: credential = self._create_credential(account, tenant) - token = credential.get_token(*scopes) + sdk_token = credential.get_token(*scopes) - import datetime - expiresOn = datetime.datetime.fromtimestamp(token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f") + # Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility + # WARNING: expiresOn is deprecated and will be removed in future release. + import datetime + expiresOn = datetime.datetime.fromtimestamp(sdk_token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f") - token_entry = { - 'accessToken': token.token, - 'expires_on': token.expires_on, - 'expiresOn': expiresOn - } + token_entry = { + 'accessToken': sdk_token.token, + 'expires_on': sdk_token.expires_on, # epoch int, like 1605238724 + 'expiresOn': expiresOn # datetime string, like "2020-11-12 13:50:47.114324" + } + + # (tokenType, accessToken, tokenEntry) + creds = 'Bearer', sdk_token.token, token_entry - # (tokenType, accessToken, tokenEntry) - creds = 'Bearer', token.token, token_entry # (cred, subscription, tenant) return (creds, None if tenant else str(account[_SUBSCRIPTION_ID]), @@ -695,13 +699,6 @@ def get_installation_id(self): self._storage[_INSTALLATION_ID] = installation_id return installation_id - def _get_token_from_cloud_shell(self, resource): # pylint: disable=no-self-use - from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper - auth = MSIAuthenticationWrapper(resource=resource) - auth.set_token() - token_entry = auth.token - return (token_entry['token_type'], token_entry['access_token'], token_entry) - class MsiAccountTypes: # pylint: disable=no-method-argument,no-self-argument diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py index e9b92d1d3eb..49cca6e454e 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py @@ -134,6 +134,13 @@ def _build_sdk_access_token(token_entry): import time request_time = int(time.time()) + # MSAL token entry sample: + # { + # 'access_token': 'eyJ0eXAiOiJKV...', + # 'token_type': 'Bearer', + # 'expires_in': 1618 + # } + # Importing azure.core.credentials.AccessToken is expensive. # This can slow down commands that doesn't need azure.core, like `az account get-access-token`. # So We define our own AccessToken. diff --git a/src/azure-cli/azure/cli/command_modules/profile/custom.py b/src/azure-cli/azure/cli/command_modules/profile/custom.py index bafbdf24471..bfaebe18c2d 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/custom.py +++ b/src/azure-cli/azure/cli/command_modules/profile/custom.py @@ -78,7 +78,7 @@ def get_access_token(cmd, subscription=None, resource=None, scopes=None, resourc 'tokenType': creds[0], 'accessToken': creds[1], # 'expires_on': creds[2].get('expires_on', None), - 'expiresOn': creds[2].get('expiresOn', None), + 'expiresOn': creds[2]['expiresOn'], 'tenant': tenant } if subscription: From 5859d97110f30fa24570547c97c0fede78ba126a Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Fri, 5 Nov 2021 17:23:54 +0800 Subject: [PATCH 2/7] unittest --- .../cli/core/auth/adal_authentication.py | 3 +- .../cli/core/auth/msal_authentication.py | 4 +- .../azure/cli/core/auth/tests/test_util.py | 2 +- .../azure/cli/core/auth/util.py | 5 ++ .../azure/cli/core/tests/test_profile.py | 51 +++++++++++-------- 5 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py index f56741afef1..521ec42aeb8 100644 --- a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py @@ -4,11 +4,10 @@ # -------------------------------------------------------------------------------------------- import requests -from azure.core.credentials import AccessToken from knack.log import get_logger from msrestazure.azure_active_directory import MSIAuthentication -from .util import _normalize_scopes, scopes_to_resource +from .util import _normalize_scopes, scopes_to_resource, AccessToken logger = get_logger(__name__) diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py index 49cca6e454e..97cb30b2a8c 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py @@ -14,7 +14,7 @@ from knack.util import CLIError from msal import PublicClientApplication, ConfidentialClientApplication -from .util import check_result +from .util import check_result, AccessToken # OAuth 2.0 client credentials flow parameter # https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow @@ -144,6 +144,4 @@ def _build_sdk_access_token(token_entry): # Importing azure.core.credentials.AccessToken is expensive. # This can slow down commands that doesn't need azure.core, like `az account get-access-token`. # So We define our own AccessToken. - from collections import namedtuple - AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"]) diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py index f5db382d736..c96e5a446ed 100644 --- a/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py @@ -6,7 +6,7 @@ # pylint: disable=protected-access import unittest -from ..util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command +from azure.cli.core.auth.util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command class TestUtil(unittest.TestCase): diff --git a/src/azure-cli-core/azure/cli/core/auth/util.py b/src/azure-cli-core/azure/cli/core/auth/util.py index ff6af520938..0186cdc3ad2 100644 --- a/src/azure-cli-core/azure/cli/core/auth/util.py +++ b/src/azure-cli-core/azure/cli/core/auth/util.py @@ -4,11 +4,16 @@ # -------------------------------------------------------------------------------------------- import os +from collections import namedtuple + from knack.log import get_logger logger = get_logger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) + + def aad_error_handler(error, **kwargs): """ Handle the error from AAD server returned by ADAL or MSAL. """ diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 6c654dbf58b..a71c58dbfc4 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -5,31 +5,24 @@ # pylint: disable=protected-access import json -import os -import sys import unittest -from unittest import mock -import re -import datetime - from copy import deepcopy - -from azure.core.credentials import AccessToken +from unittest import mock from azure.cli.core._profile import (Profile, SubscriptionFinder, _attach_token_tenant, _transform_subscription_for_multiapi) - -from azure.mgmt.resource.subscriptions.models import \ - (Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant, TenantIdDescription) - +from azure.cli.core.auth.util import AccessToken from azure.cli.core.mock import DummyCli from azure.identity import AuthenticationRecord +from azure.mgmt.resource.subscriptions.models import \ + (Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant) from knack.util import CLIError - MOCK_ACCESS_TOKEN = "mock_access_token" -MOCK_EXPIRES_ON = 1630920323 +MOCK_EXPIRES_ON_STR = "1630920323" +MOCK_EXPIRES_ON_INT = 1630920323 +MOCK_EXPIRES_ON_DATETIME = '2021-09-06 17:25:23.000000' BEARER = 'Bearer' @@ -43,14 +36,15 @@ def get_token(self, *scopes, **kwargs): import time now = int(time.time()) # Mock sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py:230 - return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON) + return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON_INT) class MSRestAzureAuthStub: def __init__(self, *args, **kwargs): self._token = { 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token + 'access_token': TestProfile.test_msi_access_token, + 'expires_on': MOCK_EXPIRES_ON_STR } self.set_token_invoked_count = 0 self.token_read_count = 0 @@ -70,6 +64,9 @@ def token(self): def token(self, value): self._token = value + def get_token(self, *args, **kwargs): + return AccessToken(self.token['access_token'], int(self.token['expires_on'])) + class TestProfile(unittest.TestCase): @@ -1049,7 +1046,8 @@ def test_get_raw_token(self): self.assertEqual(creds[0], 'Bearer') self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription should be set self.assertEqual(sub, self.subscription1.subscription_id) @@ -1060,7 +1058,8 @@ def test_get_raw_token(self): self.assertEqual(creds[0], 'Bearer') self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription shouldn't be set self.assertIsNone(sub) @@ -1084,7 +1083,8 @@ def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) # the last in the tuple is the whole token entry which has several fields - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription should be set self.assertEqual(sub, self.subscription1.subscription_id) @@ -1095,7 +1095,8 @@ def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription shouldn't be set self.assertIsNone(sub) @@ -1124,11 +1125,15 @@ def test_get_raw_token_msi_system_assigned(self, mock_msi_auth): self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') self.assertEqual(cred[1], TestProfile.test_msi_access_token) + + # Make sure expires_on and expiresOn are set + self.assertEqual(cred[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(cred[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(tenant_id, test_tenant_id) # verify tenant shouldn't be specified for MSI account - with self.assertRaisesRegexp(CLIError, "MSI"): + with self.assertRaisesRegexp(CLIError, "Tenant shouldn't be specified"): cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) @mock.patch('azure.cli.core._profile.in_cloud_console', autospec=True) @@ -1157,6 +1162,10 @@ def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_conso self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') self.assertEqual(cred[1], TestProfile.test_msi_access_token) + + # Make sure expires_on and expiresOn are set + self.assertEqual(cred[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(cred[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(tenant_id, test_tenant_id) From 71d161e5b162256e58c2b06eedfb2dde08c10eeb Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Mon, 8 Nov 2021 10:55:58 +0800 Subject: [PATCH 3/7] datetime --- src/azure-cli-core/azure/cli/core/tests/test_profile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index a71c58dbfc4..8bde2197205 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -5,6 +5,7 @@ # pylint: disable=protected-access import json +import datetime import unittest from copy import deepcopy from unittest import mock @@ -22,7 +23,7 @@ MOCK_ACCESS_TOKEN = "mock_access_token" MOCK_EXPIRES_ON_STR = "1630920323" MOCK_EXPIRES_ON_INT = 1630920323 -MOCK_EXPIRES_ON_DATETIME = '2021-09-06 17:25:23.000000' +MOCK_EXPIRES_ON_DATETIME = datetime.datetime.fromtimestamp(MOCK_EXPIRES_ON_INT).strftime("%Y-%m-%d %H:%M:%S.%f") BEARER = 'Bearer' From f65d76ac06bf22b826e781b0d76fb8f8b16eebc4 Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Mon, 8 Nov 2021 14:46:02 +0800 Subject: [PATCH 4/7] resouce --- src/azure-cli-core/azure/cli/core/_profile.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index 1df9e5383a3..aca2318105f 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -362,21 +362,22 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No account = self.get_subscription(subscription) - # This is not used anyway, just a placeholder - mi_resoure = self.cli_ctx.cloud.endpoints.active_directory_resource_id - identity_type, identity_id = Profile._try_parse_msi_account_name(account) if identity_type: # managed identity if tenant: raise CLIError("Tenant shouldn't be specified for managed identity account") - msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, mi_resoure) + from .auth.util import scopes_to_resource + msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, + scopes_to_resource(scopes)) sdk_token = msi_creds.get_token(*scopes) elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): # Cloud Shell if tenant: raise CLIError("Tenant shouldn't be specified for Cloud Shell account") - msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id, mi_resoure) + from .auth.util import scopes_to_resource + msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id, + scopes_to_resource(scopes)) sdk_token = msi_creds.get_token(*scopes) else: credential = self._create_credential(account, tenant) From 4cf877e51ba991ed37c2fff01f866913c1beead3 Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Mon, 8 Nov 2021 15:41:51 +0800 Subject: [PATCH 5/7] Add tests --- .../azure/cli/core/tests/test_profile.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 8bde2197205..5445cde370f 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -30,17 +30,20 @@ class MockCredential: def __init__(self, *args, **kwargs): + self.get_token_scopes = None super().__init__() def get_token(self, *scopes, **kwargs): + self.get_token_scopes = scopes from azure.core.credentials import AccessToken - import time - now = int(time.time()) # Mock sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py:230 return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON_INT) class MSRestAzureAuthStub: + + return_value = None + def __init__(self, *args, **kwargs): self._token = { 'token_type': 'Bearer', @@ -49,9 +52,12 @@ def __init__(self, *args, **kwargs): } self.set_token_invoked_count = 0 self.token_read_count = 0 + self.get_token_scopes = None self.client_id = kwargs.get('client_id') self.object_id = kwargs.get('object_id') self.msi_res_id = kwargs.get('msi_res_id') + self.resource = kwargs.get('resource') + MSRestAzureAuthStub.return_value = self def set_token(self): self.set_token_invoked_count += 1 @@ -66,6 +72,7 @@ def token(self, value): self._token = value def get_token(self, *args, **kwargs): + self.get_token_scopes = args return AccessToken(self.token['access_token'], int(self.token['expires_on'])) @@ -261,7 +268,8 @@ def setUpClass(cls): 'authority_type': 'MSSTS' }] - cls.msal_scopes = ['https://foo/.default'] + cls.adal_resource = 'https://foo/' + cls.msal_scopes = ['https://foo//.default'] cls.service_principal_id = "00000001-0000-0000-0000-000000000000" cls.service_principal_secret = "test_secret" @@ -1037,7 +1045,7 @@ def test_get_raw_token(self): # action # Get token with ADAL-style resource - resource_result = profile.get_raw_token(resource='https://foo') + resource_result = profile.get_raw_token(resource=self.adal_resource) # Get token with MSAL-style scopes scopes_result = profile.get_raw_token(scopes=self.msal_scopes) @@ -1055,7 +1063,7 @@ def test_get_raw_token(self): self.assertEqual(tenant, self.tenant_id) # Test get_raw_token with tenant - creds, sub, tenant = profile.get_raw_token(resource='https://foo', tenant=self.tenant_id) + creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource, tenant=self.tenant_id) self.assertEqual(creds[0], 'Bearer') self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) @@ -1068,7 +1076,8 @@ def test_get_raw_token(self): @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential') def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): - get_service_principal_credential_mock.return_value = MockCredential() + credential_mock = MockCredential() + get_service_principal_credential_mock.return_value = credential_mock cli = DummyCli() # setup storage_mock = {'subscriptions': None} @@ -1078,9 +1087,11 @@ def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): True) profile._set_subscriptions(consolidated) # action - creds, sub, tenant = profile.get_raw_token(resource='https://foo') + creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource) # verify + assert list(credential_mock.get_token_scopes) == self.msal_scopes + self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) # the last in the tuple is the whole token entry which has several fields @@ -1092,7 +1103,7 @@ def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): self.assertEqual(tenant, self.tenant_id) # Test get_raw_token with tenant - creds, sub, tenant = profile.get_raw_token(resource='https://foo', tenant=self.tenant_id) + creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource, tenant=self.tenant_id) self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) @@ -1120,9 +1131,12 @@ def test_get_raw_token_msi_system_assigned(self, mock_msi_auth): mock_msi_auth.side_effect = MSRestAzureAuthStub # action - cred, subscription_id, tenant_id = profile.get_raw_token(resource='http://test_resource') + cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) + + # Make sure resource/scopes are passed to MSIAuthenticationWrapper + assert MSRestAzureAuthStub.return_value.resource == self.adal_resource + assert list(MSRestAzureAuthStub.return_value.get_token_scopes) == self.msal_scopes - # assert self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') self.assertEqual(cred[1], TestProfile.test_msi_access_token) @@ -1157,9 +1171,12 @@ def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_conso mock_msi_auth.side_effect = MSRestAzureAuthStub # action - cred, subscription_id, tenant_id = profile.get_raw_token(resource='http://test_resource') + cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) + + # Make sure resource/scopes are passed to MSIAuthenticationWrapper + assert MSRestAzureAuthStub.return_value.resource == self.adal_resource + assert list(MSRestAzureAuthStub.return_value.get_token_scopes) == self.msal_scopes - # assert self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') self.assertEqual(cred[1], TestProfile.test_msi_access_token) From 173d46910df981d831cd8ceef57b10e412fb374a Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Mon, 15 Nov 2021 18:07:01 +0800 Subject: [PATCH 6/7] comment --- src/azure-cli-core/azure/cli/core/_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index aca2318105f..28a7ba493e2 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -372,7 +372,7 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No scopes_to_resource(scopes)) sdk_token = msi_creds.get_token(*scopes) elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - # Cloud Shell + # Cloud Shell, which is just a system-assigned managed identity. if tenant: raise CLIError("Tenant shouldn't be specified for Cloud Shell account") from .auth.util import scopes_to_resource From f0746f0c0396df5eeb017d194a98ab94657f86b7 Mon Sep 17 00:00:00 2001 From: jiasli <4003950+jiasli@users.noreply.github.com> Date: Tue, 16 Nov 2021 21:45:50 +0800 Subject: [PATCH 7/7] refine tests --- .../azure/cli/core/tests/test_profile.py | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index f51cdc1f88b..df43c5f66da 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -30,13 +30,13 @@ class CredentialMock: def __init__(self, *args, **kwargs): + # If get_token_scopes is checked, make sure to create a new instance of CredentialMock + # to avoid interference from other tests. self.get_token_scopes = None super().__init__() def get_token(self, *scopes, **kwargs): self.get_token_scopes = scopes - from azure.core.credentials import AccessToken - # Mock sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py:230 return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON_INT) @@ -48,8 +48,6 @@ def get_token(self, *scopes, **kwargs): class MSRestAzureAuthStub: - return_value = None - def __init__(self, *args, **kwargs): self._token = { 'token_type': 'Bearer', @@ -63,7 +61,6 @@ def __init__(self, *args, **kwargs): self.object_id = kwargs.get('object_id') self.msi_res_id = kwargs.get('msi_res_id') self.resource = kwargs.get('resource') - MSRestAzureAuthStub.return_value = self def set_token(self): self.set_token_invoked_count += 1 @@ -1039,8 +1036,10 @@ def test_get_login_credentials_msi_user_assigned_with_res_id(self): self.assertTrue(cred.token_read_count) self.assertTrue(cred.msi_res_id, test_res_id) - @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', return_value=credential_mock) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential') def test_get_raw_token(self, get_user_credential_mock): + credential_mock_temp = CredentialMock() + get_user_credential_mock.return_value = credential_mock_temp cli = DummyCli() # setup storage_mock = {'subscriptions': None} @@ -1073,7 +1072,7 @@ def test_get_raw_token(self, get_user_credential_mock): creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource, tenant=self.tenant_id) # verify - assert list(credential_mock.get_token_scopes) == self.msal_scopes + assert list(credential_mock_temp.get_token_scopes) == self.msal_scopes self.assertEqual(creds[0], 'Bearer') self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) @@ -1084,8 +1083,10 @@ def test_get_raw_token(self, get_user_credential_mock): self.assertIsNone(sub) self.assertEqual(tenant, self.tenant_id) - @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential', return_value=credential_mock) + @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential') def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): + credential_mock_temp = CredentialMock() + get_service_principal_credential_mock.return_value = credential_mock_temp cli = DummyCli() # setup storage_mock = {'subscriptions': None} @@ -1098,7 +1099,7 @@ def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource) # verify - assert list(credential_mock.get_token_scopes) == self.msal_scopes + assert list(credential_mock_temp.get_token_scopes) == self.msal_scopes self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) @@ -1136,14 +1137,21 @@ def test_get_raw_token_msi_system_assigned(self, mock_msi_auth): True) profile._set_subscriptions(consolidated) - mock_msi_auth.side_effect = MSRestAzureAuthStub + mi_auth_instance = None + + def mi_auth_factory(*args, **kwargs): + nonlocal mi_auth_instance + mi_auth_instance = MSRestAzureAuthStub(*args, **kwargs) + return mi_auth_instance + + mock_msi_auth.side_effect = mi_auth_factory # action cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) # Make sure resource/scopes are passed to MSIAuthenticationWrapper - assert MSRestAzureAuthStub.return_value.resource == self.adal_resource - assert list(MSRestAzureAuthStub.return_value.get_token_scopes) == self.msal_scopes + assert mi_auth_instance.resource == self.adal_resource + assert list(mi_auth_instance.get_token_scopes) == self.msal_scopes self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') @@ -1176,14 +1184,21 @@ def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_conso consolidated[0]['user']['cloudShellID'] = True profile._set_subscriptions(consolidated) - mock_msi_auth.side_effect = MSRestAzureAuthStub + mi_auth_instance = None + + def mi_auth_factory(*args, **kwargs): + nonlocal mi_auth_instance + mi_auth_instance = MSRestAzureAuthStub(*args, **kwargs) + return mi_auth_instance + + mock_msi_auth.side_effect = mi_auth_factory # action cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) # Make sure resource/scopes are passed to MSIAuthenticationWrapper - assert MSRestAzureAuthStub.return_value.resource == self.adal_resource - assert list(MSRestAzureAuthStub.return_value.get_token_scopes) == self.msal_scopes + assert mi_auth_instance.resource == self.adal_resource + assert list(mi_auth_instance.get_token_scopes) == self.msal_scopes self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer')