Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kubernetes executor cleanup_stuck_queued_tasks optimization #41220

Merged
Merged
103 changes: 50 additions & 53 deletions airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,37 @@ def _make_safe_label_value(self, input_value: str | datetime) -> str:
return pod_generator.datetime_to_label_safe_datestring(input_value)
return pod_generator.make_safe_label_value(input_value)

def get_pod_labels_combined_str_to_pod_map(self) -> dict[str, k8s.V1Pod]:
"""
List the worker pods owned by this scheduler and create a map containing pod combined labels search str -> pod.

For every pod, it creates two below entries in the map
dirrao marked this conversation as resolved.
Show resolved Hide resolved
dag_id={dag_id},task_id={task_id},airflow-worker={airflow_worker},<map_index={map_index}>,run_id={run_id}
"""
# airflow worker label selector batch call
kwargs = {"label_selector": f"airflow-worker={self._make_safe_label_value(str(self.job_id))}"}
if self.kube_config.kube_client_request_args:
kwargs.update(self.kube_config.kube_client_request_args)
pod_list = self._list_pods(kwargs)

# create a set against pod query label fields
pod_labels_combined_str_to_pod_map = {}
for pod in pod_list:
dag_id = pod.metadata.labels.get("dag_id", None)
task_id = pod.metadata.labels.get("task_id", None)
airflow_worker = pod.metadata.labels.get("airflow-worker", None)
map_index = pod.metadata.labels.get("map_index", None)
run_id = pod.metadata.labels.get("run_id", None)
if dag_id is None or task_id is None or airflow_worker is None:
continue
label_search_base_str = f"dag_id={dag_id},task_id={task_id},airflow-worker={airflow_worker}"
if map_index is not None:
label_search_base_str += f",map_index={map_index}"
if run_id is not None:
dirrao marked this conversation as resolved.
Show resolved Hide resolved
label_search_str = f"{label_search_base_str},run_id={run_id}"
pod_labels_combined_str_to_pod_map[label_search_str] = pod
return pod_labels_combined_str_to_pod_map

@provide_session
def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -216,32 +247,7 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non
if not queued_tis:
return

# airflow worker label selector batch call
kwargs = {"label_selector": f"airflow-worker={self._make_safe_label_value(str(self.job_id))}"}
if self.kube_config.kube_client_request_args:
kwargs.update(self.kube_config.kube_client_request_args)
pod_list = self._list_pods(kwargs)

# create a set against pod query label fields
label_search_set = set()
for pod in pod_list:
dag_id = pod.metadata.labels.get("dag_id", None)
task_id = pod.metadata.labels.get("task_id", None)
airflow_worker = pod.metadata.labels.get("airflow-worker", None)
map_index = pod.metadata.labels.get("map_index", None)
run_id = pod.metadata.labels.get("run_id", None)
execution_date = pod.metadata.labels.get("execution_date", None)
if dag_id is None or task_id is None or airflow_worker is None:
continue
label_search_base_str = f"dag_id={dag_id},task_id={task_id},airflow-worker={airflow_worker}"
if map_index is not None:
label_search_base_str += f",map_index={map_index}"
if run_id is not None:
label_search_str = f"{label_search_base_str},run_id={run_id}"
label_search_set.add(label_search_str)
if execution_date is not None:
label_search_str = f"{label_search_base_str},execution_date={execution_date}"
label_search_set.add(label_search_str)
pod_labels_combined_str_to_pod_map = self.get_pod_labels_combined_str_to_pod_map()

for ti in queued_tis:
self.log.debug("Checking task instance %s", ti)
Expand All @@ -262,13 +268,7 @@ def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> Non

# Try run_id first
label_search_str = f"{base_label_selector},run_id={self._make_safe_label_value(ti.run_id)}"
if label_search_str in label_search_set:
continue
# Fallback to old style of using execution_date
label_search_str = (
f"{base_label_selector},execution_date={self._make_safe_label_value(ti.execution_date)}"
)
if label_search_str in label_search_set:
if label_search_str in pod_labels_combined_str_to_pod_map:
continue
self.log.info("TaskInstance: %s found in queued state but was not launched, rescheduling", ti)
session.execute(
Expand Down Expand Up @@ -601,34 +601,31 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
:param tis: List of Task Instances to clean up
:return: List of readable task instances for a warning message
"""
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator

if TYPE_CHECKING:
assert self.kube_client
assert self.kube_scheduler
readable_tis = []
readable_tis: list[str] = []
if not tis:
return readable_tis
pod_labels_combined_str_to_pod_map = self.get_pod_labels_combined_str_to_pod_map()
for ti in tis:
selector = PodGenerator.build_selector_for_k8s_executor_pod(
dag_id=ti.dag_id,
task_id=ti.task_id,
try_number=ti.try_number,
romsharon98 marked this conversation as resolved.
Show resolved Hide resolved
map_index=ti.map_index,
run_id=ti.run_id,
airflow_worker=ti.queued_by_job_id,
# Build the pod selector
base_label_selector = (
f"dag_id={self._make_safe_label_value(ti.dag_id)},"
f"task_id={self._make_safe_label_value(ti.task_id)},"
f"airflow-worker={self._make_safe_label_value(str(ti.queued_by_job_id))}"
)
namespace = self._get_pod_namespace(ti)
pod_list = self.kube_client.list_namespaced_pod(
namespace=namespace,
label_selector=selector,
).items
if not pod_list:
if ti.map_index >= 0:
# Old tasks _couldn't_ be mapped, so we don't have to worry about compat
base_label_selector += f",map_index={ti.map_index}"

label_search_str = f"{base_label_selector},run_id={self._make_safe_label_value(ti.run_id)}"
pod = pod_labels_combined_str_to_pod_map.get(label_search_str, None)
if not pod:
self.log.warning("Cannot find pod for ti %s", ti)
continue
elif len(pod_list) > 1:
self.log.warning("Found multiple pods for ti %s: %s", ti, pod_list)
continue
readable_tis.append(repr(ti))
self.kube_scheduler.delete_pod(pod_name=pod_list[0].metadata.name, namespace=namespace)
self.kube_scheduler.delete_pod(pod_name=pod.metadata.name, namespace=pod.metadata.namespace)
return readable_tis

def adopt_launched_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
import re
import string
from datetime import datetime, timedelta
from datetime import datetime
from unittest import mock

import pytest
Expand Down Expand Up @@ -1191,28 +1191,46 @@ def test_not_adopt_unassigned_task(self, mock_kube_client):
assert tis_to_flush_by_key == {"foobar": {}}

@pytest.mark.db_test
@mock.patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
@mock.patch(
"airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils.AirflowKubernetesScheduler.delete_pod"
)
def test_cleanup_stuck_queued_tasks(self, mock_delete_pod, mock_kube_client, dag_maker, session):
@mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor.DynamicClient")
def test_cleanup_stuck_queued_tasks(self, mock_kube_dynamic_client, dag_maker, create_dummy_dag, session):
"""Delete any pods associated with a task stuck in queued."""
executor = KubernetesExecutor()
executor.start()
executor.scheduler_job_id = "123"
with dag_maker(dag_id="test_cleanup_stuck_queued_tasks"):
op = BashOperator(task_id="bash", bash_command=["echo 0", "echo 1"])
mock_kube_client = mock.MagicMock()
mock_kube_dynamic_client.return_value = mock.MagicMock()
mock_pod_resource = mock.MagicMock()
mock_kube_dynamic_client.return_value.resources.get.return_value = mock_pod_resource
mock_kube_dynamic_client.return_value.get.return_value = k8s.V1PodList(
items=[
k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
labels={
"role": "airflow-worker",
"dag_id": "test_cleanup_stuck_queued_tasks",
"task_id": "bash",
"airflow-worker": 123,
"run_id": "test",
"try_number": 0,
},
),
status=k8s.V1PodStatus(phase="Pending"),
)
]
)
create_dummy_dag(dag_id="test_cleanup_stuck_queued_tasks", task_id="bash", with_dagrun_type=None)
dag_run = dag_maker.create_dagrun()
ti = dag_run.get_task_instance(op.task_id, session)
ti.retries = 1
ti = dag_run.task_instances[0]
ti.state = State.QUEUED
ti.queued_dttm = timezone.utcnow() - timedelta(minutes=30)
ti.queued_by_job_id = 123
session.flush()

executor = self.kubernetes_executor
executor.job_id = 123
executor.kube_client = mock_kube_client
executor.kube_scheduler = mock.MagicMock()
ti.refresh_from_db()
tis = [ti]
executor.cleanup_stuck_queued_tasks(tis)
mock_delete_pod.assert_called_once()
executor.kube_scheduler.delete_pod.assert_called_once()
assert executor.running == set()
executor.end()

@pytest.mark.parametrize(
"raw_multi_namespace_mode, raw_value_namespace_list, expected_value_in_kube_config",
Expand Down