Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions dataframely/_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def my_filter(self) -> pl.DataFrame:
# --------------------------------------- UTILS -------------------------------------- #


def _common_primary_keys(columns: Iterable[type[Schema]]) -> set[str]:
return set.intersection(*[set(schema.primary_keys()) for schema in columns])
def _common_primary_key(columns: Iterable[type[Schema]]) -> set[str]:
return set.intersection(*[set(schema.primary_key()) for schema in columns])


# ------------------------------------------------------------------------------------ #
Expand Down Expand Up @@ -117,7 +117,7 @@ def __new__(
# 1) Check that there are overlapping primary keys that allow the application
# of filters.
if len(non_ignored_member_schemas) > 0 and len(result.filters) > 0:
if len(_common_primary_keys(non_ignored_member_schemas)) == 0:
if len(_common_primary_key(non_ignored_member_schemas)) == 0:
raise ImplementationError(
"Members of a collection must have an overlapping primary key "
"but did not find any."
Expand Down Expand Up @@ -145,19 +145,19 @@ def __new__(

# 3) Check that inlining for sampling is configured correctly.
if len(non_ignored_member_schemas) > 0:
common_primary_keys = _common_primary_keys(non_ignored_member_schemas)
common_primary_key = _common_primary_key(non_ignored_member_schemas)
inlined_columns: set[str] = set()
for member, info in result.members.items():
if info.inline_for_sampling:
if set(info.schema.primary_keys()) != common_primary_keys:
if set(info.schema.primary_key()) != common_primary_key:
raise ImplementationError(
f"Member '{member}' is inlined for sampling but its primary "
"key is a superset of the common primary key. Such a member "
"must not be inlined to be able to provide multiple values "
"for a single combination of the common primary key."
)
non_primary_key_columns = (
set(info.schema.column_names()) - common_primary_keys
set(info.schema.column_names()) - common_primary_key
)
if len(inlined_columns & non_primary_key_columns):
raise ImplementationError(
Expand Down Expand Up @@ -317,10 +317,10 @@ def non_ignored_members(cls) -> set[str]:
}

@classmethod
def common_primary_keys(cls) -> list[str]:
def common_primary_key(cls) -> list[str]:
"""The primary keys shared by non ignored members of the collection."""
return sorted(
_common_primary_keys(
_common_primary_key(
[
member.schema
for member in cls.members().values()
Expand Down
12 changes: 6 additions & 6 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def _build_rules(
rules: dict[str, Rule] = copy(custom)

# Add primary key validation to the list of rules if applicable
primary_keys = _primary_keys(columns)
if len(primary_keys) > 0:
rules["primary_key"] = Rule(~pl.struct(primary_keys).is_duplicated())
primary_key = _primary_key(columns)
if len(primary_key) > 0:
rules["primary_key"] = Rule(~pl.struct(primary_key).is_duplicated())

# Add column-specific rules
column_rules = {
Expand All @@ -43,7 +43,7 @@ def _build_rules(
return rules


def _primary_keys(columns: dict[str, Column]) -> list[str]:
def _primary_key(columns: dict[str, Column]) -> list[str]:
return list(k for k, col in columns.items() if col.primary_key)


Expand Down Expand Up @@ -178,9 +178,9 @@ def columns(cls) -> dict[str, Column]:
return columns

@classmethod
def primary_keys(cls) -> list[str]:
def primary_key(cls) -> list[str]:
"""The primary key columns in this schema (possibly empty)."""
return _primary_keys(cls.columns())
return _primary_key(cls.columns())

@classmethod
def _validation_rules(cls) -> dict[str, Rule]:
Expand Down
4 changes: 2 additions & 2 deletions dataframely/_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def warn_nullable_default_change() -> None:


@skip_if(env="DATAFRAMELY_NO_FUTURE_WARNINGS")
def warn_no_nullable_primary_keys() -> None:
def warn_no_nullable_primary_key() -> None:
warnings.warn(
"Nullable primary keys are not supported. "
"Nullable primary key columns are not supported. "
"Setting `nullable=True` on a primary key column is ignored "
"and will raise an error in a future release.",
FutureWarning,
Expand Down
30 changes: 15 additions & 15 deletions dataframely/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def sample(

g = generator or Generator()

primary_keys = cls.common_primary_keys()
requires_dependent_sampling = len(cls.members()) > 1 and len(primary_keys) > 0
primary_key = cls.common_primary_key()
requires_dependent_sampling = len(cls.members()) > 1 and len(primary_key) > 0

# 1) Preprocess all samples to make sampling efficient and ensure shared primary
# keys.
Expand All @@ -180,7 +180,7 @@ def sample(
# can properly sample members.
if requires_dependent_sampling:
if not all(
all(k in sample for k in primary_keys) for sample in processed_samples
all(k in sample for k in primary_key) for sample in processed_samples
):
raise ValueError("All samples must contain the common primary keys.")

Expand All @@ -192,7 +192,7 @@ def sample(
for member, schema in cls.member_schemas().items():
if (
not requires_dependent_sampling
or set(schema.primary_keys()) == set(primary_keys)
or set(schema.primary_key()) == set(primary_key)
or member_infos[member].ignored_in_filters
):
# If the primary keys are equal to the shared ones, each sample
Expand All @@ -206,7 +206,7 @@ def sample(
**(
{}
if member_infos[member].ignored_in_filters
else _extract_keys_if_exist(sample, primary_keys)
else _extract_keys_if_exist(sample, primary_key)
),
**_extract_keys_if_exist(
(
Expand All @@ -224,7 +224,7 @@ def sample(
# observe values for the member
member_overrides = [
{
**_extract_keys_if_exist(sample, primary_keys),
**_extract_keys_if_exist(sample, primary_key),
**_extract_keys_if_exist(item, schema.column_names()),
}
for sample in processed_samples
Expand Down Expand Up @@ -324,7 +324,7 @@ def _preprocess_sample(
has common primary keys, this sample **must** include **all** common
primary keys.
"""
if len(cls.members()) > 1 and len(cls.common_primary_keys()) > 0:
if len(cls.members()) > 1 and len(cls.common_primary_key()) > 0:
raise ValueError(
"`_preprocess_sample` must be overwritten for collections with more "
"than 1 member sharing a common primary key."
Expand Down Expand Up @@ -448,16 +448,16 @@ def filter(
filters = cls._filters()
if len(filters) > 0:
result_cls = cls._init(results)
primary_keys = cls.common_primary_keys()
primary_key = cls.common_primary_key()

keep: dict[str, pl.DataFrame] = {}
for name, filter in filters.items():
keep[name] = filter.logic(result_cls).select(primary_keys).collect()
keep[name] = filter.logic(result_cls).select(primary_key).collect()

# Using the filter results, we can define a joint data frame that we use to filter
# the input.
all_keep = join_all_inner(
[df.lazy() for df in keep.values()], on=primary_keys
[df.lazy() for df in keep.values()], on=primary_key
).collect()

# Now we can iterate over the results where we do the following:
Expand All @@ -467,22 +467,22 @@ def filter(
for member_name, filtered in results.items():
if cls.members()[member_name].ignored_in_filters:
continue
results[member_name] = filtered.join(all_keep, on=primary_keys)
results[member_name] = filtered.join(all_keep, on=primary_key)

new_failure_names = list(filters.keys())
new_failure_pks = [
filtered.select(primary_keys)
filtered.select(primary_key)
.lazy()
.unique()
.join(filter_keep.lazy(), on=primary_keys, how="anti")
.join(filter_keep.lazy(), on=primary_key, how="anti")
.with_columns(pl.lit(False).alias(name))
for name, filter_keep in keep.items()
]
# NOTE: The outer join might generate NULL values if a primary key is not
# filtered out by all filters. In this case, we want to assign a validation
# value of `True`.
all_new_failure_pks = join_all_outer(
new_failure_pks, on=primary_keys
new_failure_pks, on=primary_key
).with_columns(pl.col(new_failure_names).fill_null(True))

# At this point, we have a data frame with the primary keys of the *excluded*
Expand All @@ -500,7 +500,7 @@ def filter(
lf=pl.concat(
[
failure._lf,
filtered.lazy().join(all_new_failure_pks, on=primary_keys),
filtered.lazy().join(all_new_failure_pks, on=primary_key),
],
how="diagonal",
),
Expand Down
4 changes: 2 additions & 2 deletions dataframely/columns/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from dataframely._compat import pa, sa, sa_TypeEngine
from dataframely._deprecation import (
warn_no_nullable_primary_keys,
warn_no_nullable_primary_key,
warn_nullable_default_change,
)
from dataframely._polars import PolarsDataType
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
"""

if nullable and primary_key:
warn_no_nullable_primary_keys()
warn_no_nullable_primary_key()

if nullable is None:
warn_nullable_default_change()
Expand Down
4 changes: 2 additions & 2 deletions tests/collection/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class MyCollection(dy.Collection):
second: dy.LazyFrame[MySecondSchema] | None


def test_common_primary_keys() -> None:
assert MyCollection.common_primary_keys() == ["a"]
def test_common_primary_key() -> None:
assert MyCollection.common_primary_key() == ["a"]


def test_members() -> None:
Expand Down
20 changes: 10 additions & 10 deletions tests/collection/test_filter_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ class MyCollection(dy.Collection):
second: dy.LazyFrame[MySecondSchema]

@dy.filter()
def equal_primary_keys(self) -> pl.LazyFrame:
return self.first.join(self.second, on=self.common_primary_keys())
def equal_primary_key(self) -> pl.LazyFrame:
return self.first.join(self.second, on=self.common_primary_key())

@dy.filter()
def first_b_greater_second_b(self) -> pl.LazyFrame:
return self.first.join(
self.second, on=self.common_primary_keys(), how="full", coalesce=True
self.second, on=self.common_primary_key(), how="full", coalesce=True
).filter((pl.col("b") > pl.col("b_right")).fill_null(True))


Expand Down Expand Up @@ -128,11 +128,11 @@ def test_filter_with_filter_without_rule_violation(
assert_frame_equal(out.first, pl.LazyFrame({"a": [3], "b": [3]}))
assert_frame_equal(out.second, pl.LazyFrame({"a": [3], "b": [2]}))
assert failure["first"].counts() == {
"equal_primary_keys": 1,
"equal_primary_key": 1,
"first_b_greater_second_b": 1,
}
assert failure["second"].counts() == {
"equal_primary_keys": 2,
"equal_primary_key": 2,
"first_b_greater_second_b": 1,
}

Expand All @@ -150,8 +150,8 @@ def test_filter_with_filter_with_rule_violation(
assert isinstance(out, MyCollection)
assert_frame_equal(out.first, pl.LazyFrame({"a": [3], "b": [3]}))
assert_frame_equal(out.second, pl.LazyFrame({"a": [3], "b": [1]}))
assert failure["first"].counts() == {"equal_primary_keys": 2}
assert failure["second"].counts() == {"b|min": 1, "equal_primary_keys": 2}
assert failure["first"].counts() == {"equal_primary_key": 2}
assert failure["second"].counts() == {"b|min": 1, "equal_primary_key": 2}


# -------------------------------- VALIDATE WITH DATA -------------------------------- #
Expand Down Expand Up @@ -207,10 +207,10 @@ def test_validate_with_filter_without_rule_violation(
MyCollection.validate(data)

exc.match(r"Member 'first' failed validation")
exc.match(r"'equal_primary_keys' failed validation for 1 rows")
exc.match(r"'equal_primary_key' failed validation for 1 rows")
exc.match(r"'first_b_greater_second_b' failed validation for 1 rows")
exc.match(r"Member 'second' failed validation")
exc.match(r"'equal_primary_keys' failed validation for 2 rows")
exc.match(r"'equal_primary_key' failed validation for 2 rows")


def test_validate_with_filter_with_rule_violation(
Expand All @@ -228,6 +228,6 @@ def test_validate_with_filter_with_rule_violation(
MyCollection.validate(data)

exc.match(r"Member 'first' failed validation")
exc.match(r"'equal_primary_keys' failed validation for 2 rows")
exc.match(r"'equal_primary_key' failed validation for 2 rows")
exc.match(r"Member 'second' failed validation")
exc.match(r"'min' failed for 1 rows")
4 changes: 2 additions & 2 deletions tests/schema/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_nullability() -> None:
assert columns["e"].nullable


def test_primary_keys() -> None:
assert MySchema.primary_keys() == ["a", "b"]
def test_primary_key() -> None:
assert MySchema.primary_key() == ["a", "b"]


def test_no_rule_named_primary_key() -> None:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def test_warning_deprecated_nullable_primary_key(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("DATAFRAMELY_NO_FUTURE_WARNINGS", "")
with pytest.warns(FutureWarning, match="Nullable primary keys are not supported"):
with pytest.warns(
FutureWarning, match=r"Nullable primary key columns are not supported"
):
deprecated_nullable_primary_key()


Expand Down
Loading