Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ORM DAG insertion logic #42358

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 104 additions & 96 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,46 +61,28 @@
log = logging.getLogger(__name__)


def collect_orm_dags(dags: dict[str, DAG], *, session: Session) -> dict[str, DagModel]:
"""
Collect DagModel objects from DAG objects.

An existing DagModel is fetched if there's a matching ID in the database.
Otherwise, a new DagModel is created and added to the session.
"""
def _find_orm_dags(dag_ids: Iterable[str], *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
stmt = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
.where(DagModel.dag_id.in_(dags))
.where(DagModel.dag_id.in_(dag_ids))
.options(joinedload(DagModel.schedule_dataset_references))
.options(joinedload(DagModel.schedule_dataset_alias_references))
.options(joinedload(DagModel.task_outlet_dataset_references))
)
stmt = with_row_locks(stmt, of=DagModel, session=session)
existing_orm_dags = {dm.dag_id: dm for dm in session.scalars(stmt).unique()}
return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}


for dag_id, dag in dags.items():
if dag_id in existing_orm_dags:
continue
orm_dag = DagModel(dag_id=dag_id)
def _create_orm_dags(dags: Iterable[DAG], *, session: Session) -> Iterator[DagModel]:
for dag in dags:
orm_dag = DagModel(dag_id=dag.dag_id)
if dag.is_paused_upon_creation is not None:
orm_dag.is_paused = dag.is_paused_upon_creation
orm_dag.tags = []
log.info("Creating ORM DAG for %s", dag_id)
log.info("Creating ORM DAG for %s", dag.dag_id)
session.add(orm_dag)
existing_orm_dags[dag_id] = orm_dag

return existing_orm_dags


def create_orm_dag(dag: DAG, session: Session) -> DagModel:
orm_dag = DagModel(dag_id=dag.dag_id)
if dag.is_paused_upon_creation is not None:
orm_dag.is_paused = dag.is_paused_upon_creation
orm_dag.tags = []
log.info("Creating ORM DAG for %s", dag.dag_id)
session.add(orm_dag)
return orm_dag
yield orm_dag


def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select:
Expand Down Expand Up @@ -158,75 +140,101 @@ def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self:
)


def update_orm_dags(
source_dags: dict[str, DAG],
target_dags: dict[str, DagModel],
*,
processor_subdir: str | None = None,
session: Session,
) -> None:
"""
Apply DAG attributes to DagModel objects.

Objects in ``target_dags`` are modified in-place.
"""
run_info = _RunInfo.calculate(source_dags, session=session)

for dag_id, dm in sorted(target_dags.items()):
dag = source_dags[dag_id]
dm.fileloc = dag.fileloc
dm.owners = dag.owner
dm.is_active = True
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
dm._dag_display_property_value = dag._dag_display_property_value
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None for t in dag.tasks
)
dm.timetable_summary = dag.timetable.summary
dm.timetable_description = dag.timetable.description
dm.dataset_expression = dag.timetable.dataset_condition.as_expression()
dm.processor_subdir = processor_subdir

last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id)
if last_automated_run is None:
last_automated_data_interval = None
else:
last_automated_data_interval = dag.get_run_data_interval(last_automated_run)
if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
dm.next_dagrun_create_after = None
def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session) -> None:
orm_tags = {t.name: t for t in dm.tags}
for name, orm_tag in orm_tags.items():
if name not in tag_names:
session.delete(orm_tag)
dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tag_names.difference(orm_tags))


def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, session: Session) -> None:
orm_dag_owner_attributes = {obj.owner: obj for obj in dm.dag_owner_links}
for owner, obj in orm_dag_owner_attributes.items():
try:
link = dag_owner_links[owner]
except KeyError:
session.delete(obj)
else:
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)

if not dag.timetable.dataset_condition:
dm.schedule_dataset_references = []
dm.schedule_dataset_alias_references = []
# FIXME: STORE NEW REFERENCES.

dag_tags = set(dag.tags or ())
for orm_tag in (dm_tags := list(dm.tags or [])):
if orm_tag.name not in dag_tags:
session.delete(orm_tag)
dm.tags.remove(orm_tag)
orm_tag_names = {t.name for t in dm_tags}
for dag_tag in dag_tags:
if dag_tag not in orm_tag_names:
dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
dm.tags.append(dag_tag_orm)
session.add(dag_tag_orm)

dm_links = dm.dag_owner_links or []
for dm_link in dm_links:
if dm_link not in dag.owner_links:
session.delete(dm_link)
for owner_name, owner_link in dag.owner_links.items():
dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link)
session.add(dag_owner_orm)
if obj.link != link:
obj.link = link
dm.dag_owner_links.extend(
DagOwnerAttributes(dag_id=dm.dag_id, owner=owner, link=link)
for owner, link in dag_owner_links.items()
if owner not in orm_dag_owner_attributes
)


class DagModelOperation(NamedTuple):
"""Collect DAG objects and perform database operations for them."""

dags: dict[str, DAG]

def add_dags(self, *, session: Session) -> dict[str, DagModel]:
orm_dags = _find_orm_dags(self.dags, session=session)
orm_dags.update(
(model.dag_id, model)
for model in _create_orm_dags(
(dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags),
session=session,
)
)
return orm_dags

def update_dags(
self,
orm_dags: dict[str, DagModel],
*,
processor_subdir: str | None = None,
session: Session,
) -> None:
run_info = _RunInfo.calculate(self.dags, session=session)

for dag_id, dm in sorted(orm_dags.items()):
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
dag = self.dags[dag_id]
dm.fileloc = dag.fileloc
dm.owners = dag.owner
dm.is_active = True
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
dm._dag_display_property_value = dag._dag_display_property_value
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)
dm.timetable_summary = dag.timetable.summary
dm.timetable_description = dag.timetable.description
dm.dataset_expression = dag.timetable.dataset_condition.as_expression()
dm.processor_subdir = processor_subdir

last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id)
if last_automated_run is None:
last_automated_data_interval = None
else:
last_automated_data_interval = dag.get_run_data_interval(last_automated_run)
if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
dm.next_dagrun_create_after = None
else:
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)

if not dag.timetable.dataset_condition:
dm.schedule_dataset_references = []
dm.schedule_dataset_alias_references = []
# FIXME: STORE NEW REFERENCES.

if dag.tags:
_update_dag_tags(set(dag.tags), dm, session=session)
else: # Optimization: no references at all, just clear everything.
dm.tags = []
if dag.owner_links:
_update_dag_owner_links(dag.owner_links, dm, session=session)
else: # Optimization: no references at all, just clear everything.
dm.dag_owner_links = []


def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]:
Expand Down
24 changes: 6 additions & 18 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,28 +2643,16 @@ def bulk_write_to_db(
if not dags:
return

from airflow.dag_processing.collection import (
DatasetModelOperation,
collect_orm_dags,
create_orm_dag,
update_orm_dags,
)
from airflow.dag_processing.collection import DagModelOperation, DatasetModelOperation

log.info("Sync %s DAGs", len(dags))
dags_by_ids = {dag.dag_id: dag for dag in dags}
del dags

orm_dags = collect_orm_dags(dags_by_ids, session=session)
orm_dags.update(
(dag_id, create_orm_dag(dag, session=session))
for dag_id, dag in dags_by_ids.items()
if dag_id not in orm_dags
)
dag_op = DagModelOperation({dag.dag_id: dag for dag in dags})

update_orm_dags(dags_by_ids, orm_dags, processor_subdir=processor_subdir, session=session)
DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), session=session)
orm_dags = dag_op.add_dags(session=session)
dag_op.update_dags(orm_dags, processor_subdir=processor_subdir, session=session)
DagCode.bulk_sync_to_db((dag.fileloc for dag in dags), session=session)

dataset_op = DatasetModelOperation.collect(dags_by_ids)
dataset_op = DatasetModelOperation.collect(dag_op.dags)

orm_datasets = dataset_op.add_datasets(session=session)
orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session)
Expand Down