Skip to content

Commit 7e7f2a2

Browse files
authored
Fix MyPy type errors in airflow-core/src/airflow/models/ (dagrun,serialized_dag), api_fastapi/common/exceptions.py for Sqlalchemy 2 migration (#57576)
1 parent 9613558 commit 7e7f2a2

File tree

3 files changed

+55
-41
lines changed

3 files changed

+55
-41
lines changed

airflow-core/src/airflow/api_fastapi/common/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def exception_handler(self, request: Request, exc: DeserializationError):
118118
)
119119

120120

121-
ERROR_HANDLERS = [_UniqueConstraintErrorHandler(), DagErrorHandler()]
121+
ERROR_HANDLERS: list[BaseErrorHandler] = [_UniqueConstraintErrorHandler(), DagErrorHandler()]

airflow-core/src/airflow/models/dagrun.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
from opentelemetry.sdk.trace import Span
9999
from pydantic import NonNegativeInt
100100
from sqlalchemy.orm import Query, Session
101-
from sqlalchemy.sql.elements import Case
101+
from sqlalchemy.sql.elements import Case, ColumnElement
102102

103103
from airflow.models.dag_version import DagVersion
104104
from airflow.models.mappedoperator import MappedOperator
@@ -334,19 +334,23 @@ def __init__(
334334
else:
335335
self.data_interval_start, self.data_interval_end = data_interval
336336
self.bundle_version = bundle_version
337-
self.dag_id = dag_id
338-
self.run_id = run_id
337+
if dag_id is not None:
338+
self.dag_id = dag_id
339+
if run_id is not None:
340+
self.run_id = run_id
339341
self.logical_date = logical_date
340-
self.run_after = run_after
342+
if run_after is not None:
343+
self.run_after = run_after
341344
self.start_date = start_date
342345
self.conf = conf or {}
343346
if state is not None:
344347
self.state = state
345348
if queued_at is NOTSET:
346349
self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None
347-
else:
350+
elif queued_at is not None:
348351
self.queued_at = queued_at
349-
self.run_type = run_type
352+
if run_type is not None:
353+
self.run_type = run_type
350354
self.creating_job_id = creating_job_id
351355
self.backfill_id = backfill_id
352356
self.clear_number = 0
@@ -560,7 +564,7 @@ def active_runs_of_dags(
560564
)
561565
if exclude_backfill:
562566
query = query.where(cls.run_type != DagRunType.BACKFILL_JOB)
563-
return dict(iter(session.execute(query)))
567+
return dict(session.execute(query).all())
564568

565569
@classmethod
566570
@retry_db_transaction
@@ -589,16 +593,17 @@ def get_running_dag_runs_to_examine(cls, session: Session) -> Query:
589593
)
590594
.options(joinedload(cls.task_instances))
591595
.order_by(
592-
nulls_first(BackfillDagRun.sort_ordinal, session=session),
593-
nulls_first(cls.last_scheduling_decision, session=session),
596+
nulls_first(cast("ColumnElement[Any]", BackfillDagRun.sort_ordinal), session=session),
597+
nulls_first(cast("ColumnElement[Any]", cls.last_scheduling_decision), session=session),
594598
cls.run_after,
595599
)
596600
.limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE)
597601
)
598602

599603
query = query.where(DagRun.run_after <= func.now())
600604

601-
return session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)).unique()
605+
result = session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)).unique()
606+
return result
602607

603608
@classmethod
604609
@retry_db_transaction
@@ -665,16 +670,16 @@ def get_queued_dag_runs_to_set_running(cls, session: Session) -> Query:
665670
coalesce(running_drs.c.num_running, text("0"))
666671
< coalesce(Backfill.max_active_runs, DagModel.max_active_runs),
667672
# don't set paused dag runs as running
668-
not_(coalesce(Backfill.is_paused, False)),
673+
not_(coalesce(cast("ColumnElement[bool]", Backfill.is_paused), False)),
669674
)
670675
.order_by(
671676
# ordering by backfill sort ordinal first ensures that backfill dag runs
672677
# have lower priority than all other dag run types (since sort_ordinal >= 1).
673678
# additionally, sorting by sort_ordinal ensures that the backfill
674679
# dag runs are created in the right order when that matters.
675680
# todo: AIP-78 use row_number to avoid starvation; limit the number of returned runs per-dag
676-
nulls_first(BackfillDagRun.sort_ordinal, session=session),
677-
nulls_first(cls.last_scheduling_decision, session=session),
681+
nulls_first(cast("ColumnElement[Any]", BackfillDagRun.sort_ordinal), session=session),
682+
nulls_first(cast("ColumnElement[Any]", cls.last_scheduling_decision), session=session),
678683
nulls_first(running_drs.c.num_running, session=session), # many running -> lower priority
679684
cls.run_after,
680685
)
@@ -739,7 +744,7 @@ def find(
739744
if no_backfills:
740745
qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB)
741746

742-
return session.scalars(qry.order_by(cls.logical_date)).all()
747+
return list(session.scalars(qry.order_by(cls.logical_date)).all())
743748

744749
@classmethod
745750
@provide_session
@@ -806,7 +811,7 @@ def fetch_task_instances(
806811

807812
if task_ids is not None:
808813
tis = tis.where(TI.task_id.in_(task_ids))
809-
return session.scalars(tis).all()
814+
return list(session.scalars(tis).all())
810815

811816
def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session):
812817
"""Check if last N dags failed."""
@@ -936,7 +941,7 @@ def get_previous_dagrun(
936941
:param session: SQLAlchemy ORM Session
937942
:param state: the dag run state
938943
"""
939-
if dag_run.logical_date is None:
944+
if not dag_run or dag_run.logical_date is None:
940945
return None
941946
filters = [
942947
DagRun.dag_id == dag_run.dag_id,
@@ -959,7 +964,7 @@ def get_previous_scheduled_dagrun(
959964
:param session: SQLAlchemy ORM Session
960965
"""
961966
dag_run = session.get(DagRun, dag_run_id)
962-
if not dag_run.logical_date:
967+
if not dag_run or not dag_run.logical_date:
963968
return None
964969
return session.scalar(
965970
select(DagRun)
@@ -1150,9 +1155,13 @@ def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
11501155
def should_schedule(self) -> bool:
11511156
return (
11521157
bool(self.tis)
1153-
and all(not t.task.depends_on_past for t in self.tis) # type: ignore[union-attr]
1154-
and all(t.task.max_active_tis_per_dag is None for t in self.tis) # type: ignore[union-attr]
1155-
and all(t.task.max_active_tis_per_dagrun is None for t in self.tis) # type: ignore[union-attr]
1158+
and all(not getattr(t.task, "depends_on_past", False) for t in self.tis if t.task)
1159+
and all(
1160+
getattr(t.task, "max_active_tis_per_dag", None) is None for t in self.tis if t.task
1161+
)
1162+
and all(
1163+
getattr(t.task, "max_active_tis_per_dagrun", None) is None for t in self.tis if t.task
1164+
)
11561165
and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
11571166
)
11581167

@@ -1414,7 +1423,7 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "
14141423
)
14151424
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
14161425

1417-
last_ti = self.get_last_ti(dag) # type: ignore[arg-type]
1426+
last_ti = self.get_last_ti(cast("SerializedDAG", dag))
14181427
if last_ti:
14191428
last_ti_model = TIDataModel.model_validate(last_ti, from_attributes=True)
14201429
task = dag.get_task(last_ti.task_id)
@@ -1426,9 +1435,9 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "
14261435
data_interval_start=self.data_interval_start,
14271436
data_interval_end=self.data_interval_end,
14281437
run_after=self.run_after,
1429-
start_date=self.start_date,
1438+
start_date=self.start_date or timezone.utcnow(),
14301439
end_date=self.end_date,
1431-
run_type=self.run_type,
1440+
run_type=DagRunType(self.run_type),
14321441
state=self.state,
14331442
conf=self.conf,
14341443
consumed_asset_events=[],
@@ -1848,7 +1857,8 @@ def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
18481857
for map_index in indexes:
18491858
ti = TI(task, run_id=self.run_id, map_index=map_index, dag_version_id=dag_version_id)
18501859
ti_mutation_hook(ti)
1851-
created_counts[ti.operator] += 1
1860+
if ti.operator:
1861+
created_counts[ti.operator] += 1
18521862
yield ti
18531863

18541864
creator = create_ti
@@ -1910,7 +1920,7 @@ def _create_task_instances(
19101920
run_id = self.run_id
19111921
try:
19121922
if hook_is_noop:
1913-
session.bulk_insert_mappings(TI, tasks)
1923+
session.bulk_insert_mappings(TI.__mapper__, tasks)
19141924
else:
19151925
session.bulk_save_objects(tasks)
19161926

@@ -1995,12 +2005,14 @@ def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
19952005
.group_by(cls.dag_id)
19962006
.subquery()
19972007
)
1998-
return session.scalars(
1999-
select(cls).join(
2000-
subquery,
2001-
and_(cls.dag_id == subquery.c.dag_id, cls.logical_date == subquery.c.logical_date),
2002-
)
2003-
).all()
2008+
return list(
2009+
session.scalars(
2010+
select(cls).join(
2011+
subquery,
2012+
and_(cls.dag_id == subquery.c.dag_id, cls.logical_date == subquery.c.logical_date),
2013+
)
2014+
).all()
2015+
)
20042016

20052017
@provide_session
20062018
def schedule_tis(
@@ -2054,7 +2066,7 @@ def schedule_tis(
20542066
schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids)
20552067
)
20562068
for id_chunk in schedulable_ti_ids_chunks:
2057-
count += session.execute(
2069+
result = session.execute(
20582070
update(TI)
20592071
.where(TI.id.in_(id_chunk))
20602072
.values(
@@ -2069,13 +2081,14 @@ def schedule_tis(
20692081
),
20702082
)
20712083
.execution_options(synchronize_session=False)
2072-
).rowcount
2084+
)
2085+
count += getattr(result, "rowcount", 0)
20732086

20742087
# Tasks using EmptyOperator should not be executed, mark them as success
20752088
if empty_ti_ids:
20762089
dummy_ti_ids_chunks = chunks(empty_ti_ids, max_tis_per_query or len(empty_ti_ids))
20772090
for id_chunk in dummy_ti_ids_chunks:
2078-
count += session.execute(
2091+
result = session.execute(
20792092
update(TI)
20802093
.where(TI.id.in_(id_chunk))
20812094
.values(
@@ -2088,7 +2101,8 @@ def schedule_tis(
20882101
.execution_options(
20892102
synchronize_session=False,
20902103
)
2091-
).rowcount
2104+
)
2105+
count += getattr(result, "rowcount", 0)
20922106

20932107
return count
20942108

@@ -2180,7 +2194,7 @@ def get_or_create_dagrun(
21802194
21812195
:return: The newly created DAG run.
21822196
"""
2183-
dr: DagRun = session.scalar(
2197+
dr = session.scalar(
21842198
select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.logical_date == logical_date)
21852199
)
21862200
if dr:

airflow-core/src/airflow/models/serialized_dag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def collect_asset_key_to_ids(self, asset_name_uris: set[tuple[str, str]]) -> dic
145145
)
146146
}
147147

148-
def collect_asset_name_ref_to_ids_names(self, asset_ref_names) -> dict[str, tuple[int, str]]:
148+
def collect_asset_name_ref_to_ids_names(self, asset_ref_names: set[str]) -> dict[str, tuple[int, str]]:
149149
return {
150150
name: (asset_id, name)
151151
for name, asset_id in self.session.execute(
@@ -155,7 +155,7 @@ def collect_asset_name_ref_to_ids_names(self, asset_ref_names) -> dict[str, tupl
155155
)
156156
}
157157

158-
def collect_asset_uri_ref_to_ids_names(self, asset_ref_uris) -> dict[str, tuple[int, str]]:
158+
def collect_asset_uri_ref_to_ids_names(self, asset_ref_uris: set[str]) -> dict[str, tuple[int, str]]:
159159
return {
160160
uri: (asset_id, name)
161161
for uri, name, asset_id in self.session.execute(
@@ -165,7 +165,7 @@ def collect_asset_uri_ref_to_ids_names(self, asset_ref_uris) -> dict[str, tuple[
165165
)
166166
}
167167

168-
def collect_alias_to_assets(self, asset_alias_names) -> dict[str, list[tuple[int, str]]]:
168+
def collect_alias_to_assets(self, asset_alias_names: set[str]) -> dict[str, list[tuple[int, str]]]:
169169
return {
170170
aam.name: [(am.id, am.name) for am in aam.assets]
171171
for aam in self.session.scalars(
@@ -219,7 +219,7 @@ def resolve_asset_ref_dag_dep(
219219
dependency_id=dep_id,
220220
)
221221

222-
def resolve_asset_name_ref_dag_dep(self, dep_data) -> Iterator[DagDependency]:
222+
def resolve_asset_name_ref_dag_dep(self, dep_data: dict) -> Iterator[DagDependency]:
223223
return self.resolve_asset_ref_dag_dep(dep_data=dep_data, ref_type="asset-name-ref")
224224

225225
def resolve_asset_uri_ref_dag_dep(self, dep_data: dict) -> Iterator[DagDependency]:

0 commit comments

Comments
 (0)