Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from airflow._shared.timezones import timezone
from airflow.assets.manager import asset_manager
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
from airflow.listeners.listener import get_listener_manager
from airflow.models.asset import AssetEvent, AssetModel
from airflow.models.base import Base, StringID, TaskInstanceDependencies
Expand Down Expand Up @@ -556,6 +557,11 @@ def insert_mapping(
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
)

executor = task.executor
if executor is None:
executor_name = ExecutorLoader.get_default_executor_name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we call on task instance potentially often and there is a bunch of logic behind this - but w/o DB access - to wire up executor details... should the called get_default_executor_name() method being implement the @cache decorator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jscheffl, would this return the right executor in a multi executor context?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the executor field is None then the default executor is used during execution - so the logic on API for display is consistent here. If the field on the TaskInstance is filled, then it references to the executor that specifically is wanted.

executor = executor_name.alias or executor_name.module_path

return {
"dag_id": task.dag_id,
"task_id": task.task_id,
Expand All @@ -569,7 +575,7 @@ def insert_mapping(
"priority_weight": priority_weight,
"run_as_user": task.run_as_user,
"max_tries": task.retries,
"executor": task.executor,
"executor": executor,
"executor_config": task.executor_config,
"operator": task.task_type,
"custom_operator_name": getattr(task, "operator_name", None),
Expand Down Expand Up @@ -753,7 +759,11 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) ->
self.run_as_user = task.run_as_user
# Do not set max_tries to task.retries here because max_tries is a cumulative
# value that needs to be stored in the db.
self.executor = task.executor
if task.executor is None:
executor_name = ExecutorLoader.get_default_executor_name()
self.executor = executor_name.alias or executor_name.module_path
else:
self.executor = task.executor
self.executor_config = task.executor_config
self.operator = task.task_type
op_name = getattr(task, "operator_name", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ export const Details = () => {
</Table.Row>
<Table.Row>
<Table.Cell>{translate("taskInstance.executor")}</Table.Cell>
<Table.Cell>{tryInstance?.executor}</Table.Cell>
</Table.Row>
<Table.Row>
<Table.Cell>{translate("taskInstance.executorConfig")}</Table.Cell>
<Table.Cell>{tryInstance?.executor_config}</Table.Cell>
</Table.Row>
</Table.Body>
Expand Down
66 changes: 66 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
AirflowFailException,
AirflowSkipException,
)
from airflow.executors.executor_utils import ExecutorName
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel
from airflow.models.connection import Connection
from airflow.models.dag_version import DagVersion
Expand Down Expand Up @@ -2821,6 +2822,71 @@ def mock_policy(task_instance: TaskInstance):
assert ti.max_tries == expected_max_tries


@pytest.mark.parametrize(
"task_executor,expected_executor",
[
(None, "LocalExecutor"), # Default executor should be resolved
("LocalExecutor", "LocalExecutor"), # Explicit executor should be preserved
("CeleryExecutor", "CeleryExecutor"), # Explicit executor should be preserved
],
)
def test_refresh_from_task_resolves_executor(task_executor, expected_executor, monkeypatch):
"""Test that refresh_from_task resolves None executor to default executor name."""
# Mock the default executor
mock_executor_name = ExecutorName(
module_path="airflow.executors.local_executor.LocalExecutor", alias="LocalExecutor"
)

with mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor_name",
return_value=mock_executor_name,
):
task = EmptyOperator(task_id="test_executor", executor=task_executor)
ti = TI(task, run_id=None, dag_version_id=mock.MagicMock())
ti.refresh_from_task(task)

assert ti.executor == expected_executor


def test_insert_mapping_resolves_executor_to_default():
"""Test that insert_mapping resolves None executor to default executor name."""
mock_executor_name = ExecutorName(
module_path="airflow.executors.local_executor.LocalExecutor", alias="LocalExecutor"
)

with mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor_name",
return_value=mock_executor_name,
):
task = EmptyOperator(
task_id="test_task",
executor=None, # No executor specified
)

mapping = TI.insert_mapping(
run_id="test_run",
task=task,
map_index=-1,
dag_version_id=mock.MagicMock(),
)

assert mapping["executor"] == "LocalExecutor"


def test_insert_mapping_preserves_explicit_executor():
"""Test that insert_mapping preserves explicitly set executor."""
task = EmptyOperator(task_id="test_task", executor="CeleryExecutor")

mapping = TI.insert_mapping(
run_id="test_run",
task=task,
map_index=-1,
dag_version_id=mock.MagicMock(),
)

assert mapping["executor"] == "CeleryExecutor"


class TestRunRawTaskQueriesCount:
"""
These tests are designed to detect changes in the number of queries executed
Expand Down
Loading