Skip to content

Commit

Permalink
Try to make dataset objects totally unhashable, redux (apache#42066)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
topherinternational and uranusjr authored Sep 7, 2024
1 parent af753c6 commit 025ce81
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 39 deletions.
34 changes: 15 additions & 19 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,20 +206,12 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
raise NotImplementedError


@attr.define()
@attr.define(unsafe_hash=False)
class DatasetAlias(BaseDataset):
"""A represeation of dataset alias which is used to create dataset during the runtime."""

name: str

def __eq__(self, other: Any) -> bool:
if isinstance(other, DatasetAlias):
return self.name == other.name
return NotImplemented

def __hash__(self) -> int:
return hash(self.name)

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
"""
Iterate a dataset alias as dag dependency.
Expand All @@ -241,29 +233,33 @@ class DatasetAliasEvent(TypedDict):
dest_dataset_uri: str


@attr.define()
def _set_extra_default(extra: dict | None) -> dict:
"""
Automatically convert None to an empty dict.
This allows the caller site to continue doing ``Dataset(uri, extra=None)``,
but still allow the ``extra`` attribute to always be a dict.
"""
if extra is None:
return {}
return extra


@attr.define(unsafe_hash=False)
class Dataset(os.PathLike, BaseDataset):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(
converter=_sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None
extra: dict[str, Any] = attr.field(factory=dict, converter=_set_extra_default)

__version__: ClassVar[int] = 1

def __fspath__(self) -> str:
return self.uri

def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.uri == other.uri
return NotImplemented

def __hash__(self) -> int:
return hash(self.uri)

@property
def normalized_uri(self) -> str | None:
"""
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2786,12 +2786,12 @@ def bulk_write_to_db(
curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
for task in dag.tasks:
dataset_outlets: list[Dataset] = []
dataset_alias_outlets: set[DatasetAlias] = set()
dataset_alias_outlets: list[DatasetAlias] = []
for outlet in task.outlets:
if isinstance(outlet, Dataset):
dataset_outlets.append(outlet)
elif isinstance(outlet, DatasetAlias):
dataset_alias_outlets.add(outlet)
dataset_alias_outlets.append(outlet)

if not dataset_outlets:
if curr_outlet_references:
Expand Down
4 changes: 4 additions & 0 deletions newsfragments/42054.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Dataset and DatasetAlias are no longer hashable

This means they can no longer be used as dict keys or put into a set. Dataset's
equality logic is also tweaked slightly to consider the extra dict.
14 changes: 4 additions & 10 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,6 @@ def test_not_equal_when_different_uri():
assert dataset1 != dataset2


def test_hash():
uri = "s3://example/dataset"
dataset = Dataset(uri=uri)
hash(dataset)


def test_dataset_logic_operations():
result_or = dataset1 | dataset2
assert isinstance(result_or, DatasetAny)
Expand Down Expand Up @@ -187,10 +181,10 @@ def test_datasetbooleancondition_evaluate_iter():
assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False

# Testing iter_datasets indirectly through the subclasses
datasets_any = set(any_condition.iter_datasets())
datasets_all = set(all_condition.iter_datasets())
assert datasets_any == {("s3://bucket1/data1", dataset1), ("s3://bucket2/data2", dataset2)}
assert datasets_all == {("s3://bucket1/data1", dataset1), ("s3://bucket2/data2", dataset2)}
datasets_any = dict(any_condition.iter_datasets())
datasets_all = dict(all_condition.iter_datasets())
assert datasets_any == {"s3://bucket1/data1": dataset1, "s3://bucket2/data2": dataset2}
assert datasets_all == {"s3://bucket1/data1": dataset1, "s3://bucket2/data2": dataset2}


@pytest.mark.parametrize(
Expand Down
75 changes: 68 additions & 7 deletions tests/lineage/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,77 @@ def test_create_dataset(self, mock_providers_manager):
def create_dataset(arg1, arg2="default", extra=None):
return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra)

mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset}
test_scheme = "myscheme"
mock_providers_manager.return_value.dataset_factories = {test_scheme: create_dataset}

test_uri = "urischeme://value_a/value_b"
test_kwargs = {"arg1": "value_1"}
test_kwargs_uri = "myscheme://value_1/default"
test_extra = {"key": "value"}

# test uri arg - should take precedence over the keyword args + scheme
assert self.collector.create_dataset(
scheme=test_scheme, uri=test_uri, dataset_kwargs=test_kwargs, dataset_extra=None
) == Dataset(test_uri)
assert self.collector.create_dataset(
scheme=test_scheme, uri=test_uri, dataset_kwargs=test_kwargs, dataset_extra={}
) == Dataset(test_uri)
assert self.collector.create_dataset(
scheme=test_scheme, uri=test_uri, dataset_kwargs=test_kwargs, dataset_extra=test_extra
) == Dataset(test_uri, extra=test_extra)

# test keyword args
assert self.collector.create_dataset(
scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None
) == Dataset(test_kwargs_uri)
assert self.collector.create_dataset(
scheme="myscheme", uri=None, dataset_kwargs={"arg1": "value_1"}, dataset_extra=None
) == Dataset("myscheme://value_1/default")
scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra={}
) == Dataset(test_kwargs_uri)
assert self.collector.create_dataset(
scheme="myscheme",
scheme=test_scheme,
uri=None,
dataset_kwargs={"arg1": "value_1", "arg2": "value_2"},
dataset_extra={"key": "value"},
) == Dataset("myscheme://value_1/value_2", extra={"key": "value"})
dataset_kwargs={**test_kwargs, "arg2": "value_2"},
dataset_extra=test_extra,
) == Dataset("myscheme://value_1/value_2", extra=test_extra)

# missing both uri and scheme
assert (
self.collector.create_dataset(
scheme=None, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None
)
is None
)

@patch("airflow.lineage.hook.ProvidersManager")
def test_create_dataset_no_factory(self, mock_providers_manager):
test_scheme = "myscheme"
mock_providers_manager.return_value.dataset_factories = {}

test_kwargs = {"arg1": "value_1"}

assert (
self.collector.create_dataset(
scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None
)
is None
)

@patch("airflow.lineage.hook.ProvidersManager")
def test_create_dataset_factory_exception(self, mock_providers_manager):
def create_dataset(extra=None, **kwargs):
raise RuntimeError("Factory error")

test_scheme = "myscheme"
mock_providers_manager.return_value.dataset_factories = {test_scheme: create_dataset}

test_kwargs = {"arg1": "value_1"}

assert (
self.collector.create_dataset(
scheme=test_scheme, uri=None, dataset_kwargs=test_kwargs, dataset_extra=None
)
is None
)

def test_collected_datasets(self):
context_input = MagicMock()
Expand Down
2 changes: 1 addition & 1 deletion tests/timetables/test_datasets_timetable.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_serialization(dataset_timetable: DatasetOrTimeSchedule, monkeypatch: An
"timetable": "mock_serialized_timetable",
"dataset_condition": {
"__type": "dataset_all",
"objects": [{"__type": "dataset", "uri": "test_dataset", "extra": None}],
"objects": [{"__type": "dataset", "uri": "test_dataset", "extra": {}}],
},
}

Expand Down

0 comments on commit 025ce81

Please sign in to comment.