9898 from opentelemetry .sdk .trace import Span
9999 from pydantic import NonNegativeInt
100100 from sqlalchemy .orm import Query , Session
101- from sqlalchemy .sql .elements import Case
101+ from sqlalchemy .sql .elements import Case , ColumnElement
102102
103103 from airflow .models .dag_version import DagVersion
104104 from airflow .models .mappedoperator import MappedOperator
@@ -334,19 +334,23 @@ def __init__(
334334 else :
335335 self .data_interval_start , self .data_interval_end = data_interval
336336 self .bundle_version = bundle_version
337- self .dag_id = dag_id
338- self .run_id = run_id
337+ if dag_id is not None :
338+ self .dag_id = dag_id
339+ if run_id is not None :
340+ self .run_id = run_id
339341 self .logical_date = logical_date
340- self .run_after = run_after
342+ if run_after is not None :
343+ self .run_after = run_after
341344 self .start_date = start_date
342345 self .conf = conf or {}
343346 if state is not None :
344347 self .state = state
345348 if queued_at is NOTSET :
346349 self .queued_at = timezone .utcnow () if state == DagRunState .QUEUED else None
347- else :
350+ elif queued_at is not None :
348351 self .queued_at = queued_at
349- self .run_type = run_type
352+ if run_type is not None :
353+ self .run_type = run_type
350354 self .creating_job_id = creating_job_id
351355 self .backfill_id = backfill_id
352356 self .clear_number = 0
@@ -560,7 +564,7 @@ def active_runs_of_dags(
560564 )
561565 if exclude_backfill :
562566 query = query .where (cls .run_type != DagRunType .BACKFILL_JOB )
563- return dict (iter ( session .execute (query )))
567+ return dict (session .execute (query ). all ( ))
564568
565569 @classmethod
566570 @retry_db_transaction
@@ -589,16 +593,17 @@ def get_running_dag_runs_to_examine(cls, session: Session) -> Query:
589593 )
590594 .options (joinedload (cls .task_instances ))
591595 .order_by (
592- nulls_first (BackfillDagRun .sort_ordinal , session = session ),
593- nulls_first (cls .last_scheduling_decision , session = session ),
596+ nulls_first (cast ( "ColumnElement[Any]" , BackfillDagRun .sort_ordinal ) , session = session ),
597+ nulls_first (cast ( "ColumnElement[Any]" , cls .last_scheduling_decision ) , session = session ),
594598 cls .run_after ,
595599 )
596600 .limit (cls .DEFAULT_DAGRUNS_TO_EXAMINE )
597601 )
598602
599603 query = query .where (DagRun .run_after <= func .now ())
600604
601- return session .scalars (with_row_locks (query , of = cls , session = session , skip_locked = True )).unique ()
605+ result = session .scalars (with_row_locks (query , of = cls , session = session , skip_locked = True )).unique ()
606+ return result
602607
603608 @classmethod
604609 @retry_db_transaction
@@ -665,16 +670,16 @@ def get_queued_dag_runs_to_set_running(cls, session: Session) -> Query:
665670 coalesce (running_drs .c .num_running , text ("0" ))
666671 < coalesce (Backfill .max_active_runs , DagModel .max_active_runs ),
667672 # don't set paused dag runs as running
668- not_ (coalesce (Backfill .is_paused , False )),
673+ not_ (coalesce (cast ( "ColumnElement[bool]" , Backfill .is_paused ) , False )),
669674 )
670675 .order_by (
671676 # ordering by backfill sort ordinal first ensures that backfill dag runs
672677 # have lower priority than all other dag run types (since sort_ordinal >= 1).
673678 # additionally, sorting by sort_ordinal ensures that the backfill
674679 # dag runs are created in the right order when that matters.
675680 # todo: AIP-78 use row_number to avoid starvation; limit the number of returned runs per-dag
676- nulls_first (BackfillDagRun .sort_ordinal , session = session ),
677- nulls_first (cls .last_scheduling_decision , session = session ),
681+ nulls_first (cast ( "ColumnElement[Any]" , BackfillDagRun .sort_ordinal ) , session = session ),
682+ nulls_first (cast ( "ColumnElement[Any]" , cls .last_scheduling_decision ) , session = session ),
678683 nulls_first (running_drs .c .num_running , session = session ), # many running -> lower priority
679684 cls .run_after ,
680685 )
@@ -739,7 +744,7 @@ def find(
739744 if no_backfills :
740745 qry = qry .where (cls .run_type != DagRunType .BACKFILL_JOB )
741746
742- return session .scalars (qry .order_by (cls .logical_date )).all ()
747+ return list ( session .scalars (qry .order_by (cls .logical_date )).all () )
743748
744749 @classmethod
745750 @provide_session
@@ -806,7 +811,7 @@ def fetch_task_instances(
806811
807812 if task_ids is not None :
808813 tis = tis .where (TI .task_id .in_ (task_ids ))
809- return session .scalars (tis ).all ()
814+ return list ( session .scalars (tis ).all () )
810815
811816 def _check_last_n_dagruns_failed (self , dag_id , max_consecutive_failed_dag_runs , session ):
812817 """Check if last N dags failed."""
@@ -936,7 +941,7 @@ def get_previous_dagrun(
936941 :param session: SQLAlchemy ORM Session
937942 :param state: the dag run state
938943 """
939- if dag_run .logical_date is None :
944+ if not dag_run or dag_run .logical_date is None :
940945 return None
941946 filters = [
942947 DagRun .dag_id == dag_run .dag_id ,
@@ -959,7 +964,7 @@ def get_previous_scheduled_dagrun(
959964 :param session: SQLAlchemy ORM Session
960965 """
961966 dag_run = session .get (DagRun , dag_run_id )
962- if not dag_run .logical_date :
967+ if not dag_run or not dag_run .logical_date :
963968 return None
964969 return session .scalar (
965970 select (DagRun )
@@ -1150,9 +1155,13 @@ def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
11501155 def should_schedule (self ) -> bool :
11511156 return (
11521157 bool (self .tis )
1153- and all (not t .task .depends_on_past for t in self .tis ) # type: ignore[union-attr]
1154- and all (t .task .max_active_tis_per_dag is None for t in self .tis ) # type: ignore[union-attr]
1155- and all (t .task .max_active_tis_per_dagrun is None for t in self .tis ) # type: ignore[union-attr]
1158+ and all (not getattr (t .task , "depends_on_past" , False ) for t in self .tis if t .task )
1159+ and all (
1160+ getattr (t .task , "max_active_tis_per_dag" , None ) is None for t in self .tis if t .task
1161+ )
1162+ and all (
1163+ getattr (t .task , "max_active_tis_per_dagrun" , None ) is None for t in self .tis if t .task
1164+ )
11561165 and all (t .state != TaskInstanceState .DEFERRED for t in self .tis )
11571166 )
11581167
@@ -1414,7 +1423,7 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "
14141423 )
14151424 from airflow .sdk .execution_time .task_runner import RuntimeTaskInstance
14161425
1417- last_ti = self .get_last_ti (dag ) # type: ignore[arg-type]
1426+ last_ti = self .get_last_ti (cast ( "SerializedDAG" , dag ))
14181427 if last_ti :
14191428 last_ti_model = TIDataModel .model_validate (last_ti , from_attributes = True )
14201429 task = dag .get_task (last_ti .task_id )
@@ -1426,9 +1435,9 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "
14261435 data_interval_start = self .data_interval_start ,
14271436 data_interval_end = self .data_interval_end ,
14281437 run_after = self .run_after ,
1429- start_date = self .start_date ,
1438+ start_date = self .start_date or timezone . utcnow () ,
14301439 end_date = self .end_date ,
1431- run_type = self .run_type ,
1440+ run_type = DagRunType ( self .run_type ) ,
14321441 state = self .state ,
14331442 conf = self .conf ,
14341443 consumed_asset_events = [],
@@ -1848,7 +1857,8 @@ def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
18481857 for map_index in indexes :
18491858 ti = TI (task , run_id = self .run_id , map_index = map_index , dag_version_id = dag_version_id )
18501859 ti_mutation_hook (ti )
1851- created_counts [ti .operator ] += 1
1860+ if ti .operator :
1861+ created_counts [ti .operator ] += 1
18521862 yield ti
18531863
18541864 creator = create_ti
@@ -1910,7 +1920,7 @@ def _create_task_instances(
19101920 run_id = self .run_id
19111921 try :
19121922 if hook_is_noop :
1913- session .bulk_insert_mappings (TI , tasks )
1923+ session .bulk_insert_mappings (TI . __mapper__ , tasks )
19141924 else :
19151925 session .bulk_save_objects (tasks )
19161926
@@ -1995,12 +2005,14 @@ def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
19952005 .group_by (cls .dag_id )
19962006 .subquery ()
19972007 )
1998- return session .scalars (
1999- select (cls ).join (
2000- subquery ,
2001- and_ (cls .dag_id == subquery .c .dag_id , cls .logical_date == subquery .c .logical_date ),
2002- )
2003- ).all ()
2008+ return list (
2009+ session .scalars (
2010+ select (cls ).join (
2011+ subquery ,
2012+ and_ (cls .dag_id == subquery .c .dag_id , cls .logical_date == subquery .c .logical_date ),
2013+ )
2014+ ).all ()
2015+ )
20042016
20052017 @provide_session
20062018 def schedule_tis (
@@ -2054,7 +2066,7 @@ def schedule_tis(
20542066 schedulable_ti_ids , max_tis_per_query or len (schedulable_ti_ids )
20552067 )
20562068 for id_chunk in schedulable_ti_ids_chunks :
2057- count + = session .execute (
2069+ result = session .execute (
20582070 update (TI )
20592071 .where (TI .id .in_ (id_chunk ))
20602072 .values (
@@ -2069,13 +2081,14 @@ def schedule_tis(
20692081 ),
20702082 )
20712083 .execution_options (synchronize_session = False )
2072- ).rowcount
2084+ )
2085+ count += getattr (result , "rowcount" , 0 )
20732086
20742087 # Tasks using EmptyOperator should not be executed, mark them as success
20752088 if empty_ti_ids :
20762089 dummy_ti_ids_chunks = chunks (empty_ti_ids , max_tis_per_query or len (empty_ti_ids ))
20772090 for id_chunk in dummy_ti_ids_chunks :
2078- count + = session .execute (
2091+ result = session .execute (
20792092 update (TI )
20802093 .where (TI .id .in_ (id_chunk ))
20812094 .values (
@@ -2088,7 +2101,8 @@ def schedule_tis(
20882101 .execution_options (
20892102 synchronize_session = False ,
20902103 )
2091- ).rowcount
2104+ )
2105+ count += getattr (result , "rowcount" , 0 )
20922106
20932107 return count
20942108
@@ -2180,7 +2194,7 @@ def get_or_create_dagrun(
21802194
21812195 :return: The newly created DAG run.
21822196 """
2183- dr : DagRun = session .scalar (
2197+ dr = session .scalar (
21842198 select (DagRun ).where (DagRun .dag_id == dag .dag_id , DagRun .logical_date == logical_date )
21852199 )
21862200 if dr :
0 commit comments