Skip to content

Commit

Permalink
[AKS] az aks command invoke: Add progress spinner (#30274)
Browse files Browse the repository at this point in the history
  • Loading branch information
CustardTart32 authored Nov 13, 2024
1 parent e8bcbfb commit eb55296
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 4 deletions.
40 changes: 40 additions & 0 deletions src/azure-cli/azure/cli/command_modules/acs/_polling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from typing import Dict

from azure.core.polling.base_polling import LocationPolling, _is_empty, BadResponse, _as_json


class RunCommandLocationPolling(LocationPolling):
"""Extends LocationPolling but uses the body content instead of the status code for the status"""

@staticmethod
def _get_provisioning_state(response):
"""Attempt to get provisioning state from resource.
:param azure.core.pipeline.transport.HttpResponse response: latest REST call response.
:returns: Status if found, else 'None'.
"""
if _is_empty(response):
return None
body: Dict = _as_json(response)
return body.get("properties", {}).get("provisioningState")

def get_status(self, pipeline_response):
"""Process the latest status update retrieved from the same URL as
the previous request.
:param azure.core.pipeline.PipelineResponse response: latest REST call response.
:raises: BadResponse if status not 200 or 204.
"""
response = pipeline_response.http_response
if _is_empty(response):
raise BadResponse(
"The response from long running operation does not contain a body."
)

status = self._get_provisioning_state(response)
return status or "Unknown"
33 changes: 29 additions & 4 deletions src/azure-cli/azure/cli/command_modules/acs/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
CONST_AZURE_SERVICE_MESH_MODE_ISTIO,
CONST_MANAGED_CLUSTER_SKU_TIER_PREMIUM,
)

from azure.cli.command_modules.acs._polling import RunCommandLocationPolling
from azure.cli.command_modules.acs._helpers import get_snapshot_by_snapshot_id, check_is_private_link_cluster
from azure.cli.command_modules.acs._resourcegroup import get_rg_location
from azure.cli.command_modules.acs._validators import extract_comma_separated_string
Expand All @@ -95,9 +95,11 @@
from azure.cli.core.commands import LongRunningOperation
from azure.cli.core.commands.client_factory import get_subscription_id
from azure.cli.core.profiles import ResourceType
from azure.mgmt.core.polling.arm_polling import ARMPolling
from azure.cli.core.util import in_cloud_console, sdk_no_wait
from azure.core.exceptions import ResourceNotFoundError as ResourceNotFoundErrorAzCore
from azure.mgmt.containerservice.models import KubernetesSupportPlan
from humanfriendly.terminal.spinners import Spinner
from knack.log import get_logger
from knack.prompting import NoTTYException, prompt_y_n
from knack.util import CLIError
Expand Down Expand Up @@ -2034,6 +2036,7 @@ def aks_runcommand(cmd, client, resource_group_name, name, command_string="", co

if not command_string:
raise ValidationError('Command cannot be empty.')

RunCommandRequest = cmd.get_models('RunCommandRequest', resource_type=ResourceType.MGMT_CONTAINERSERVICE,
operation_group='managed_clusters')
request_payload = RunCommandRequest(command=command_string)
Expand All @@ -2046,8 +2049,15 @@ def aks_runcommand(cmd, client, resource_group_name, name, command_string="", co
request_payload.cluster_token = _get_dataplane_aad_token(
cmd.cli_ctx, "6dae42f8-4368-4678-94ff-3960e28e3630")

polling_interval = 5
retry_total = 0

command_result_poller = sdk_no_wait(
no_wait, client.begin_run_command, resource_group_name, name, request_payload, polling_interval=5, retry_total=0
no_wait, client.begin_run_command, resource_group_name, name, request_payload,
# NOTE: Note sure if retry_total is used in ARMPolling
polling=ARMPolling(polling_interval, lro_options={"final-state-via": "location"}, lro_algorithms=[RunCommandLocationPolling()], retry_total=retry_total),
polling_interval=polling_interval,
retry_total=retry_total
)
if no_wait:
# pylint: disable=protected-access
Expand All @@ -2058,7 +2068,22 @@ def aks_runcommand(cmd, client, resource_group_name, name, command_string="", co
command_id = command_id_regex.findall(command_result_polling_url)[0]
_aks_command_result_in_progess_helper(client, resource_group_name, name, command_id)
return
return _print_command_result(cmd.cli_ctx, command_result_poller.result(300))

spinner = Spinner(label='Running', stream=sys.stderr, hide_cursor=False)
progress_controller = cmd.cli_ctx.get_progress_controller(det=False, spinner=spinner)

now = datetime.datetime.now()
progress_controller.begin()
while not command_result_poller.done():
if datetime.datetime.now() - now >= datetime.timedelta(seconds=300):
break

progress_controller.add(message=command_result_poller.status())
progress_controller.update()
time.sleep(0.5)

progress_controller.end()
return _print_command_result(cmd.cli_ctx, command_result_poller.result(timeout=0))


def aks_command_result(cmd, client, resource_group_name, name, command_id=""):
Expand Down Expand Up @@ -2115,7 +2140,7 @@ def _print_command_result(cli_ctx, commandResult):
return

# *-ing state
print(f"{colorama.Fore.BLUE}command is in {commandResult.provisioning_state} state{colorama.Style.RESET_ALL}")
print(f"{colorama.Fore.BLUE}command (id: {commandResult.id}) is in {commandResult.provisioning_state} state{colorama.Style.RESET_ALL}")
return


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import unittest
from unittest.mock import Mock

from azure.cli.command_modules.acs._polling import RunCommandLocationPolling
from azure.core.pipeline import PipelineResponse
from azure.core.rest import HttpRequest, HttpResponse


class TestRunCommandPoller(unittest.TestCase):
def test_get_status(self):
poller = RunCommandLocationPolling()

mock_response = Mock(spec=HttpResponse)
mock_response.text.return_value = "{\"properties\": {\"provisioningState\": \"Scaling Up\"}}"

pipeline_response: PipelineResponse[HttpRequest, HttpResponse] = PipelineResponse(Mock(spec=HttpRequest), mock_response, Mock())

status = poller.get_status(pipeline_response)
assert status == "Scaling Up"

def test_get_status_no_provisioning_state(self):
poller = RunCommandLocationPolling()

mock_response = Mock(spec=HttpResponse)
mock_response.text.return_value = "{\"properties\": {\"status\": \"Scaling Up\"}}"

pipeline_response: PipelineResponse[HttpRequest, HttpResponse] = PipelineResponse(Mock(spec=HttpRequest), mock_response, Mock())

status = poller.get_status(pipeline_response)
assert status == "Unknown"

0 comments on commit eb55296

Please sign in to comment.