Skip to content

fix(nodes): pydantic field type massaging improvements #7984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
126 changes: 106 additions & 20 deletions invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import inspect
import re
import sys
import types
import typing
import warnings
from abc import ABC, abstractmethod
from enum import Enum
Expand All @@ -20,6 +22,7 @@
Literal,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
Expand Down Expand Up @@ -93,6 +96,11 @@ class UIConfigBase(BaseModel):
)


class OriginalModelField(TypedDict):
annotation: Any
field_info: FieldInfo


class BaseInvocationOutput(BaseModel):
"""
Base class for all invocation outputs.
Expand Down Expand Up @@ -121,6 +129,9 @@ def get_type(cls) -> str:
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
return cls.model_fields["type"].default

_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation_output decorator."""

model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
Expand Down Expand Up @@ -251,6 +262,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
coerce_numbers_to_str=True,
)

_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation decorator."""


TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)

Expand Down Expand Up @@ -475,6 +489,47 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
return None


class NoDefaultSentinel:
pass


def validate_field_default(
cls_name: str, field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo
) -> None:
"""Validates the default value of a field against its pydantic field definition."""

assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"

# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
# We store the original default value in the json_schema_extra dict, so we can validate it here.
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)

if orig_default is NoDefaultSentinel:
return

TempDefaultValidator = create_model(cls_name, **{field_name: (annotation, field_info)})

# Validate the default value against the annotation
try:
TempDefaultValidator.model_validate({field_name: orig_default})
except Exception as e:
raise InvalidFieldError(
f'Default value for field "{field_name}" on invocation "{invocation_type}" is invalid, {e}'
) from e


def is_optional(annotation: Any) -> bool:
"""
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
"""
origin = typing.get_origin(annotation)
# PEP 604 unions (int|None) have origin types.UnionType
is_union = origin is typing.Union or origin is types.UnionType
if not is_union:
return False
return any(arg is type(None) for arg in typing.get_args(annotation))


def invocation(
invocation_type: str,
title: Optional[str] = None,
Expand Down Expand Up @@ -507,6 +562,24 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:

validate_fields(cls.model_fields, invocation_type)

fields: dict[str, tuple[Any, FieldInfo]] = {}

for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)

cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)

validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)

if field_info.default is None and not is_optional(annotation):
annotation = annotation | None

fields[field_name] = (annotation, field_info)

# Add OpenAPI schema extras
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
Expand Down Expand Up @@ -539,11 +612,17 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
# not work. Instead, we have to create a new class with the type field and patch the original class with it.

invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_annotation = Literal[invocation_type]
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)

# pydantic's Field function returns a FieldInfo, but they annotate it as returning a type so that type-checkers
# don't get confused by something like this:
# foo: str = Field() <-- this is a FieldInfo, not a str
# Unfortunately this means we need to use type: ignore here to avoid type-checker errors
fields["type"] = (invocation_type_annotation, invocation_type_field) # type: ignore

# Validate the `invoke()` method is implemented
if "invoke" in cls.__abstractmethods__:
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
Expand All @@ -565,17 +644,12 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
)

docstring = cls.__doc__
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(invocation_type_annotation, invocation_type_field),
)
cls.__doc__ = docstring
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
new_class.__doc__ = docstring

InvocationRegistry.register_invocation(cls)
InvocationRegistry.register_invocation(new_class)

return cls
return new_class

return wrapper

Expand All @@ -600,23 +674,35 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:

validate_fields(cls.model_fields, output_type)

fields: dict[str, tuple[Any, FieldInfo]] = {}

for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)

cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)

if field_info.default is not PydanticUndefined and is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)

# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_annotation = Literal[output_type]
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)

fields["type"] = (output_type_annotation, output_type_field) # type: ignore

docstring = cls.__doc__
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(output_type_annotation, output_type_field),
)
cls.__doc__ = docstring
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
new_class.__doc__ = docstring

InvocationRegistry.register_output(cls)
InvocationRegistry.register_output(new_class)

return cls
return new_class

return wrapper
4 changes: 0 additions & 4 deletions invokeai/app/invocations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""

images: list[ImageField] = InputField(
default=[],
min_length=1,
description="The images to batch over",
)
Expand Down Expand Up @@ -120,7 +119,6 @@ class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""

strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
Expand Down Expand Up @@ -176,7 +174,6 @@ class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""

integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
Expand Down Expand Up @@ -230,7 +227,6 @@ class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""

floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)
Expand Down
Loading