From 35de0de5dc4ae25e3feba3225159f562f643e477 Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Wed, 9 Jul 2025 19:20:25 +0200 Subject: [PATCH] refactor!: Rename `primary_keys` to `primary_key` --- dataframely/_base_collection.py | 16 ++++++------- dataframely/_base_schema.py | 12 +++++----- dataframely/_deprecation.py | 4 ++-- dataframely/collection.py | 30 ++++++++++++------------ dataframely/columns/_base.py | 4 ++-- tests/collection/test_base.py | 4 ++-- tests/collection/test_filter_validate.py | 20 ++++++++-------- tests/schema/test_base.py | 4 ++-- tests/test_deprecation.py | 4 +++- 9 files changed, 50 insertions(+), 48 deletions(-) diff --git a/dataframely/_base_collection.py b/dataframely/_base_collection.py index f0e447a..1a661e4 100644 --- a/dataframely/_base_collection.py +++ b/dataframely/_base_collection.py @@ -60,8 +60,8 @@ def my_filter(self) -> pl.DataFrame: # --------------------------------------- UTILS -------------------------------------- # -def _common_primary_keys(columns: Iterable[type[Schema]]) -> set[str]: - return set.intersection(*[set(schema.primary_keys()) for schema in columns]) +def _common_primary_key(columns: Iterable[type[Schema]]) -> set[str]: + return set.intersection(*[set(schema.primary_key()) for schema in columns]) # ------------------------------------------------------------------------------------ # @@ -117,7 +117,7 @@ def __new__( # 1) Check that there are overlapping primary keys that allow the application # of filters. if len(non_ignored_member_schemas) > 0 and len(result.filters) > 0: - if len(_common_primary_keys(non_ignored_member_schemas)) == 0: + if len(_common_primary_key(non_ignored_member_schemas)) == 0: raise ImplementationError( "Members of a collection must have an overlapping primary key " "but did not find any." @@ -145,11 +145,11 @@ def __new__( # 3) Check that inlining for sampling is configured correctly. if len(non_ignored_member_schemas) > 0: - common_primary_keys = _common_primary_keys(non_ignored_member_schemas) + common_primary_key = _common_primary_key(non_ignored_member_schemas) inlined_columns: set[str] = set() for member, info in result.members.items(): if info.inline_for_sampling: - if set(info.schema.primary_keys()) != common_primary_keys: + if set(info.schema.primary_key()) != common_primary_key: raise ImplementationError( f"Member '{member}' is inlined for sampling but its primary " "key is a superset of the common primary key. Such a member " @@ -157,7 +157,7 @@ def __new__( "for a single combination of the common primary key." ) non_primary_key_columns = ( - set(info.schema.column_names()) - common_primary_keys + set(info.schema.column_names()) - common_primary_key ) if len(inlined_columns & non_primary_key_columns): raise ImplementationError( @@ -317,10 +317,10 @@ def non_ignored_members(cls) -> set[str]: } @classmethod - def common_primary_keys(cls) -> list[str]: + def common_primary_key(cls) -> list[str]: """The primary keys shared by non ignored members of the collection.""" return sorted( - _common_primary_keys( + _common_primary_key( [ member.schema for member in cls.members().values() diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index 0ecc546..0866a31 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -28,9 +28,9 @@ def _build_rules( rules: dict[str, Rule] = copy(custom) # Add primary key validation to the list of rules if applicable - primary_keys = _primary_keys(columns) - if len(primary_keys) > 0: - rules["primary_key"] = Rule(~pl.struct(primary_keys).is_duplicated()) + primary_key = _primary_key(columns) + if len(primary_key) > 0: + rules["primary_key"] = Rule(~pl.struct(primary_key).is_duplicated()) # Add column-specific rules column_rules = { @@ -43,7 +43,7 @@ def _build_rules( return rules -def _primary_keys(columns: dict[str, Column]) -> list[str]: +def _primary_key(columns: dict[str, Column]) -> list[str]: return list(k for k, col in columns.items() if col.primary_key) @@ -178,9 +178,9 @@ def columns(cls) -> dict[str, Column]: return columns @classmethod - def primary_keys(cls) -> list[str]: + def primary_key(cls) -> list[str]: """The primary key columns in this schema (possibly empty).""" - return _primary_keys(cls.columns()) + return _primary_key(cls.columns()) @classmethod def _validation_rules(cls) -> dict[str, Rule]: diff --git a/dataframely/_deprecation.py b/dataframely/_deprecation.py index 03b390f..42cf951 100644 --- a/dataframely/_deprecation.py +++ b/dataframely/_deprecation.py @@ -40,9 +40,9 @@ def warn_nullable_default_change() -> None: @skip_if(env="DATAFRAMELY_NO_FUTURE_WARNINGS") -def warn_no_nullable_primary_keys() -> None: +def warn_no_nullable_primary_key() -> None: warnings.warn( - "Nullable primary keys are not supported. " + "Nullable primary key columns are not supported. " "Setting `nullable=True` on a primary key column is ignored " "and will raise an error in a future release.", FutureWarning, diff --git a/dataframely/collection.py b/dataframely/collection.py index 2c3140a..8512409 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -161,8 +161,8 @@ def sample( g = generator or Generator() - primary_keys = cls.common_primary_keys() - requires_dependent_sampling = len(cls.members()) > 1 and len(primary_keys) > 0 + primary_key = cls.common_primary_key() + requires_dependent_sampling = len(cls.members()) > 1 and len(primary_key) > 0 # 1) Preprocess all samples to make sampling efficient and ensure shared primary # keys. @@ -180,7 +180,7 @@ def sample( # can properly sample members. if requires_dependent_sampling: if not all( - all(k in sample for k in primary_keys) for sample in processed_samples + all(k in sample for k in primary_key) for sample in processed_samples ): raise ValueError("All samples must contain the common primary keys.") @@ -192,7 +192,7 @@ def sample( for member, schema in cls.member_schemas().items(): if ( not requires_dependent_sampling - or set(schema.primary_keys()) == set(primary_keys) + or set(schema.primary_key()) == set(primary_key) or member_infos[member].ignored_in_filters ): # If the primary keys are equal to the shared ones, each sample @@ -206,7 +206,7 @@ def sample( **( {} if member_infos[member].ignored_in_filters - else _extract_keys_if_exist(sample, primary_keys) + else _extract_keys_if_exist(sample, primary_key) ), **_extract_keys_if_exist( ( @@ -224,7 +224,7 @@ def sample( # observe values for the member member_overrides = [ { - **_extract_keys_if_exist(sample, primary_keys), + **_extract_keys_if_exist(sample, primary_key), **_extract_keys_if_exist(item, schema.column_names()), } for sample in processed_samples @@ -324,7 +324,7 @@ def _preprocess_sample( has common primary keys, this sample **must** include **all** common primary keys. """ - if len(cls.members()) > 1 and len(cls.common_primary_keys()) > 0: + if len(cls.members()) > 1 and len(cls.common_primary_key()) > 0: raise ValueError( "`_preprocess_sample` must be overwritten for collections with more " "than 1 member sharing a common primary key." @@ -448,16 +448,16 @@ def filter( filters = cls._filters() if len(filters) > 0: result_cls = cls._init(results) - primary_keys = cls.common_primary_keys() + primary_key = cls.common_primary_key() keep: dict[str, pl.DataFrame] = {} for name, filter in filters.items(): - keep[name] = filter.logic(result_cls).select(primary_keys).collect() + keep[name] = filter.logic(result_cls).select(primary_key).collect() # Using the filter results, we can define a joint data frame that we use to filter # the input. all_keep = join_all_inner( - [df.lazy() for df in keep.values()], on=primary_keys + [df.lazy() for df in keep.values()], on=primary_key ).collect() # Now we can iterate over the results where we do the following: @@ -467,14 +467,14 @@ def filter( for member_name, filtered in results.items(): if cls.members()[member_name].ignored_in_filters: continue - results[member_name] = filtered.join(all_keep, on=primary_keys) + results[member_name] = filtered.join(all_keep, on=primary_key) new_failure_names = list(filters.keys()) new_failure_pks = [ - filtered.select(primary_keys) + filtered.select(primary_key) .lazy() .unique() - .join(filter_keep.lazy(), on=primary_keys, how="anti") + .join(filter_keep.lazy(), on=primary_key, how="anti") .with_columns(pl.lit(False).alias(name)) for name, filter_keep in keep.items() ] @@ -482,7 +482,7 @@ def filter( # filtered out by all filters. In this case, we want to assign a validation # value of `True`. all_new_failure_pks = join_all_outer( - new_failure_pks, on=primary_keys + new_failure_pks, on=primary_key ).with_columns(pl.col(new_failure_names).fill_null(True)) # At this point, we have a data frame with the primary keys of the *excluded* @@ -500,7 +500,7 @@ def filter( lf=pl.concat( [ failure._lf, - filtered.lazy().join(all_new_failure_pks, on=primary_keys), + filtered.lazy().join(all_new_failure_pks, on=primary_key), ], how="diagonal", ), diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index ab28a46..f178b7c 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -13,7 +13,7 @@ from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._deprecation import ( - warn_no_nullable_primary_keys, + warn_no_nullable_primary_key, warn_nullable_default_change, ) from dataframely._polars import PolarsDataType @@ -73,7 +73,7 @@ def __init__( """ if nullable and primary_key: - warn_no_nullable_primary_keys() + warn_no_nullable_primary_key() if nullable is None: warn_nullable_default_change() diff --git a/tests/collection/test_base.py b/tests/collection/test_base.py index 3039b43..82d4cd1 100644 --- a/tests/collection/test_base.py +++ b/tests/collection/test_base.py @@ -21,8 +21,8 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[MySecondSchema] | None -def test_common_primary_keys() -> None: - assert MyCollection.common_primary_keys() == ["a"] +def test_common_primary_key() -> None: + assert MyCollection.common_primary_key() == ["a"] def test_members() -> None: diff --git a/tests/collection/test_filter_validate.py b/tests/collection/test_filter_validate.py index 4c6cb15..b4d0269 100644 --- a/tests/collection/test_filter_validate.py +++ b/tests/collection/test_filter_validate.py @@ -29,13 +29,13 @@ class MyCollection(dy.Collection): second: dy.LazyFrame[MySecondSchema] @dy.filter() - def equal_primary_keys(self) -> pl.LazyFrame: - return self.first.join(self.second, on=self.common_primary_keys()) + def equal_primary_key(self) -> pl.LazyFrame: + return self.first.join(self.second, on=self.common_primary_key()) @dy.filter() def first_b_greater_second_b(self) -> pl.LazyFrame: return self.first.join( - self.second, on=self.common_primary_keys(), how="full", coalesce=True + self.second, on=self.common_primary_key(), how="full", coalesce=True ).filter((pl.col("b") > pl.col("b_right")).fill_null(True)) @@ -128,11 +128,11 @@ def test_filter_with_filter_without_rule_violation( assert_frame_equal(out.first, pl.LazyFrame({"a": [3], "b": [3]})) assert_frame_equal(out.second, pl.LazyFrame({"a": [3], "b": [2]})) assert failure["first"].counts() == { - "equal_primary_keys": 1, + "equal_primary_key": 1, "first_b_greater_second_b": 1, } assert failure["second"].counts() == { - "equal_primary_keys": 2, + "equal_primary_key": 2, "first_b_greater_second_b": 1, } @@ -150,8 +150,8 @@ def test_filter_with_filter_with_rule_violation( assert isinstance(out, MyCollection) assert_frame_equal(out.first, pl.LazyFrame({"a": [3], "b": [3]})) assert_frame_equal(out.second, pl.LazyFrame({"a": [3], "b": [1]})) - assert failure["first"].counts() == {"equal_primary_keys": 2} - assert failure["second"].counts() == {"b|min": 1, "equal_primary_keys": 2} + assert failure["first"].counts() == {"equal_primary_key": 2} + assert failure["second"].counts() == {"b|min": 1, "equal_primary_key": 2} # -------------------------------- VALIDATE WITH DATA -------------------------------- # @@ -207,10 +207,10 @@ def test_validate_with_filter_without_rule_violation( MyCollection.validate(data) exc.match(r"Member 'first' failed validation") - exc.match(r"'equal_primary_keys' failed validation for 1 rows") + exc.match(r"'equal_primary_key' failed validation for 1 rows") exc.match(r"'first_b_greater_second_b' failed validation for 1 rows") exc.match(r"Member 'second' failed validation") - exc.match(r"'equal_primary_keys' failed validation for 2 rows") + exc.match(r"'equal_primary_key' failed validation for 2 rows") def test_validate_with_filter_with_rule_violation( @@ -228,6 +228,6 @@ def test_validate_with_filter_with_rule_violation( MyCollection.validate(data) exc.match(r"Member 'first' failed validation") - exc.match(r"'equal_primary_keys' failed validation for 2 rows") + exc.match(r"'equal_primary_key' failed validation for 2 rows") exc.match(r"Member 'second' failed validation") exc.match(r"'min' failed for 1 rows") diff --git a/tests/schema/test_base.py b/tests/schema/test_base.py index 8df1275..7404873 100644 --- a/tests/schema/test_base.py +++ b/tests/schema/test_base.py @@ -44,8 +44,8 @@ def test_nullability() -> None: assert columns["e"].nullable -def test_primary_keys() -> None: - assert MySchema.primary_keys() == ["a", "b"] +def test_primary_key() -> None: + assert MySchema.primary_key() == ["a", "b"] def test_no_rule_named_primary_key() -> None: diff --git a/tests/test_deprecation.py b/tests/test_deprecation.py index 6982b98..278e794 100644 --- a/tests/test_deprecation.py +++ b/tests/test_deprecation.py @@ -40,7 +40,9 @@ def test_warning_deprecated_nullable_primary_key( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("DATAFRAMELY_NO_FUTURE_WARNINGS", "") - with pytest.warns(FutureWarning, match="Nullable primary keys are not supported"): + with pytest.warns( + FutureWarning, match=r"Nullable primary key columns are not supported" + ): deprecated_nullable_primary_key()