Skip to content

Commit 8e9f67f

Browse files
kaxilgot686-yandex
authored andcommitted
Use Protocol for OutletEventAccessor (apache#45762)
Follow-up of apache#45727 to use Protocol to allow auto-completion on IDE while not introducing runtime dep
1 parent c15bdf7 commit 8e9f67f

File tree

13 files changed

+63
-28
lines changed

13 files changed

+63
-28
lines changed

airflow/models/taskinstance.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
from airflow.models.dagrun import DagRun
164164
from airflow.models.operator import Operator
165165
from airflow.sdk.definitions.dag import DAG
166-
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
166+
from airflow.sdk.types import OutletEventAccessorsProtocol, RuntimeTaskInstanceProtocol
167167
from airflow.timetables.base import DataInterval
168168
from airflow.typing_compat import Literal, TypeGuard
169169
from airflow.utils.task_group import TaskGroup
@@ -2730,7 +2730,7 @@ def _run_raw_task(
27302730
)
27312731

27322732
def _register_asset_changes(
2733-
self, *, events: OutletEventAccessors, session: Session | None = None
2733+
self, *, events: OutletEventAccessorsProtocol, session: Session | None = None
27342734
) -> None:
27352735
if session:
27362736
TaskInstance._register_asset_changes_int(ti=self, events=events, session=session)
@@ -2740,7 +2740,7 @@ def _register_asset_changes(
27402740
@staticmethod
27412741
@provide_session
27422742
def _register_asset_changes_int(
2743-
ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION
2743+
ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: Session = NEW_SESSION
27442744
) -> None:
27452745
if TYPE_CHECKING:
27462746
assert ti.task

airflow/serialization/serialized_objects.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from airflow.sdk.definitions.asset import (
5757
Asset,
5858
AssetAlias,
59+
AssetAliasEvent,
5960
AssetAliasUniqueKey,
6061
AssetAll,
6162
AssetAny,
@@ -64,7 +65,7 @@
6465
BaseAsset,
6566
)
6667
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
67-
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
68+
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
6869
from airflow.serialization.dag_dependency import DagDependency
6970
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
7071
from airflow.serialization.helpers import serialize_template_field
@@ -80,7 +81,6 @@
8081
from airflow.utils.context import (
8182
ConnectionAccessor,
8283
Context,
83-
OutletEventAccessors,
8484
VariableAccessor,
8585
)
8686
from airflow.utils.db import LazySelectSequence

airflow/utils/context.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from sqlalchemy.sql.expression import Select, TextClause
6464

6565
from airflow.models.baseoperator import BaseOperator
66+
from airflow.sdk.types import OutletEventAccessorsProtocol
6667

6768
# NOTE: Please keep this in sync with the following:
6869
# * Context in task_sdk/src/airflow/sdk/definitions/context.py
@@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context:
331332
return cast(Context, new)
332333

333334

334-
def context_get_outlet_events(context: Context) -> OutletEventAccessors:
335+
def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol:
335336
try:
336337
return context["outlet_events"]
337338
except KeyError:

airflow/utils/operator_helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from airflow.utils.types import NOTSET
3030

3131
if TYPE_CHECKING:
32-
from airflow.utils.context import OutletEventAccessors
32+
from airflow.sdk.types import OutletEventAccessorsProtocol
3333

3434
P = ParamSpec("P")
3535
R = TypeVar("R")
@@ -230,7 +230,7 @@ def run(*args, **kwargs): ...
230230

231231
def ExecutionCallableRunner(
232232
func: Callable[P, R],
233-
outlet_events: OutletEventAccessors,
233+
outlet_events: OutletEventAccessorsProtocol,
234234
*,
235235
logger: logging.Logger,
236236
) -> _ExecutionCallableRunner:

providers/edge/src/airflow/providers/edge/example_dags/win_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
if TYPE_CHECKING:
4848
try:
49-
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol as TaskInstance
49+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
5050
except ImportError:
5151
from airflow.models import TaskInstance # type: ignore[assignment]
5252
from airflow.utils.context import Context

providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
if TYPE_CHECKING:
3333
try:
34-
from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
34+
from airflow.sdk.types import RuntimeTaskInstanceProtocol
3535
except ImportError:
3636
from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol # type: ignore[assignment]
3737
from airflow.utils.context import Context

task_sdk/src/airflow/sdk/definitions/asset/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -660,3 +660,12 @@ def as_expression(self) -> Any:
660660
:meta private:
661661
"""
662662
return {"all": [o.as_expression() for o in self.objects]}
663+
664+
665+
@attrs.define
666+
class AssetAliasEvent:
667+
"""Representation of asset event to be triggered by an asset alias."""
668+
669+
source_alias_name: str
670+
dest_asset_key: AssetUniqueKey
671+
extra: dict[str, Any]

task_sdk/src/airflow/sdk/definitions/context.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from airflow.models.operator import Operator
2828
from airflow.sdk.definitions.baseoperator import BaseOperator
2929
from airflow.sdk.definitions.dag import DAG
30-
from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol
30+
from airflow.sdk.types import (
31+
DagRunProtocol,
32+
OutletEventAccessorsProtocol,
33+
RuntimeTaskInstanceProtocol,
34+
)
3135

3236

3337
class Context(TypedDict, total=False):
@@ -38,8 +42,7 @@ class Context(TypedDict, total=False):
3842
dag_run: DagRunProtocol
3943
data_interval_end: datetime | None
4044
data_interval_start: datetime | None
41-
# outlet_events: OutletEventAccessors
42-
outlet_events: Any
45+
outlet_events: OutletEventAccessorsProtocol
4346
ds: str
4447
ds_nodash: str
4548
expanded_ti_count: int | None

task_sdk/src/airflow/sdk/execution_time/context.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from airflow.sdk.definitions.asset import (
2929
Asset,
3030
AssetAlias,
31+
AssetAliasEvent,
3132
AssetAliasUniqueKey,
3233
AssetNameRef,
3334
AssetRef,
@@ -174,15 +175,6 @@ def __eq__(self, other: object) -> bool:
174175
return True
175176

176177

177-
@attrs.define
178-
class AssetAliasEvent:
179-
"""Representation of asset event to be triggered by an asset alias."""
180-
181-
source_alias_name: str
182-
dest_asset_key: AssetUniqueKey
183-
extra: dict[str, Any]
184-
185-
186178
@attrs.define
187179
class OutletEventAccessor:
188180
"""Wrapper to access an outlet asset event in template."""

task_sdk/src/airflow/sdk/execution_time/task_runner.py

-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def get_template_context(self) -> Context:
137137
}
138138
context.update(context_from_server)
139139

140-
# TODO: We should use/move TypeDict from airflow.utils.context.Context
141140
return context
142141

143142
def render_templates(

task_sdk/src/airflow/sdk/definitions/protocols.py task_sdk/src/airflow/sdk/types.py

+27
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from typing import TYPE_CHECKING, Any, Protocol
2121

2222
if TYPE_CHECKING:
23+
from collections.abc import Iterator
2324
from datetime import datetime
2425

26+
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, BaseAssetUniqueKey
2527
from airflow.sdk.definitions.baseoperator import BaseOperator
2628

2729

@@ -65,3 +67,28 @@ def xcom_pull(
6567
) -> Any: ...
6668

6769
def xcom_push(self, key: str, value: Any) -> None: ...
70+
71+
72+
class OutletEventAccessorProtocol(Protocol):
73+
"""Protocol for managing access to a specific outlet event accessor."""
74+
75+
key: BaseAssetUniqueKey
76+
extra: dict[str, Any]
77+
asset_alias_events: list[AssetAliasEvent]
78+
79+
def __init__(
80+
self,
81+
*,
82+
key: BaseAssetUniqueKey,
83+
extra: dict[str, Any],
84+
asset_alias_events: list[AssetAliasEvent],
85+
) -> None: ...
86+
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ...
87+
88+
89+
class OutletEventAccessorsProtocol(Protocol):
90+
"""Protocol for managing access to outlet event accessors."""
91+
92+
def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
93+
def __len__(self) -> int: ...
94+
def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessorProtocol: ...

task_sdk/tests/execution_time/test_context.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@
2222
import pytest
2323

2424
from airflow.sdk import get_current_context
25-
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasUniqueKey, AssetUniqueKey
25+
from airflow.sdk.definitions.asset import (
26+
Asset,
27+
AssetAlias,
28+
AssetAliasEvent,
29+
AssetAliasUniqueKey,
30+
AssetUniqueKey,
31+
)
2632
from airflow.sdk.definitions.connection import Connection
2733
from airflow.sdk.definitions.variable import Variable
2834
from airflow.sdk.exceptions import ErrorType
2935
from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, ErrorResponse, VariableResult
3036
from airflow.sdk.execution_time.context import (
31-
AssetAliasEvent,
3237
ConnectionAccessor,
3338
OutletEventAccessor,
3439
OutletEventAccessors,

tests/serialization/test_serialized_objects.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@
4242
from airflow.models.xcom_arg import XComArg
4343
from airflow.operators.empty import EmptyOperator
4444
from airflow.providers.standard.operators.python import PythonOperator
45-
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
46-
from airflow.sdk.execution_time.context import AssetAliasEvent, OutletEventAccessor
45+
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey
46+
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
4747
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
4848
from airflow.serialization.serialized_objects import BaseSerialization
4949
from airflow.triggers.base import BaseTrigger
5050
from airflow.utils import timezone
51-
from airflow.utils.context import OutletEventAccessors
5251
from airflow.utils.db import LazySelectSequence
5352
from airflow.utils.operator_resources import Resources
5453
from airflow.utils.state import DagRunState, State

0 commit comments

Comments
 (0)