Skip to content

Commit

Permalink
Improve Pydantic model detection robustness (#11)
Browse files Browse the repository at this point in the history
It was previously too easy to hit some false positives
  • Loading branch information
Viicos committed May 6, 2024
1 parent bf0572d commit d00e91f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 14 deletions.
78 changes: 67 additions & 11 deletions src/flake8_pydantic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_decorator_names(decorator_list: list[ast.expr]) -> set[str]:
return names


def _has_pydantic_model_base(node: ast.ClassDef, include_root_model: bool) -> bool:
def _has_pydantic_model_base(node: ast.ClassDef, *, include_root_model: bool) -> bool:
model_class_names = {"BaseModel"}
if include_root_model:
model_class_names.add("RootModel")
Expand All @@ -42,15 +42,55 @@ def _has_model_config(node: ast.ClassDef) -> bool:
return False


PYDANTIC_FIELD_ARGUMENTS = {
"default",
"default_factory",
"alias",
"alias_priority",
"validation_alias",
"title",
"description",
"examples",
"exclude",
"discriminator",
"json_schema_extra",
"frozen",
"validate_default",
"repr",
"init",
"init_var",
"kw_only",
"pattern",
"strict",
"gt",
"ge",
"lt",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_length",
"max_length",
"union_mode",
}


def _has_field_function(node: ast.ClassDef) -> bool:
for stmt in node.body:
if isinstance(stmt, (ast.Assign, ast.AnnAssign)) and isinstance(stmt.value, ast.Call):
if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field":
# f = Field(...)
return True
if isinstance(stmt.value.func, ast.Attribute) and stmt.value.func.attr == "Field":
# f = pydantic.Field(...)
return True
if (
isinstance(stmt, (ast.Assign, ast.AnnAssign))
and isinstance(stmt.value, ast.Call)
and (
(isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "Field") # f = Field(...)
or (
isinstance(stmt.value.func, ast.Attribute) and stmt.value.func.attr == "Field"
) # f = pydantic.Field(...)
)
and all(kw.arg in PYDANTIC_FIELD_ARGUMENTS for kw in stmt.value.keywords if kw.arg is not None)
):
return True

return False


Expand Down Expand Up @@ -84,14 +124,30 @@ def _has_pydantic_decorator(node: ast.ClassDef) -> bool:
return False


PYDANTIC_METHODS = {
"model_construct",
"model_copy",
"model_dump",
"model_dump_json",
"model_json_schema",
"model_parametrized_name",
"model_rebuild",
"model_validate",
"model_validate_json",
"model_validate_strings",
}


def _has_pydantic_method(node: ast.ClassDef) -> bool:
for stmt in node.body:
if isinstance(stmt, ast.FunctionDef) and stmt.name.startswith(("model_", "__pydantic_")):
if isinstance(stmt, ast.FunctionDef) and (
stmt.name.startswith(("__pydantic_", "__get_pydantic_")) or stmt.name in PYDANTIC_METHODS
):
return True
return False


def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bool:
def is_pydantic_model(node: ast.ClassDef, *, include_root_model: bool = True) -> bool:
"""Determine if a class definition is a Pydantic model.
Multiple heuristics are use to determine if this is the case:
Expand All @@ -106,7 +162,7 @@ def is_pydantic_model(node: ast.ClassDef, include_root_model: bool = True) -> bo
return False

return (
_has_pydantic_model_base(node, include_root_model)
_has_pydantic_model_base(node, include_root_model=include_root_model)
or _has_model_config(node)
or _has_field_function(node)
or _has_annotated_field(node)
Expand Down
36 changes: 33 additions & 3 deletions tests/test_is_pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,34 @@ class SubModel(ParentModel):

HAS_FIELD_FUNCTION_1 = """
class SubModel(ParentModel):
a = Field()
a = Field(title="A")
"""

HAS_FIELD_FUNCTION_2 = """
class SubModel(ParentModel):
a: int = Field()
a: int = Field(gt=1)
"""

HAS_FIELD_FUNCTION_3 = """
class SubModel(ParentModel):
a = pydantic.Field()
a = pydantic.Field(alias="b")
"""

HAS_FIELD_FUNCTION_4 = """
class SubModel(ParentModel):
a: int = pydantic.Field(repr=True)
"""

HAS_FIELD_FUNCTION_5 = """
class SubModel(ParentModel):
a: int = pydantic.Field()
"""

HAS_FIELD_FUNCTION_6 = """
class SubModel(ParentModel):
a: int = pydantic.Field(1)
"""

USES_ANNOTATED_1 = """
class SubModel(ParentModel):
a: Annotated[int, ""]
Expand Down Expand Up @@ -86,12 +96,27 @@ class SubModel(ParentModel):
def __pydantic_some_method__(self): pass
"""

HAS_PYDANTIC_METHOD_3 = """
class SubModel(ParentModel):
def __get_pydantic_core_schema__(self): pass
"""

# Negative cases:
NO_BASES = """
class Model:
a = Field()
"""

UNRELATED_FIELD_ARG = """
class SubModel(ParentModel):
a: int = Field(some_arg=1)
"""

UNRELATED_MODEL_METHOD = """
class SubModel(ParentModel):
def model_unrelated(): pass
"""


@pytest.mark.parametrize(
["source", "expected"],
Expand All @@ -105,13 +130,18 @@ class Model:
(HAS_FIELD_FUNCTION_2, True),
(HAS_FIELD_FUNCTION_3, True),
(HAS_FIELD_FUNCTION_4, True),
(HAS_FIELD_FUNCTION_5, True),
(HAS_FIELD_FUNCTION_6, True),
(USES_ANNOTATED_1, True),
(USES_ANNOTATED_2, True),
(HAS_PYDANTIC_DECORATOR_1, True),
(HAS_PYDANTIC_DECORATOR_2, True),
(HAS_PYDANTIC_METHOD_1, True),
(HAS_PYDANTIC_METHOD_2, True),
(HAS_PYDANTIC_METHOD_3, True),
(NO_BASES, False),
(UNRELATED_FIELD_ARG, False),
(UNRELATED_MODEL_METHOD, False),
],
)
def test_is_pydantic_model(source: str, expected: bool) -> None:
Expand Down

0 comments on commit d00e91f

Please sign in to comment.