From 1f4593a33604a1ecc3bbe0ca21e1301d0881f92d Mon Sep 17 00:00:00 2001 From: jason810496 Date: Tue, 14 Jan 2025 09:51:57 +0800 Subject: [PATCH 01/11] Fix FileTaskHandler only read from default executor --- airflow/utils/log/file_task_handler.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 73ee79126a9ee..62647eedb25d0 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -24,7 +24,6 @@ from collections.abc import Iterable from contextlib import suppress from enum import Enum -from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from urllib.parse import urljoin @@ -44,6 +43,7 @@ if TYPE_CHECKING: from pendulum import DateTime + from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -314,10 +314,18 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO def _read_grouped_logs(self): return False - @cached_property - def _executor_get_task_log(self) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]: - """This cached property avoids loading executor repeatedly.""" - executor = ExecutorLoader.get_default_executor() + def _get_executor_get_task_log( + self, ti: TaskInstance + ) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]: + """ + Get the get_task_log method from executor of current task instance. + + Since there might be multiple executors, so we need to get the executor of current task instance instead of getting from default executor. + + :param ti: task instance object + :return: get_task_log method of the executor + """ + executor: BaseExecutor = ExecutorLoader.load_executor(ti.executor) return executor.get_task_log def _read( @@ -360,7 +368,8 @@ def _read( messages_list.extend(remote_messages) has_k8s_exec_pod = False if ti.state == TaskInstanceState.RUNNING: - response = self._executor_get_task_log(ti, try_number) + executor_get_task_log = self._get_executor_get_task_log(ti) + response = executor_get_task_log(ti, try_number) if response: executor_messages, executor_logs = response if executor_messages: From cc5fda22c1612fa295115f3808a9bf70233466eb Mon Sep 17 00:00:00 2001 From: jason810496 Date: Wed, 15 Jan 2025 11:43:25 +0800 Subject: [PATCH 02/11] Add cached_property back to avoid loading executors --- airflow/utils/log/file_task_handler.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 62647eedb25d0..93f8e32c70da2 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -24,6 +24,7 @@ from collections.abc import Iterable from contextlib import suppress from enum import Enum +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from urllib.parse import urljoin @@ -314,6 +315,16 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO def _read_grouped_logs(self): return False + @cached_property + def _available_executors(self) -> dict[str, BaseExecutor]: + """This cached property avoids loading executors repeatedly.""" + return {ex.__class__.__name__: ex for ex in ExecutorLoader.init_executors()} + + @cached_property + def _default_executor(self) -> BaseExecutor: + """This cached property avoids loading executors repeatedly.""" + return next(iter(self._available_executors.values())) + def _get_executor_get_task_log( self, ti: TaskInstance ) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]: @@ -325,7 +336,13 @@ def _get_executor_get_task_log( :param ti: task instance object :return: get_task_log method of the executor """ - executor: BaseExecutor = ExecutorLoader.load_executor(ti.executor) + executor_name = ti.executor + if executor_name is None: + executor = self._default_executor + elif executor_name in self._available_executors: + executor = self._available_executors[executor_name] + else: + raise AirflowException(f"Executor {executor_name} not found for task {ti}") return executor.get_task_log def _read( From 0a5b6e42cee90535eca824daa0bd78678c73d601 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Wed, 15 Jan 2025 11:45:27 +0800 Subject: [PATCH 03/11] Add test for multi-executors scenario --- tests/utils/test_log_handlers.py | 62 +++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 454af48d66763..ddc112a8cda42 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -33,7 +33,7 @@ from requests.adapters import Response from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.executors import executor_loader +from airflow.executors import executor_constants, executor_loader from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models.dagrun import DagRun @@ -187,6 +187,66 @@ def task_callable(ti): # Remove the generated tmp log file. os.remove(log_filename) + @pytest.mark.parametrize( + "executor_name", + [ + (executor_constants.LOCAL_KUBERNETES_EXECUTOR), + (executor_constants.CELERY_KUBERNETES_EXECUTOR), + (executor_constants.KUBERNETES_EXECUTOR), + ], + ) + @conf_vars( + { + ("core", "EXECUTOR"): ",".join( + [ + executor_constants.LOCAL_KUBERNETES_EXECUTOR, + executor_constants.CELERY_KUBERNETES_EXECUTOR, + executor_constants.KUBERNETES_EXECUTOR, + ] + ), + } + ) + def test_file_task_handler_with_multiple_executors(self, executor_name, create_task_instance): + reload(executor_loader) + executors_mapping = executor_loader.ExecutorLoader.executors + path_to_executor_class = executors_mapping[executor_name] + + with patch(f"{path_to_executor_class}.get_task_log") as mock_get_task_log: + mock_get_task_log.return_value = ([], []) + ti = create_task_instance( + dag_id="dag_for_testing_multiple_executors", + task_id="task_for_testing_multiple_executors", + run_type=DagRunType.SCHEDULED, + logical_date=DEFAULT_DATE, + ) + ti.executor = executor_name + ti.state = TaskInstanceState.RUNNING + ti.try_number = 1 + logger = ti.log + ti.log.disabled = False + + file_handler = next( + (handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None + ) + assert file_handler is not None + + set_context(logger, ti) + assert file_handler.handler is not None + # We expect set_context generates a file locally. + log_filename = file_handler.handler.baseFilename + assert os.path.isfile(log_filename) + assert log_filename.endswith("1.log"), log_filename + + ti.run(ignore_ti_state=True) + + file_handler.flush() + file_handler.close() + + assert hasattr(file_handler, "read") + file_handler.read(ti) + os.remove(log_filename) + mock_get_task_log.assert_called_once() + def test_file_task_handler_running(self, dag_maker): def task_callable(ti): ti.log.info("test") From 04d79a5fab6c9102cd4b4c24c6ae6105f6e641c0 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 16 Jan 2025 05:55:25 +0800 Subject: [PATCH 04/11] Allow to call load_executor without init_executors --- airflow/executors/executor_loader.py | 4 ++++ tests/executors/test_executor_loader.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index 2651718bbad23..6d6b8d115bcc1 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -231,6 +231,10 @@ def init_executors(cls) -> list[BaseExecutor]: @classmethod def lookup_executor_name_by_str(cls, executor_name_str: str) -> ExecutorName: # lookup the executor by alias first, if not check if we're given a module path + if not _classname_to_executors or not _module_to_executors or not _alias_to_executors: + # if we haven't loaded the executors yet, such as directly calling load_executor + cls._get_executor_names() + if executor_name := _alias_to_executors.get(executor_name_str): return executor_name elif executor_name := _module_to_executors.get(executor_name_str): diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py index 87455bd841b3d..44e8b6629c737 100644 --- a/tests/executors/test_executor_loader.py +++ b/tests/executors/test_executor_loader.py @@ -343,6 +343,15 @@ def test_load_executor_alias(self): ) assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor) + @mock.patch( + "airflow.executors.executor_loader.ExecutorLoader._get_executor_names", + wraps=ExecutorLoader._get_executor_names, + ) + def test_call_load_executor_method_without_init_executors(self, mock_get_executor_names): + with conf_vars({("core", "executor"): "LocalExecutor"}): + ExecutorLoader.load_executor("LocalExecutor") + mock_get_executor_names.assert_called_once() + @mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor", autospec=True) def test_load_custom_executor_with_classname(self, mock_executor): with conf_vars( From 3ebd94008e7b23c9e77095f4cc0dbe3a8c7a7f48 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 16 Jan 2025 06:15:24 +0800 Subject: [PATCH 05/11] Refactor by caching necessary executors --- airflow/utils/log/file_task_handler.py | 29 ++++++++++---------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 93f8e32c70da2..a7d54ee522330 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -24,7 +24,6 @@ from collections.abc import Iterable from contextlib import suppress from enum import Enum -from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from urllib.parse import urljoin @@ -180,6 +179,8 @@ class FileTaskHandler(logging.Handler): inherits_from_empty_operator_log_message = ( "Operator inherits from empty operator and thus does not have logs" ) + executor_instances: dict[str, BaseExecutor] = {} + default_executor_key = "_default_executor" def __init__( self, @@ -315,16 +316,6 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO def _read_grouped_logs(self): return False - @cached_property - def _available_executors(self) -> dict[str, BaseExecutor]: - """This cached property avoids loading executors repeatedly.""" - return {ex.__class__.__name__: ex for ex in ExecutorLoader.init_executors()} - - @cached_property - def _default_executor(self) -> BaseExecutor: - """This cached property avoids loading executors repeatedly.""" - return next(iter(self._available_executors.values())) - def _get_executor_get_task_log( self, ti: TaskInstance ) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]: @@ -336,14 +327,16 @@ def _get_executor_get_task_log( :param ti: task instance object :return: get_task_log method of the executor """ - executor_name = ti.executor - if executor_name is None: - executor = self._default_executor - elif executor_name in self._available_executors: - executor = self._available_executors[executor_name] + executor_name = ti.executor or self.default_executor_key + executor = self.executor_instances.get(executor_name) + if executor is not None: + return executor.get_task_log + + if executor_name == self.default_executor_key: + self.executor_instances[executor_name] = ExecutorLoader.get_default_executor() else: - raise AirflowException(f"Executor {executor_name} not found for task {ti}") - return executor.get_task_log + self.executor_instances[executor_name] = ExecutorLoader.load_executor(executor_name) + return self.executor_instances[executor_name].get_task_log def _read( self, From c40e68defc11fb2c4b6dae0ab607f624561ab166 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 16 Jan 2025 06:16:26 +0800 Subject: [PATCH 06/11] Refactor test with default executor case --- tests/utils/test_log_handlers.py | 34 +++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index ddc112a8cda42..6905e833e728d 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -193,6 +193,7 @@ def task_callable(ti): (executor_constants.LOCAL_KUBERNETES_EXECUTOR), (executor_constants.CELERY_KUBERNETES_EXECUTOR), (executor_constants.KUBERNETES_EXECUTOR), + (None), ], ) @conf_vars( @@ -206,12 +207,26 @@ def task_callable(ti): ), } ) - def test_file_task_handler_with_multiple_executors(self, executor_name, create_task_instance): - reload(executor_loader) + @patch( + "airflow.executors.executor_loader.ExecutorLoader.load_executor", + wraps=executor_loader.ExecutorLoader.load_executor, + ) + @patch( + "airflow.executors.executor_loader.ExecutorLoader.get_default_executor", + wraps=executor_loader.ExecutorLoader.get_default_executor, + ) + def test_file_task_handler_with_multiple_executors( + self, mock_get_default_executor, mock_load_executor, executor_name, create_task_instance + ): executors_mapping = executor_loader.ExecutorLoader.executors - path_to_executor_class = executors_mapping[executor_name] + default_executor_name = executor_loader.ExecutorLoader.get_default_executor_name() + path_to_executor_class: str + if executor_name is None: + path_to_executor_class = executors_mapping.get(default_executor_name.alias) + else: + path_to_executor_class = executors_mapping.get(executor_name) - with patch(f"{path_to_executor_class}.get_task_log") as mock_get_task_log: + with patch(f"{path_to_executor_class}.get_task_log", return_value=([], [])) as mock_get_task_log: mock_get_task_log.return_value = ([], []) ti = create_task_instance( dag_id="dag_for_testing_multiple_executors", @@ -219,7 +234,8 @@ def test_file_task_handler_with_multiple_executors(self, executor_name, create_t run_type=DagRunType.SCHEDULED, logical_date=DEFAULT_DATE, ) - ti.executor = executor_name + if executor_name is not None: + ti.executor = executor_name ti.state = TaskInstanceState.RUNNING ti.try_number = 1 logger = ti.log @@ -247,6 +263,14 @@ def test_file_task_handler_with_multiple_executors(self, executor_name, create_t os.remove(log_filename) mock_get_task_log.assert_called_once() + if executor_name is None: + mock_get_default_executor.assert_called_once() + # will be called in `ExecutorLoader.get_default_executor` method + mock_load_executor.assert_called_once_with(default_executor_name) + else: + mock_get_default_executor.assert_not_called() + mock_load_executor.assert_called_once_with(executor_name) + def test_file_task_handler_running(self, dag_maker): def task_callable(ti): ti.log.info("test") From b1d9652edde0e97e7c77f25424e06613575669ce Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 18 Jan 2025 22:12:41 +0800 Subject: [PATCH 07/11] Fix side effect from executor_loader --- .../deps/test_ready_to_reschedule_dep.py | 3 ++ tests/utils/test_log_handlers.py | 2 ++ tests_common/test_utils/executor_loader.py | 34 +++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 tests_common/test_utils/executor_loader.py diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index d982cf4b27107..07703c2ed309b 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -32,6 +32,7 @@ from airflow.utils.state import State from tests_common.test_utils import db +from tests_common.test_utils.executor_loader import clean_executor_loader pytestmark = pytest.mark.db_test @@ -54,6 +55,7 @@ class TestNotInReschedulePeriodDep: def setup_test_cases(self, request, create_task_instance): db.clear_db_runs() db.clear_rendered_ti_fields() + clean_executor_loader() self.dag_id = f"dag_{slugify(request.cls.__name__)}" self.task_id = f"task_{slugify(request.node.name, max_length=40)}" @@ -64,6 +66,7 @@ def setup_test_cases(self, request, create_task_instance): yield db.clear_rendered_ti_fields() db.clear_db_runs() + clean_executor_loader() def _get_task_instance(self, state, *, map_index=-1): """Helper which create fake task_instance""" diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 6905e833e728d..68e01f0edecda 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -55,6 +55,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.executor_loader import clean_executor_loader pytestmark = pytest.mark.db_test @@ -218,6 +219,7 @@ def task_callable(ti): def test_file_task_handler_with_multiple_executors( self, mock_get_default_executor, mock_load_executor, executor_name, create_task_instance ): + clean_executor_loader() executors_mapping = executor_loader.ExecutorLoader.executors default_executor_name = executor_loader.ExecutorLoader.get_default_executor_name() path_to_executor_class: str diff --git a/tests_common/test_utils/executor_loader.py b/tests_common/test_utils/executor_loader.py new file mode 100644 index 0000000000000..0c00f3499631b --- /dev/null +++ b/tests_common/test_utils/executor_loader.py @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import airflow.executors.executor_loader as executor_loader + +if TYPE_CHECKING: + from airflow.executors.executor_utils import ExecutorName + + +def clean_executor_loader(): + """Clean the executor loader state, as it stores global variables in the module, causing side effects for some tests.""" + executor_loader._alias_to_executors: dict[str, ExecutorName] = {} + executor_loader._module_to_executors: dict[str, ExecutorName] = {} + executor_loader._team_id_to_executors: dict[str | None, ExecutorName] = {} + executor_loader._classname_to_executors: dict[str, ExecutorName] = {} + executor_loader._executor_names: list[ExecutorName] = [] From e37dcdd5a1fadfd30749b0717d3343fc0bb87a01 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Wed, 22 Jan 2025 01:18:49 +0800 Subject: [PATCH 08/11] Fix KubernetesExecutor test - Previous test failure is cuased by cache state of executor_instances - Should set ti.state = RUNNING after ti.run --- .../cncf/kubernetes/log_handlers/test_log_handlers.py | 9 ++++++++- tests/utils/test_log_handlers.py | 6 +++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py index 9cbebcf8df9ec..f408fd8817602 100644 --- a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -43,6 +43,7 @@ from tests_common.test_utils.compat import PythonOperator from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.executor_loader import clean_executor_loader from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -76,6 +77,7 @@ def teardown_method(self): @pytest.mark.parametrize("state", [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]) def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instance, state): """Test for k8s executor, the log is read from get_task_log method""" + clean_executor_loader() mock_k8s_get_task_log.return_value = ([], []) executor_name = "KubernetesExecutor" ti = create_task_instance( @@ -86,6 +88,7 @@ def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instanc ) ti.state = state ti.triggerer_job = None + ti.executor = executor_name with conf_vars({("core", "executor"): executor_name}): reload(executor_loader) fth = FileTaskHandler("") @@ -105,11 +108,12 @@ def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instanc pytest.param(k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="pod-name-xxx")), "default"), ], ) - @patch.dict("os.environ", AIRFLOW__CORE__EXECUTOR="KubernetesExecutor") + @conf_vars({("core", "executor"): "KubernetesExecutor"}) @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client") def test_read_from_k8s_under_multi_namespace_mode( self, mock_kube_client, pod_override, namespace_to_call ): + reload(executor_loader) mock_read_log = mock_kube_client.return_value.read_namespaced_pod_log mock_list_pod = mock_kube_client.return_value.list_namespaced_pod @@ -139,6 +143,7 @@ def task_callable(ti): ) ti = TaskInstance(task=task, run_id=dagrun.run_id) ti.try_number = 3 + ti.executor = "KubernetesExecutor" logger = ti.log ti.log.disabled = False @@ -147,6 +152,8 @@ def task_callable(ti): set_context(logger, ti) ti.run(ignore_ti_state=True) ti.state = TaskInstanceState.RUNNING + # clear executor_instances cache + file_handler.executor_instances = {} file_handler.read(ti, 2) # first we find pod name diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 68e01f0edecda..acbce6aaca965 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -238,8 +238,8 @@ def test_file_task_handler_with_multiple_executors( ) if executor_name is not None: ti.executor = executor_name - ti.state = TaskInstanceState.RUNNING ti.try_number = 1 + ti.state = TaskInstanceState.RUNNING logger = ti.log ti.log.disabled = False @@ -249,14 +249,14 @@ def test_file_task_handler_with_multiple_executors( assert file_handler is not None set_context(logger, ti) + # clear executor_instances cache + file_handler.executor_instances = {} assert file_handler.handler is not None # We expect set_context generates a file locally. log_filename = file_handler.handler.baseFilename assert os.path.isfile(log_filename) assert log_filename.endswith("1.log"), log_filename - ti.run(ignore_ti_state=True) - file_handler.flush() file_handler.close() From d65e1a05f3792de304123ff56354869180d967e0 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Wed, 22 Jan 2025 21:47:38 +0800 Subject: [PATCH 09/11] Fix side effect from executor_loader - The side effect only show up in postgres as backend environment, as previous fix only resolve side effect in sqlite as backend environment. - Also refactor clean_executor_loader as pytest fixture with setup teardown --- .../log_handlers/test_log_handlers.py | 6 +- tests/executors/test_executor_loader.py | 81 +++++++++---------- .../deps/test_ready_to_reschedule_dep.py | 5 +- tests/utils/test_log_handlers.py | 9 ++- tests_common/pytest_plugin.py | 10 +++ tests_common/test_utils/executor_loader.py | 4 +- 6 files changed, 62 insertions(+), 53 deletions(-) diff --git a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py index f408fd8817602..ac016cfd37691 100644 --- a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -43,7 +43,6 @@ from tests_common.test_utils.compat import PythonOperator from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.executor_loader import clean_executor_loader from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -75,9 +74,10 @@ def teardown_method(self): "airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor.get_task_log" ) @pytest.mark.parametrize("state", [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]) - def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instance, state): + def test__read_for_k8s_executor( + self, mock_k8s_get_task_log, create_task_instance, state, clean_executor_loader + ): """Test for k8s executor, the log is read from get_task_log method""" - clean_executor_loader() mock_k8s_get_task_log.return_value = ([], []) executor_name = "KubernetesExecutor" ti = create_task_instance( diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py index 44e8b6629c737..5d00ffca82d87 100644 --- a/tests/executors/test_executor_loader.py +++ b/tests/executors/test_executor_loader.py @@ -16,14 +16,13 @@ # under the License. from __future__ import annotations -from importlib import reload from unittest import mock import pytest from airflow.exceptions import AirflowConfigException from airflow.executors import executor_loader -from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, ExecutorName +from airflow.executors.executor_loader import ConnectorSource, ExecutorName from airflow.executors.local_executor import LocalExecutor from airflow.providers.amazon.aws.executors.ecs.ecs_executor import AwsEcsExecutor from airflow.providers.celery.executors.celery_executor import CeleryExecutor @@ -36,23 +35,17 @@ class FakeExecutor: class TestExecutorLoader: - def setup_method(self) -> None: - from airflow.executors import executor_loader + @pytest.fixture(autouse=True) + def setup_method(self, clean_executor_loader) -> None: + self.executor_loader_class = executor_loader.ExecutorLoader # type: ignore - reload(executor_loader) - global ExecutorLoader - ExecutorLoader = executor_loader.ExecutorLoader # type: ignore - - def teardown_method(self) -> None: - from airflow.executors import executor_loader - - reload(executor_loader) - ExecutorLoader.init_executors() + def teardown_method(self, clean_executor_loader) -> None: + self.executor_loader_class.init_executors() def test_no_executor_configured(self): with conf_vars({("core", "executor"): None}): with pytest.raises(AirflowConfigException, match=r".*not found in config$"): - ExecutorLoader.get_default_executor() + self.executor_loader_class.get_default_executor() @pytest.mark.parametrize( "executor_name", @@ -66,16 +59,18 @@ def test_no_executor_configured(self): ) def test_should_support_executor_from_core(self, executor_name): with conf_vars({("core", "executor"): executor_name}): - executor = ExecutorLoader.get_default_executor() + executor = self.executor_loader_class.get_default_executor() assert executor is not None assert executor_name == executor.__class__.__name__ assert executor.name is not None - assert executor.name == ExecutorName(ExecutorLoader.executors[executor_name], alias=executor_name) + assert executor.name == ExecutorName( + self.executor_loader_class.executors[executor_name], alias=executor_name + ) assert executor.name.connector_source == ConnectorSource.CORE def test_should_support_custom_path(self): with conf_vars({("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"}): - executor = ExecutorLoader.get_default_executor() + executor = self.executor_loader_class.get_default_executor() assert executor is not None assert executor.__class__.__name__ == "FakeExecutor" assert executor.name is not None @@ -249,17 +244,17 @@ def test_get_hybrid_executors_from_config( "airflow.executors.executor_loader.ExecutorLoader._get_team_executor_configs", return_value=team_executor_config, ): - executors = ExecutorLoader._get_executor_names() + executors = self.executor_loader_class._get_executor_names() assert executors == expected_executors_list def test_init_executors(self): with conf_vars({("core", "executor"): "CeleryExecutor"}): - executors = ExecutorLoader.init_executors() - executor_name = ExecutorLoader.get_default_executor_name() + executors = self.executor_loader_class.init_executors() + executor_name = self.executor_loader_class.get_default_executor_name() assert len(executors) == 1 assert isinstance(executors[0], CeleryExecutor) - assert "CeleryExecutor" in ExecutorLoader.executors - assert ExecutorLoader.executors["CeleryExecutor"] == executor_name.module_path + assert "CeleryExecutor" in self.executor_loader_class.executors + assert self.executor_loader_class.executors["CeleryExecutor"] == executor_name.module_path @pytest.mark.parametrize( "executor_config", @@ -276,7 +271,7 @@ def test_get_hybrid_executors_from_config_duplicates_should_fail(self, executor_ with pytest.raises( AirflowConfigException, match=r".+Duplicate executors are not yet supported.+" ): - ExecutorLoader._get_executor_names() + self.executor_loader_class._get_executor_names() @pytest.mark.parametrize( "executor_config", @@ -292,7 +287,7 @@ def test_get_hybrid_executors_from_config_duplicates_should_fail(self, executor_ def test_get_hybrid_executors_from_config_core_executors_bad_config_format(self, executor_config): with conf_vars({("core", "executor"): executor_config}): with pytest.raises(AirflowConfigException): - ExecutorLoader._get_executor_names() + self.executor_loader_class._get_executor_names() @pytest.mark.parametrize( ("executor_config", "expected_value"), @@ -308,7 +303,7 @@ def test_get_hybrid_executors_from_config_core_executors_bad_config_format(self, ) def test_should_support_import_executor_from_core(self, executor_config, expected_value): with conf_vars({("core", "executor"): executor_config}): - executor, import_source = ExecutorLoader.import_default_executor_cls() + executor, import_source = self.executor_loader_class.import_default_executor_cls() assert expected_value == executor.__name__ assert import_source == ConnectorSource.CORE @@ -322,34 +317,38 @@ def test_should_support_import_executor_from_core(self, executor_config, expecte ) def test_should_support_import_custom_path(self, executor_config): with conf_vars({("core", "executor"): executor_config}): - executor, import_source = ExecutorLoader.import_default_executor_cls() + executor, import_source = self.executor_loader_class.import_default_executor_cls() assert executor.__name__ == "FakeExecutor" assert import_source == ConnectorSource.CUSTOM_PATH def test_load_executor(self): with conf_vars({("core", "executor"): "LocalExecutor"}): - ExecutorLoader.init_executors() - assert isinstance(ExecutorLoader.load_executor("LocalExecutor"), LocalExecutor) - assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor) - assert isinstance(ExecutorLoader.load_executor(None), LocalExecutor) + self.executor_loader_class.init_executors() + assert isinstance(self.executor_loader_class.load_executor("LocalExecutor"), LocalExecutor) + assert isinstance( + self.executor_loader_class.load_executor(executor_loader._executor_names[0]), LocalExecutor + ) + assert isinstance(self.executor_loader_class.load_executor(None), LocalExecutor) def test_load_executor_alias(self): with conf_vars({("core", "executor"): "local_exec:airflow.executors.local_executor.LocalExecutor"}): - ExecutorLoader.init_executors() - assert isinstance(ExecutorLoader.load_executor("local_exec"), LocalExecutor) + self.executor_loader_class.init_executors() + assert isinstance(self.executor_loader_class.load_executor("local_exec"), LocalExecutor) assert isinstance( - ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"), + self.executor_loader_class.load_executor("airflow.executors.local_executor.LocalExecutor"), LocalExecutor, ) - assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor) + assert isinstance( + self.executor_loader_class.load_executor(executor_loader._executor_names[0]), LocalExecutor + ) @mock.patch( "airflow.executors.executor_loader.ExecutorLoader._get_executor_names", - wraps=ExecutorLoader._get_executor_names, + wraps=executor_loader.ExecutorLoader._get_executor_names, ) def test_call_load_executor_method_without_init_executors(self, mock_get_executor_names): with conf_vars({("core", "executor"): "LocalExecutor"}): - ExecutorLoader.load_executor("LocalExecutor") + self.executor_loader_class.load_executor("LocalExecutor") mock_get_executor_names.assert_called_once() @mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor", autospec=True) @@ -362,15 +361,15 @@ def test_load_custom_executor_with_classname(self, mock_executor): ): "my_alias:airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor" } ): - ExecutorLoader.init_executors() - assert isinstance(ExecutorLoader.load_executor("my_alias"), AwsEcsExecutor) - assert isinstance(ExecutorLoader.load_executor("AwsEcsExecutor"), AwsEcsExecutor) + self.executor_loader_class.init_executors() + assert isinstance(self.executor_loader_class.load_executor("my_alias"), AwsEcsExecutor) + assert isinstance(self.executor_loader_class.load_executor("AwsEcsExecutor"), AwsEcsExecutor) assert isinstance( - ExecutorLoader.load_executor( + self.executor_loader_class.load_executor( "airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor" ), AwsEcsExecutor, ) assert isinstance( - ExecutorLoader.load_executor(executor_loader._executor_names[0]), AwsEcsExecutor + self.executor_loader_class.load_executor(executor_loader._executor_names[0]), AwsEcsExecutor ) diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 07703c2ed309b..319d34fae6c04 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -32,7 +32,6 @@ from airflow.utils.state import State from tests_common.test_utils import db -from tests_common.test_utils.executor_loader import clean_executor_loader pytestmark = pytest.mark.db_test @@ -52,10 +51,9 @@ def side_effect(*args, **kwargs): class TestNotInReschedulePeriodDep: @pytest.fixture(autouse=True) - def setup_test_cases(self, request, create_task_instance): + def setup_test_cases(self, request, create_task_instance, clean_executor_loader): db.clear_db_runs() db.clear_rendered_ti_fields() - clean_executor_loader() self.dag_id = f"dag_{slugify(request.cls.__name__)}" self.task_id = f"task_{slugify(request.node.name, max_length=40)}" @@ -66,7 +64,6 @@ def setup_test_cases(self, request, create_task_instance): yield db.clear_rendered_ti_fields() db.clear_db_runs() - clean_executor_loader() def _get_task_instance(self, state, *, map_index=-1): """Helper which create fake task_instance""" diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index acbce6aaca965..fda432e01d1af 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -55,7 +55,6 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.executor_loader import clean_executor_loader pytestmark = pytest.mark.db_test @@ -217,9 +216,13 @@ def task_callable(ti): wraps=executor_loader.ExecutorLoader.get_default_executor, ) def test_file_task_handler_with_multiple_executors( - self, mock_get_default_executor, mock_load_executor, executor_name, create_task_instance + self, + mock_get_default_executor, + mock_load_executor, + executor_name, + create_task_instance, + clean_executor_loader, ): - clean_executor_loader() executors_mapping = executor_loader.ExecutorLoader.executors default_executor_name = executor_loader.ExecutorLoader.get_default_executor_name() path_to_executor_class: str diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 1b68f039eaa17..3b07835d275af 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -1567,6 +1567,16 @@ def clean_dags_and_dagruns(): clear_db_runs() +@pytest.fixture +def clean_executor_loader(): + from tests_common.test_utils.executor_loader import clean_executor_loader_module + + """Clean the executor_loader state, as it stores global variables in the module, causing side effects for some tests.""" + clean_executor_loader_module() + yield # Test runs here + clean_executor_loader_module() + + @pytest.fixture(scope="session") def app(): from tests_common.test_utils.config import conf_vars diff --git a/tests_common/test_utils/executor_loader.py b/tests_common/test_utils/executor_loader.py index 0c00f3499631b..f7dd98b726428 100644 --- a/tests_common/test_utils/executor_loader.py +++ b/tests_common/test_utils/executor_loader.py @@ -25,8 +25,8 @@ from airflow.executors.executor_utils import ExecutorName -def clean_executor_loader(): - """Clean the executor loader state, as it stores global variables in the module, causing side effects for some tests.""" +def clean_executor_loader_module(): + """Clean the executor_loader state, as it stores global variables in the module, causing side effects for some tests.""" executor_loader._alias_to_executors: dict[str, ExecutorName] = {} executor_loader._module_to_executors: dict[str, ExecutorName] = {} executor_loader._team_id_to_executors: dict[str | None, ExecutorName] = {} From ca36091f1b18e21339fd6a47c4ad3e49ec9474b8 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Wed, 22 Jan 2025 21:51:40 +0800 Subject: [PATCH 10/11] Capitalize default executor key --- airflow/utils/log/file_task_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index a7d54ee522330..21b745affbccd 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -180,7 +180,7 @@ class FileTaskHandler(logging.Handler): "Operator inherits from empty operator and thus does not have logs" ) executor_instances: dict[str, BaseExecutor] = {} - default_executor_key = "_default_executor" + DEFAULT_EXECUTOR_KEY = "_default_executor" def __init__( self, @@ -327,12 +327,12 @@ def _get_executor_get_task_log( :param ti: task instance object :return: get_task_log method of the executor """ - executor_name = ti.executor or self.default_executor_key + executor_name = ti.executor or self.DEFAULT_EXECUTOR_KEY executor = self.executor_instances.get(executor_name) if executor is not None: return executor.get_task_log - if executor_name == self.default_executor_key: + if executor_name == self.DEFAULT_EXECUTOR_KEY: self.executor_instances[executor_name] = ExecutorLoader.get_default_executor() else: self.executor_instances[executor_name] = ExecutorLoader.load_executor(executor_name) From 1bd1a300db82e3eb5357779b7de62f70f781e6b3 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 23 Jan 2025 14:53:40 +0800 Subject: [PATCH 11/11] Refactor clean_executor_loader fixture --- .../log_handlers/test_log_handlers.py | 5 +- tests/executors/test_executor_loader.py | 67 +++++++++---------- .../deps/test_ready_to_reschedule_dep.py | 3 +- tests_common/pytest_plugin.py | 5 +- 4 files changed, 41 insertions(+), 39 deletions(-) diff --git a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py index ac016cfd37691..d89fbdf6edb15 100644 --- a/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py +++ b/providers/tests/cncf/kubernetes/log_handlers/test_log_handlers.py @@ -74,9 +74,8 @@ def teardown_method(self): "airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor.get_task_log" ) @pytest.mark.parametrize("state", [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]) - def test__read_for_k8s_executor( - self, mock_k8s_get_task_log, create_task_instance, state, clean_executor_loader - ): + @pytest.mark.usefixtures("clean_executor_loader") + def test__read_for_k8s_executor(self, mock_k8s_get_task_log, create_task_instance, state): """Test for k8s executor, the log is read from get_task_log method""" mock_k8s_get_task_log.return_value = ([], []) executor_name = "KubernetesExecutor" diff --git a/tests/executors/test_executor_loader.py b/tests/executors/test_executor_loader.py index 5d00ffca82d87..de6703954b10d 100644 --- a/tests/executors/test_executor_loader.py +++ b/tests/executors/test_executor_loader.py @@ -34,18 +34,12 @@ class FakeExecutor: pass +@pytest.mark.usefixtures("clean_executor_loader") class TestExecutorLoader: - @pytest.fixture(autouse=True) - def setup_method(self, clean_executor_loader) -> None: - self.executor_loader_class = executor_loader.ExecutorLoader # type: ignore - - def teardown_method(self, clean_executor_loader) -> None: - self.executor_loader_class.init_executors() - def test_no_executor_configured(self): with conf_vars({("core", "executor"): None}): with pytest.raises(AirflowConfigException, match=r".*not found in config$"): - self.executor_loader_class.get_default_executor() + executor_loader.ExecutorLoader.get_default_executor() @pytest.mark.parametrize( "executor_name", @@ -59,18 +53,18 @@ def test_no_executor_configured(self): ) def test_should_support_executor_from_core(self, executor_name): with conf_vars({("core", "executor"): executor_name}): - executor = self.executor_loader_class.get_default_executor() + executor = executor_loader.ExecutorLoader.get_default_executor() assert executor is not None assert executor_name == executor.__class__.__name__ assert executor.name is not None assert executor.name == ExecutorName( - self.executor_loader_class.executors[executor_name], alias=executor_name + executor_loader.ExecutorLoader.executors[executor_name], alias=executor_name ) assert executor.name.connector_source == ConnectorSource.CORE def test_should_support_custom_path(self): with conf_vars({("core", "executor"): "tests.executors.test_executor_loader.FakeExecutor"}): - executor = self.executor_loader_class.get_default_executor() + executor = executor_loader.ExecutorLoader.get_default_executor() assert executor is not None assert executor.__class__.__name__ == "FakeExecutor" assert executor.name is not None @@ -244,17 +238,17 @@ def test_get_hybrid_executors_from_config( "airflow.executors.executor_loader.ExecutorLoader._get_team_executor_configs", return_value=team_executor_config, ): - executors = self.executor_loader_class._get_executor_names() + executors = executor_loader.ExecutorLoader._get_executor_names() assert executors == expected_executors_list def test_init_executors(self): with conf_vars({("core", "executor"): "CeleryExecutor"}): - executors = self.executor_loader_class.init_executors() - executor_name = self.executor_loader_class.get_default_executor_name() + executors = executor_loader.ExecutorLoader.init_executors() + executor_name = executor_loader.ExecutorLoader.get_default_executor_name() assert len(executors) == 1 assert isinstance(executors[0], CeleryExecutor) - assert "CeleryExecutor" in self.executor_loader_class.executors - assert self.executor_loader_class.executors["CeleryExecutor"] == executor_name.module_path + assert "CeleryExecutor" in executor_loader.ExecutorLoader.executors + assert executor_loader.ExecutorLoader.executors["CeleryExecutor"] == executor_name.module_path @pytest.mark.parametrize( "executor_config", @@ -271,7 +265,7 @@ def test_get_hybrid_executors_from_config_duplicates_should_fail(self, executor_ with pytest.raises( AirflowConfigException, match=r".+Duplicate executors are not yet supported.+" ): - self.executor_loader_class._get_executor_names() + executor_loader.ExecutorLoader._get_executor_names() @pytest.mark.parametrize( "executor_config", @@ -287,7 +281,7 @@ def test_get_hybrid_executors_from_config_duplicates_should_fail(self, executor_ def test_get_hybrid_executors_from_config_core_executors_bad_config_format(self, executor_config): with conf_vars({("core", "executor"): executor_config}): with pytest.raises(AirflowConfigException): - self.executor_loader_class._get_executor_names() + executor_loader.ExecutorLoader._get_executor_names() @pytest.mark.parametrize( ("executor_config", "expected_value"), @@ -303,7 +297,7 @@ def test_get_hybrid_executors_from_config_core_executors_bad_config_format(self, ) def test_should_support_import_executor_from_core(self, executor_config, expected_value): with conf_vars({("core", "executor"): executor_config}): - executor, import_source = self.executor_loader_class.import_default_executor_cls() + executor, import_source = executor_loader.ExecutorLoader.import_default_executor_cls() assert expected_value == executor.__name__ assert import_source == ConnectorSource.CORE @@ -317,29 +311,33 @@ def test_should_support_import_executor_from_core(self, executor_config, expecte ) def test_should_support_import_custom_path(self, executor_config): with conf_vars({("core", "executor"): executor_config}): - executor, import_source = self.executor_loader_class.import_default_executor_cls() + executor, import_source = executor_loader.ExecutorLoader.import_default_executor_cls() assert executor.__name__ == "FakeExecutor" assert import_source == ConnectorSource.CUSTOM_PATH def test_load_executor(self): with conf_vars({("core", "executor"): "LocalExecutor"}): - self.executor_loader_class.init_executors() - assert isinstance(self.executor_loader_class.load_executor("LocalExecutor"), LocalExecutor) + executor_loader.ExecutorLoader.init_executors() + assert isinstance(executor_loader.ExecutorLoader.load_executor("LocalExecutor"), LocalExecutor) assert isinstance( - self.executor_loader_class.load_executor(executor_loader._executor_names[0]), LocalExecutor + executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]), + LocalExecutor, ) - assert isinstance(self.executor_loader_class.load_executor(None), LocalExecutor) + assert isinstance(executor_loader.ExecutorLoader.load_executor(None), LocalExecutor) def test_load_executor_alias(self): with conf_vars({("core", "executor"): "local_exec:airflow.executors.local_executor.LocalExecutor"}): - self.executor_loader_class.init_executors() - assert isinstance(self.executor_loader_class.load_executor("local_exec"), LocalExecutor) + executor_loader.ExecutorLoader.init_executors() + assert isinstance(executor_loader.ExecutorLoader.load_executor("local_exec"), LocalExecutor) assert isinstance( - self.executor_loader_class.load_executor("airflow.executors.local_executor.LocalExecutor"), + executor_loader.ExecutorLoader.load_executor( + "airflow.executors.local_executor.LocalExecutor" + ), LocalExecutor, ) assert isinstance( - self.executor_loader_class.load_executor(executor_loader._executor_names[0]), LocalExecutor + executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]), + LocalExecutor, ) @mock.patch( @@ -348,7 +346,7 @@ def test_load_executor_alias(self): ) def test_call_load_executor_method_without_init_executors(self, mock_get_executor_names): with conf_vars({("core", "executor"): "LocalExecutor"}): - self.executor_loader_class.load_executor("LocalExecutor") + executor_loader.ExecutorLoader.load_executor("LocalExecutor") mock_get_executor_names.assert_called_once() @mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor", autospec=True) @@ -361,15 +359,16 @@ def test_load_custom_executor_with_classname(self, mock_executor): ): "my_alias:airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor" } ): - self.executor_loader_class.init_executors() - assert isinstance(self.executor_loader_class.load_executor("my_alias"), AwsEcsExecutor) - assert isinstance(self.executor_loader_class.load_executor("AwsEcsExecutor"), AwsEcsExecutor) + executor_loader.ExecutorLoader.init_executors() + assert isinstance(executor_loader.ExecutorLoader.load_executor("my_alias"), AwsEcsExecutor) + assert isinstance(executor_loader.ExecutorLoader.load_executor("AwsEcsExecutor"), AwsEcsExecutor) assert isinstance( - self.executor_loader_class.load_executor( + executor_loader.ExecutorLoader.load_executor( "airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor" ), AwsEcsExecutor, ) assert isinstance( - self.executor_loader_class.load_executor(executor_loader._executor_names[0]), AwsEcsExecutor + executor_loader.ExecutorLoader.load_executor(executor_loader._executor_names[0]), + AwsEcsExecutor, ) diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py index 319d34fae6c04..7e6f1b2253e17 100644 --- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -49,9 +49,10 @@ def side_effect(*args, **kwargs): yield m +@pytest.mark.usefixtures("clean_executor_loader") class TestNotInReschedulePeriodDep: @pytest.fixture(autouse=True) - def setup_test_cases(self, request, create_task_instance, clean_executor_loader): + def setup_test_cases(self, request, create_task_instance): db.clear_db_runs() db.clear_rendered_ti_fields() diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 3b07835d275af..969d0b2a61c8c 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -1569,12 +1569,15 @@ def clean_dags_and_dagruns(): @pytest.fixture def clean_executor_loader(): + """Clean the executor_loader state, as it stores global variables in the module, causing side effects for some tests.""" + from airflow.executors.executor_loader import ExecutorLoader + from tests_common.test_utils.executor_loader import clean_executor_loader_module - """Clean the executor_loader state, as it stores global variables in the module, causing side effects for some tests.""" clean_executor_loader_module() yield # Test runs here clean_executor_loader_module() + ExecutorLoader.init_executors() @pytest.fixture(scope="session")