From 116c8f968b3f43ea36a2c2a31c8a91d90fd27ec2 Mon Sep 17 00:00:00 2001 From: Matt Ellis Date: Wed, 12 May 2021 13:41:39 +1000 Subject: [PATCH 1/4] Add optional support for adopting orphaned task instances from AwsBatchExecutor --- airflow_aws_executors/batch_executor.py | 48 ++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/airflow_aws_executors/batch_executor.py b/airflow_aws_executors/batch_executor.py index 98d9efc..3776b7c 100644 --- a/airflow_aws_executors/batch_executor.py +++ b/airflow_aws_executors/batch_executor.py @@ -1,4 +1,4 @@ -"""AWS Batch Executor. Each Airflow task gets deligated out to an AWS Batch Job""" +"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job""" import time from copy import deepcopy @@ -7,6 +7,7 @@ import boto3 from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor +from airflow.models import TaskInstance from airflow.utils.module_loading import import_string from airflow.utils.state import State from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load @@ -73,6 +74,7 @@ def __init__(self, *args, **kwargs): self.active_workers: Optional[BatchJobCollection] = None self.batch = None self.submit_job_kwargs = None + self.adopt_task_instances = None def start(self): """Initialize Boto3 Batch Client, and other internal variables""" @@ -80,6 +82,7 @@ def start(self): self.active_workers = BatchJobCollection() self.batch = boto3.client('batch', region_name=region) self.submit_job_kwargs = self._load_submit_kwargs() + self.adopt_task_instances = conf.getboolean('batch', 'adopt_task_instances', fallback=False) def sync(self): """Checks and update state on all running tasks""" @@ -128,6 +131,9 @@ def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=No job_id = self._submit_job(key, command, queue, executor_config or {}) self.active_workers.add_job(job_id, key) + # Add batch job_id to executor event buffer, which gets saved in TaskInstance.external_executor_id + self.event_buffer[key] = (State.QUEUED, job_id) + def _submit_job( self, key: TaskInstanceKeyType, @@ -184,7 +190,11 @@ def end(self, heartbeat_interval=10): def terminate(self): """ Kill all Batch Jobs by calling Boto3's TerminateJob API. + Do not kill Batch Jobs if [batch].adopt_task_instances option is set to True """ + if self.adopt_task_instances: + pass + for job_id in self.active_workers.get_all_jobs(): self.batch.terminate_job( jobId=job_id, @@ -192,6 +202,42 @@ def terminate(self): ) self.end() + def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + """ + If [batch].adopt_task_instances option is set to True, try to adopt running task instances that have been + abandoned by a SchedulerJob dying. + These tasks instances should have a corresponding AWS Batch Job which can be adopted by the unique job_id. + + Anything that is not adopted will be cleared by the scheduler (and then become eligible for re-scheduling) + + :return: any TaskInstances that were unable to be adopted + :rtype: list[airflow.models.TaskInstance] + """ + if not self.adopt_task_instances: + # Do not try to adopt task instances, return all orphaned tasks for clearing + return tis + + adopted_tis: List[TaskInstance] = [] + not_adopted_tis: List[TaskInstance] = [] + + for ti in tis: + if ti.external_executor_id is not None: + self.active_workers.add_job(ti.external_executor_id, ti.key) + adopted_tis.append(ti) + else: + not_adopted_tis.append(ti) + + if adopted_tis: + tasks = [f'{task} in state {task.state}' for task in adopted_tis] + task_instance_str = '\n\t'.join(tasks) + self.log.info( + 'Adopted the following %d tasks from a dead executor:\n\t%s', + len(adopted_tis), + task_instance_str, + ) + + return not_adopted_tis + @staticmethod def _load_submit_kwargs() -> dict: submit_kwargs = import_string( From 2dddc52a25d263c6951c395b444696d95e43c858 Mon Sep 17 00:00:00 2001 From: Matt Ellis Date: Wed, 12 May 2021 13:44:37 +1000 Subject: [PATCH 2/4] Add optional support for adopting orphaned task instances from AwsEcsFargateExecutor --- airflow_aws_executors/ecs_fargate_executor.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/airflow_aws_executors/ecs_fargate_executor.py b/airflow_aws_executors/ecs_fargate_executor.py index ee0c4fe..82da328 100644 --- a/airflow_aws_executors/ecs_fargate_executor.py +++ b/airflow_aws_executors/ecs_fargate_executor.py @@ -8,6 +8,7 @@ import boto3 from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor +from airflow.models import TaskInstance from airflow.utils.module_loading import import_string from airflow.utils.state import State from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load @@ -93,6 +94,7 @@ def __init__(self, *args, **kwargs): self.pending_tasks: Optional[deque] = None self.ecs = None self.run_task_kwargs = None + self.adopt_task_instances = None def start(self): """Initialize Boto3 ECS Client, and other internal variables""" @@ -103,6 +105,7 @@ def start(self): self.pending_tasks = deque() self.ecs = boto3.client('ecs', region_name=region) # noqa self.run_task_kwargs = self._load_run_kwargs() + self.adopt_task_instances = conf.getboolean('ecs_fargate', 'adopt_task_instances', fallback=False) def sync(self): self.sync_running_tasks() @@ -207,6 +210,8 @@ def attempt_task_runs(self): else: task = run_task_response['tasks'][0] self.active_workers.add_task(task, task_key, queue, cmd, exec_config) + # Add fargate task arn to executor event buffer, which gets saved in TaskInstance.external_executor_id + self.event_buffer[task_key] = (State.QUEUED, task.task_arn) if failure_reasons: self.log.debug('Pending tasks failed to launch for the following reasons: %s. Will retry later.', dict(failure_reasons)) @@ -267,7 +272,11 @@ def end(self, heartbeat_interval=10): def terminate(self): """ Kill all ECS processes by calling Boto3's StopTask API. + Do not kill ECS processes if [ecs_fargate].adopt_task_instances option is set to True """ + if self.adopt_task_instances: + pass + for arn in self.active_workers.get_all_arns(): self.ecs.stop_task( cluster=self.cluster, @@ -276,6 +285,44 @@ def terminate(self): ) self.end() + def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + """ + If [ecs_fargate].adopt_task_instances option is set to True, try to adopt running task instances that have been + abandoned by a SchedulerJob dying. + These tasks instances should have a corresponding ECS process which can be adopted by the unique task arn. + + Anything that is not adopted will be cleared by the scheduler (and then become eligible for re-scheduling) + + :return: any TaskInstances that were unable to be adopted + :rtype: list[airflow.models.TaskInstance] + """ + if not self.adopt_task_instances: + # Do not try to adopt task instances, return all orphaned tasks for clearing + return tis + + adopted_tis: List[TaskInstance] = [] + + task_arns = [ti.external_executor_id for ti in tis if ti.external_executor_id] + if task_arns: + task_descriptions = self.__describe_tasks(task_arns).get('tasks', []) + + for task in task_descriptions: + ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0] + self.active_workers.add_task(task, ti.key, ti.queue, ti.command_as_list(), ti.executor_config) + adopted_tis.append(ti) + + if adopted_tis: + tasks = [f'{task} in state {task.state}' for task in adopted_tis] + task_instance_str = '\n\t'.join(tasks) + self.log.info( + 'Adopted the following %d tasks from a dead executor:\n\t%s', + len(adopted_tis), + task_instance_str, + ) + + not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] + return not_adopted_tis + def _load_run_kwargs(self) -> dict: run_kwargs = import_string( conf.get( From a0662da82e8509f86be511af8cc450b9bd7b5d7a Mon Sep 17 00:00:00 2001 From: Matt Ellis Date: Wed, 12 May 2021 13:46:44 +1000 Subject: [PATCH 3/4] Add unit test cases for adopting orphaned task instances --- tests/test_batch_executor.py | 55 +++++++++++++++++++-- tests/test_ecs_fargate_executor.py | 78 ++++++++++++++++++++++++++++-- 2 files changed, 125 insertions(+), 8 deletions(-) diff --git a/tests/test_batch_executor.py b/tests/test_batch_executor.py index 78373ae..841736b 100644 --- a/tests/test_batch_executor.py +++ b/tests/test_batch_executor.py @@ -4,6 +4,7 @@ from airflow_aws_executors.batch_executor import ( AwsBatchExecutor, BatchJobDetailSchema, BatchJob, BatchJobCollection ) +from airflow.models import TaskInstance from airflow.utils.state import State from .botocore_helper import get_botocore_model, assert_botocore_call @@ -104,14 +105,17 @@ def test_execute(self): # task is stored in active worker self.assertEqual(1, len(self.executor.active_workers)) + # job_id is stored in executor event buffer + self.assertEqual((State.QUEUED, 'ABC'), self.executor.event_buffer[airflow_key]) + @mock.patch('airflow.executors.base_executor.BaseExecutor.fail') @mock.patch('airflow.executors.base_executor.BaseExecutor.success') def test_sync(self, success_mock, fail_mock): """Test synch from end-to-end. Mocks a successful job & makes sure it's removed""" - after_sync_reponse = self.__mock_sync() + after_sync_response = self.__mock_sync() # sanity check that container's status code is mocked to success - loaded_batch_job = BatchJobDetailSchema().load(after_sync_reponse) + loaded_batch_job = BatchJobDetailSchema().load(after_sync_response) self.assertEqual(State.SUCCESS, loaded_batch_job.get_job_state()) self.executor.sync() @@ -130,11 +134,11 @@ def test_sync(self, success_mock, fail_mock): @mock.patch('airflow.executors.base_executor.BaseExecutor.success') def test_failed_sync(self, success_mock, fail_mock): """Test failure states""" - after_sync_reponse = self.__mock_sync() + after_sync_response = self.__mock_sync() # set container's status code to failure & sanity-check - after_sync_reponse['status'] = 'FAILED' - self.assertEqual(State.FAILED, BatchJobDetailSchema().load(after_sync_reponse).get_job_state()) + after_sync_response['status'] = 'FAILED' + self.assertEqual(State.FAILED, BatchJobDetailSchema().load(after_sync_response).get_job_state()) self.executor.sync() # ensure that run_task is called correctly as defined by Botocore docs @@ -158,6 +162,47 @@ def test_terminate(self): self.assertTrue(self.executor.batch.terminate_job.called) self.assert_botocore_call('TerminateJob', *self.executor.batch.terminate_job.call_args) + def test_terminate_with_task_adoption(self): + """Test that executor does not shut down active Batch jobs when 'adopt_task_instances' is set to True""" + self.executor.adopt_task_instances = True + self.executor.terminate() + + # jobs are not terminated + self.assertFalse(self.executor.batch.terminate_job.called) + + def test_try_adopt_task_instances(self): + """Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event""" + self.executor.adopt_task_instances = True + + orphaned_tasks = [ + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + ] + orphaned_tasks[0].external_executor_id = None # One orphaned task has no external_executor_id + not_adopted_tasks = self.executor.try_adopt_task_instances(orphaned_tasks) + + # adopted tasks are stored in active workers + self.assertEqual(len(orphaned_tasks) - 1, len(self.executor.active_workers)) + + # one task is unable to be adopted + self.assertEqual(1, len(not_adopted_tasks)) + + def test_try_adopt_task_instances_disabled(self): + """Test that executor won't adopt orphaned task instances if 'adopt_task_instances' is set to False (default)""" + orphaned_tasks = [ + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + ] + not_adopted_tasks = self.executor.try_adopt_task_instances(orphaned_tasks) + + # no orphaned tasks are stored in active workers + self.assertEqual(0, len(self.executor.active_workers)) + + # all tasks are unable to be adopted + self.assertEqual(len(orphaned_tasks), len(not_adopted_tasks)) + def test_end(self): """The end() function should call sync 3 times, and the task should fail on the 3rd call""" sync_call_count = 0 diff --git a/tests/test_ecs_fargate_executor.py b/tests/test_ecs_fargate_executor.py index e79d5f8..97bb4e4 100644 --- a/tests/test_ecs_fargate_executor.py +++ b/tests/test_ecs_fargate_executor.py @@ -4,6 +4,7 @@ from airflow_aws_executors.ecs_fargate_executor import ( AwsEcsFargateExecutor, BotoTaskSchema, EcsFargateTask, EcsFargateTaskCollection ) +from airflow.models import TaskInstance from airflow.utils.state import State from .botocore_helper import get_botocore_model, assert_botocore_call @@ -184,6 +185,9 @@ def test_execute(self): self.assertEqual(1, len(self.executor.active_workers)) self.assertIn(self.executor.active_workers.task_by_key(airflow_key).task_arn, '001') + # task_arn is stored in executor event buffer + self.assertEqual((State.QUEUED, '001'), self.executor.event_buffer[airflow_key]) + def test_failed_execute_api(self): """Test what happens when FARGATE refuses to execute a task""" self.executor.ecs.run_task.return_value = { @@ -215,7 +219,7 @@ def test_sync(self, success_mock, fail_mock): self.executor.sync_running_tasks() - # ensure that run_task is called correctly as defined by Botocore docs + # ensure that describe_tasks is called correctly as defined by Botocore docs self.executor.ecs.describe_tasks.assert_called_once() self.assert_botocore_call('DescribeTasks', *self.executor.ecs.describe_tasks.call_args) @@ -236,7 +240,7 @@ def test_failed_sync(self, success_mock, fail_mock): self.assertEqual(State.FAILED, BotoTaskSchema().load(after_fargate_json).get_task_state()) self.executor.sync() - # ensure that run_task is called correctly as defined by Botocore docs + # ensure that describe_tasks is called correctly as defined by Botocore docs self.executor.ecs.describe_tasks.assert_called_once() self.assert_botocore_call('DescribeTasks', *self.executor.ecs.describe_tasks.call_args) @@ -263,7 +267,7 @@ def test_failed_sync_api(self, success_mock, fail_mock): # Call Sync 3 times with failures for check_count in range(AwsEcsFargateExecutor.MAX_FAILURE_CHECKS): self.executor.sync_running_tasks() - # ensure that run_task is called correctly as defined by Botocore docs + # ensure that describe_tasks is called correctly as defined by Botocore docs self.assertEqual(self.executor.ecs.describe_tasks.call_count, check_count + 1) self.assert_botocore_call('DescribeTasks', *self.executor.ecs.describe_tasks.call_args) @@ -291,6 +295,74 @@ def test_terminate(self): self.assertTrue(self.executor.ecs.stop_task.called) self.assert_botocore_call('StopTask', *self.executor.ecs.stop_task.call_args) + def test_terminate_with_task_adoption(self): + """Test that executor does not shut down active ECS tasks when 'adopt_task_instances' is set to True""" + self.executor.adopt_task_instances = True + self.executor.terminate() + + # tasks are not terminated + self.assertFalse(self.executor.ecs.stop_task.called) + + def test_try_adopt_task_instances(self): + """Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event""" + self.executor.adopt_task_instances = True + + self.executor.ecs.describe_tasks.return_value = { + 'tasks': [ + { + 'taskArn': '001', + 'lastStatus': 'RUNNING', + 'desiredStatus': 'RUNNING', + 'containers': [{'name': 'some-ecs-container'}] + }, + { + 'taskArn': '002', + 'lastStatus': 'RUNNING', + 'desiredStatus': 'RUNNING', + 'containers': [{'name': 'another-ecs-container'}] + } + ], + 'failures': [] + } + + orphaned_tasks = [ + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + ] + orphaned_tasks[0].external_executor_id = '001' # Matches a running task_arn + orphaned_tasks[1].external_executor_id = '002' # Matches a running task_arn + orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id + not_adopted_tasks = self.executor.try_adopt_task_instances(orphaned_tasks) + + # ensure that describe_tasks is called correctly as defined by Botocore docs + self.executor.ecs.describe_tasks.assert_called_once() + self.assert_botocore_call('DescribeTasks', *self.executor.ecs.describe_tasks.call_args) + + # adopted tasks are stored in active workers + self.assertEqual(len(orphaned_tasks) - 1, len(self.executor.active_workers)) + + # one task is unable to be adopted + self.assertEqual(1, len(not_adopted_tasks)) + + def test_try_adopt_task_instances_disabled(self): + """Test that executor won't adopt orphaned task instances if 'adopt_task_instances' is set to False (default)""" + orphaned_tasks = [ + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + mock.Mock(TaskInstance), + ] + not_adopted_tasks = self.executor.try_adopt_task_instances(orphaned_tasks) + + # ensure that describe_tasks is not called + self.executor.ecs.describe_tasks.assert_not_called() + + # no orphaned tasks are stored in active workers + self.assertEqual(0, len(self.executor.active_workers)) + + # all tasks are unable to be adopted + self.assertEqual(len(orphaned_tasks), len(not_adopted_tasks)) + def assert_botocore_call(self, method_name, args, kwargs): assert_botocore_call(self.ecs_model, method_name, args, kwargs) From 42f2e0bc853ed0f2fae7ac9da91ad2211cab541a Mon Sep 17 00:00:00 2001 From: Matt Ellis Date: Wed, 12 May 2021 13:47:37 +1000 Subject: [PATCH 4/4] Add 'adopt_task_instances' configuration option to readme.md --- readme.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/readme.md b/readme.md index e6002fd..041ff61 100644 --- a/readme.md +++ b/readme.md @@ -146,6 +146,12 @@ task = PythonOperator( To change the parameters used to run a task in Batch, the user can overwrite the path to specify another python dictionary. More documentation can be found in the `Extensibility` section below. * **default**: airflow_aws_executors.conf.BATCH_SUBMIT_JOB_KWARGS +* `adopt_task_instances` + * **description**: Boolean flag. If set to True, the executor will try to adopt orphaned task instances from a + SchedulerJob shutdown event (for example when a scheduler container is re-deployed or terminated). + If set to False (default), the executor will terminate all active AWS Batch Jobs when the scheduler shuts down. + More documentation can be found in the [airflow docs](https://airflow.apache.org/docs/apache-airflow/stable/scheduler.html#scheduler-tuneables). + * **default**: False #### ECS & FARGATE `[ecs_fargate]` * `region` @@ -181,6 +187,12 @@ task = PythonOperator( To change the parameters used to run a task in FARGATE or ECS, the user can overwrite the path to specify another python dictionary. More documentation can be found in the `Extensibility` section below. * **default**: airflow_aws_executors.conf.ECS_FARGATE_RUN_TASK_KWARGS +* `adopt_task_instances` + * **description**: Boolean flag. If set to True, the executor will try to adopt orphaned task instances from a + SchedulerJob shutdown event (for example when a scheduler container is re-deployed or terminated). + If set to False (default), the executor will terminate all active ECS Tasks when the scheduler shuts down. + More documentation can be found in the [airflow docs](https://airflow.apache.org/docs/apache-airflow/stable/scheduler.html#scheduler-tuneables). + * **default**: False *NOTE: Modify airflow.cfg or export environmental variables. For example:*