diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index 203ed47..654eacd 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -15,7 +15,8 @@ from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn, \ + is_integration_pattern_valid LAMBDA_SERVICE_NAME = "lambda" GLUE_SERVICE_NAME = "glue" @@ -161,16 +162,25 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): class EcsRunTaskStep(Task): - """ Creates a Task State to run Amazon ECS or Fargate Tasks. See `Manage Amazon ECS or Fargate Tasks with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + supported_integration_patterns = [ + IntegrationPattern.WaitForCompletion, + IntegrationPattern.WaitForTaskToken, + IntegrationPattern.CallAndContinue + ] + + def __init__(self, state_id, wait_for_completion=True, integration_pattern=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the ecs job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the ecs job and proceed to the next step. (default: True) + integration_pattern (stepfunctions.steps.integration_resources.IntegrationPattern, optional): Service integration pattern used to call the integrated service. This is mutually exclusive from wait_for_completion Supported integration patterns (default: None): + * WaitForCompletion: Wait for the state machine execution to complete before going to the next state. (See `Run A Job `_ for more details.) + * WaitForTaskToken: Wait for the state machine execution to return a task token before progressing to the next state (See `Wait for a Callback with the Task Token `_ for more details.) + * CallAndContinue: Call StartExecution and progress to the next state (See `Request Response `_ for more details.) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. @@ -181,20 +191,23 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ + if wait_for_completion and integration_pattern: + raise ValueError( + "Only one of wait_for_completion and integration_pattern set. " + "Set wait_for_completion to False if you wish to use integration_pattern." + ) + + # The old implementation type still has to be supported until a new + # major is realeased. if wait_for_completion: - """ - Example resource arn: arn:aws:states:::ecs:runTask.sync - """ - - kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, - EcsApi.RunTask, - IntegrationPattern.WaitForCompletion) - else: - """ - Example resource arn: arn:aws:states:::ecs:runTask - """ - - kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, - EcsApi.RunTask) + integration_pattern = IntegrationPattern.WaitForCompletion + if not wait_for_completion and not integration_pattern: + integration_pattern = IntegrationPattern.CallAndContinue + + is_integration_pattern_valid(integration_pattern, + self.supported_integration_patterns) + kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, + EcsApi.RunTask, + integration_pattern) super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/tests/unit/test_compute_steps.py b/tests/unit/test_compute_steps.py index 368010a..925b949 100644 --- a/tests/unit/test_compute_steps.py +++ b/tests/unit/test_compute_steps.py @@ -17,6 +17,7 @@ from unittest.mock import patch from stepfunctions.steps.compute import LambdaStep, GlueStartJobRunStep, BatchSubmitJobStep, EcsRunTaskStep +from stepfunctions.steps.integration_resources import IntegrationPattern @patch.object(boto3.session.Session, 'region_name', 'us-east-1') @@ -100,24 +101,72 @@ def test_batch_submit_job_step_creation(): @patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_ecs_run_task_step_creation(): - step = EcsRunTaskStep('Ecs Job', wait_for_completion=False) +@pytest.mark.parametrize( + ("task_kwargs",), + [ + ({},), + ({ + "integration_pattern": IntegrationPattern.WaitForCompletion, + "wait_for_completion": False, + },), + ] +) +def test_ecs_run_task_with_wait_for_completion(task_kwargs): + step = EcsRunTaskStep('ECS Job', **task_kwargs) assert step.to_dict() == { 'Type': 'Task', - 'Resource': 'arn:aws:states:::ecs:runTask', + 'Resource': 'arn:aws:states:::ecs:runTask.sync', 'End': True } - step = EcsRunTaskStep('Ecs Job', parameters={ - 'TaskDefinition': 'Task' - }) + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +@pytest.mark.parametrize( + ("task_kwargs",), + [ + ({ + "integration_pattern": IntegrationPattern.WaitForTaskToken, + "wait_for_completion": False, + },), + ] +) +def test_ecs_run_task_with_wait_for_task_token(task_kwargs): + step = EcsRunTaskStep('ECS Job', **task_kwargs) assert step.to_dict() == { 'Type': 'Task', - 'Resource': 'arn:aws:states:::ecs:runTask.sync', - 'Parameters': { - 'TaskDefinition': 'Task' - }, + 'Resource': 'arn:aws:states:::ecs:runTask.waitForTaskToken', + 'End': True + } + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +@pytest.mark.parametrize( + ("task_kwargs",), + [ + ({ + "wait_for_completion": False, + },), + ({ + "integration_pattern": IntegrationPattern.CallAndContinue, + "wait_for_completion": False, + },) + ] +) +def test_ecs_run_task_with_call_and_continue(task_kwargs): + step = EcsRunTaskStep('ECS Job', **task_kwargs) + + assert step.to_dict() == { + 'Type': 'Task', + 'Resource': 'arn:aws:states:::ecs:runTask', 'End': True } + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_ecs_run_task_with_conflicting_arguments(): + with pytest.raises(ValueError): + step = EcsRunTaskStep('Ecs Job', + wait_for_completion=True, + integration_pattern=IntegrationPattern.WaitForTaskToken)