Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DataFrameModel iterable over the schema field names #1288

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 29 additions & 35 deletions pandera/api/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pandas as pd

from pandera.api.base.model import BaseModel
from pandera.api.base.model import BaseModel, MetaModel
from pandera.api.checks import Check
from pandera.api.pandas.components import Column, Index, MultiIndex
from pandera.api.pandas.container import DataFrameSchema
Expand Down Expand Up @@ -125,7 +125,21 @@ def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]:
return checks


class DataFrameModel(BaseModel):
class MetaDataFrameModel(MetaModel):
"""A metaclass for DataFrameModel to provide iter support."""

def to_schema(cls) -> DataFrameSchema:
"""Create :class:`~pandera.DataFrameSchema` from the class."""
raise NotImplementedError

def __iter__(cls) -> Iterable[str]:
"""Iterate over the fields of the schema"""
# False positive in metaclass context; pylint: disable=no-value-for-parameter
schema = cls.to_schema()
return iter(schema.columns)


class DataFrameModel(BaseModel, metaclass=MetaDataFrameModel):
"""Definition of a :class:`~pandera.api.pandas.container.DataFrameSchema`.

*new in 0.5.0*
Expand All @@ -151,17 +165,13 @@ class DataFrameModel(BaseModel):
@docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__)
def __new__(cls, *args, **kwargs) -> DataFrameBase[TDataFrameModel]: # type: ignore [misc]
"""%(validate_doc)s"""
return cast(
DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs)
)
return cast(DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs))

def __init_subclass__(cls, **kwargs):
"""Ensure :class:`~pandera.api.pandas.model_components.FieldInfo` instances."""
if "Config" in cls.__dict__:
cls.Config.name = (
cls.Config.name
if hasattr(cls.Config, "name")
else cls.__name__
cls.Config.name if hasattr(cls.Config, "name") else cls.__name__
)
else:
cls.Config = type("Config", (BaseConfig,), {"name": cls.__name__})
Expand Down Expand Up @@ -201,9 +211,7 @@ def __class_getitem__(
Type[TDataFrameModel], GENERIC_SCHEMA_CACHE[(cls, params)]
)

param_dict: Dict[TypeVar, Type[Any]] = dict(
zip(__parameters__, params)
)
param_dict: Dict[TypeVar, Type[Any]] = dict(zip(__parameters__, params))
extra: Dict[str, Any] = {"__annotations__": {}}
for field, (annot_info, field_info) in cls._collect_fields().items():
if isinstance(annot_info.arg, TypeVar):
Expand All @@ -214,9 +222,7 @@ def __class_getitem__(
extra["__annotations__"][field] = raw_annot
extra[field] = copy.deepcopy(field_info)

parameterized_name = (
f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]"
)
parameterized_name = f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]"
parameterized_cls = type(parameterized_name, (cls,), extra)
GENERIC_SCHEMA_CACHE[(cls, params)] = parameterized_cls
return parameterized_cls
Expand Down Expand Up @@ -323,9 +329,7 @@ def example(
**kwargs,
) -> DataFrameBase[TDataFrameModel]:
"""%(example_doc)s"""
return cast(
DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs)
)
return cast(DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs))

@classmethod
def _build_columns_index( # pylint:disable=too-many-locals
Expand All @@ -335,8 +339,7 @@ def _build_columns_index( # pylint:disable=too-many-locals
**multiindex_kwargs: Any,
) -> Tuple[Dict[str, Column], Optional[Union[Index, MultiIndex]],]:
index_count = sum(
annotation.origin in INDEX_TYPES
for annotation, _ in fields.values()
annotation.origin in INDEX_TYPES for annotation, _ in fields.values()
)

columns: Dict[str, Column] = {}
Expand Down Expand Up @@ -385,9 +388,7 @@ def _build_columns_index( # pylint:disable=too-many-locals
or annotation.raw_annotation in INDEX_TYPES
):
if annotation.optional:
raise SchemaInitError(
f"Index '{field_name}' cannot be Optional."
)
raise SchemaInitError(f"Index '{field_name}' cannot be Optional.")

if check_name is False or (
# default single index
Expand Down Expand Up @@ -442,34 +443,30 @@ def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]:
raise SchemaInitError(f"Found missing annotations: {missing}")

fields = {}
for field_name, annotation in annotations.items():
for field_name, annotation in reversed(annotations.items()):
field = attrs[field_name] # __init_subclass__ guarantees existence
if not isinstance(field, FieldInfo):
raise SchemaInitError(
f"'{field_name}' can only be assigned a 'Field', "
+ f"not a '{type(field)}.'"
)
fields[field.name] = (AnnotationInfo(annotation), field)
return fields
return dict(reversed(fields.items()))

@classmethod
def _collect_config_and_extras(
cls,
) -> Tuple[Type[BaseConfig], Dict[str, Any]]:
"""Collect config options from bases, splitting off unknown options."""
bases = inspect.getmro(cls)[:-1]
bases = tuple(
base for base in bases if issubclass(base, DataFrameModel)
)
bases = tuple(base for base in bases if issubclass(base, DataFrameModel))
root_model, *models = reversed(bases)

options, extras = _extract_config_options_and_extras(root_model.Config)

for model in models:
config = getattr(model, _CONFIG_KEY, {})
base_options, base_extras = _extract_config_options_and_extras(
config
)
base_options, base_extras = _extract_config_options_and_extras(config)
options.update(base_options)
extras.update(base_extras)

Expand All @@ -482,9 +479,7 @@ def _collect_check_infos(cls, key: str) -> List[CheckInfo]:
walk the inheritance tree.
"""
bases = inspect.getmro(cls)[:-2] # bases -> DataFrameModel -> object
bases = tuple(
base for base in bases if issubclass(base, DataFrameModel)
)
bases = tuple(base for base in bases if issubclass(base, DataFrameModel))

method_names = set()
check_infos = []
Expand Down Expand Up @@ -648,7 +643,6 @@ def _field_json_schema(field):
"title": dataframe_schema.name or "pandera.DataFrameSchema",
"type": "object",
"properties": {
field["name"]: _field_json_schema(field)
for field in table_schema["fields"]
field["name"]: _field_json_schema(field) for field in table_schema["fields"]
},
}
82 changes: 82 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,3 +1426,85 @@ class Config:
}
}
assert PanderaSchema.get_metadata() == expected


def test_iter_fieldnames():
"""
Test we can iterate over the `DataFrameModel` to get the field names.
"""

class PanderaSchema(pa.DataFrameModel):
id: Series[int]
product_name: Series[str]
price: Series[float]

expected = [
PanderaSchema.id,
PanderaSchema.product_name,
PanderaSchema.price,
]
assert list(PanderaSchema) == expected


def test_iter_fieldnames_inheritance():
"""
Test iterating over the fieldnames respects the order of inheritance.
"""

class PanderaSchema1(pa.DataFrameModel):
id: Series[int]
product_name: Series[str]
price: Series[float]

# Note: order of definition differs from order of inheritance
class PanderaSchema3(pa.DataFrameModel):
quality: Series[str]

class PanderaSchema2(pa.DataFrameModel):
quantity: Series[int]

class CombinedSchema(PanderaSchema1, PanderaSchema2, PanderaSchema3):
pass

expected = [
PanderaSchema1.id,
PanderaSchema1.product_name,
PanderaSchema1.price,
PanderaSchema2.quantity,
PanderaSchema3.quality,
]
assert list(CombinedSchema) == expected


def test_iter_fieldnames_df_index():
"""
Test iterating over the fieldnames as a way to index all columns of a dataframe.
"""

class PanderaSchema(pa.DataFrameModel):
id: Series[int]
product_name: Series[str]
price: Series[float]

class Config:
order = True

df = pd.DataFrame(
{
PanderaSchema.price: [1.0, 2.0, 3.0],
PanderaSchema.id: [1, 2, 3],
PanderaSchema.product_name: ["A", "B", "C"],
}
)
assert df.columns == [
PanderaSchema.price,
PanderaSchema.id,
PanderaSchema.product_name,
]
df = df[list(PanderaSchema)].copy()
assert df.columns == [
PanderaSchema.id,
PanderaSchema.product_name,
PanderaSchema.price,
]
PanderaSchema.validate(df)