diff --git a/tierkreis/tests/controller/test_types.py b/tierkreis/tests/controller/test_types.py index 4beca82fc..a0c3f851a 100644 --- a/tierkreis/tests/controller/test_types.py +++ b/tierkreis/tests/controller/test_types.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from datetime import datetime from types import NoneType, UnionType -from typing import TypeVar from uuid import UUID import pytest @@ -11,7 +10,6 @@ from tierkreis.controller.data.types import ( PType, bytes_from_ptype, - generics_in_ptype, is_ptype, ptype_from_bytes, ) @@ -135,19 +133,3 @@ def test_ptype_from_annotation(ptype: type[PType]) -> None: @pytest.mark.parametrize("ptype", fail_list) def test_ptype_from_annotation_fails(ptype: type[PType]) -> None: assert not is_ptype(ptype) - - -S = TypeVar("S") -T = TypeVar("T") - -generic_types = [] -generic_types.append((list[T], {str(T)})) # type: ignore[valid-type] -generic_types.append((list[S | T], {str(S), str(T)})) # type: ignore[valid-type] -generic_types.append((list[list[list[T]]], {str(T)})) # type: ignore[valid-type] -generic_types.append((tuple[S, T], {str(S), str(T)})) # type: ignore[valid-type] -generic_types.append((UntupledModel[S, T], {str(S), str(T)})) # type: ignore[valid-type] - - -@pytest.mark.parametrize(("ptype", "generics"), generic_types) -def test_generic_types(ptype: type[PType], generics: set[type[PType]]) -> None: - assert generics_in_ptype(ptype) == generics diff --git a/tierkreis/tierkreis/builder.py b/tierkreis/tierkreis/builder.py index 9c7b6307e..7d3c7398d 100644 --- a/tierkreis/tierkreis/builder.py +++ b/tierkreis/tierkreis/builder.py @@ -70,8 +70,8 @@ class TypedGraphRef[Ins: TModel, Outs: TModel]: """ graph_ref: ValueRef - outputs_type: type[Outs] inputs_type: type[Ins] + outputs_type: type[Outs] class LoopOutput(TNamedModel, Protocol): @@ -151,7 +151,7 @@ def ref(self) -> TypedGraphRef[Inputs, Outputs]: :return: The ref of the typed graph. :rtype: TypedGraphRef[Inputs, Outputs] """ - return TypedGraphRef((-1, "body"), self.outputs_type, self.inputs_type) + return TypedGraphRef((-1, "body"), self.inputs_type, self.outputs_type) def outputs(self, outputs: Outputs) -> None: """Set output nodes of a graph. @@ -279,23 +279,11 @@ def task[Out: TModel](self, func: Function[Out]) -> Out: OutModel = func.out() # noqa: N806 return init_tmodel(OutModel, lambda p: (idx, p)) - @overload - def eval[A: TModel, B: TModel]( - self, - body: TypedGraphRef[A, B], - eval_inputs: A, - ) -> B: ... - @overload def eval[A: TModel, B: TModel]( self, - body: GraphBuilder[A, B], + body: GraphBuilder[A, B] | TypedGraphRef[A, B], eval_inputs: A, - ) -> B: ... - def eval[A: TModel, B: TModel]( - self, - body: GraphBuilder[A, B] | TypedGraphRef, - eval_inputs: Any, - ) -> Any: + ) -> B: """Add a evaluation node to the graph. This will evaluate a nested graph with the given inputs. @@ -314,20 +302,6 @@ def eval[A: TModel, B: TModel]( idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(eval_inputs))("dummy") return init_tmodel(body.outputs_type, lambda p: (idx, p)) - @overload - def loop[A: TModel, B: LoopOutput]( - self, - body: TypedGraphRef[A, B], - loop_inputs: A, - name: str | None = None, - ) -> B: ... - @overload - def loop[A: TModel, B: LoopOutput]( - self, - body: GraphBuilder[A, B], - loop_inputs: A, - name: str | None = None, - ) -> B: ... def loop[A: TModel, B: LoopOutput]( self, body: TypedGraphRef[A, B] | GraphBuilder[A, B], diff --git a/tierkreis/tierkreis/controller/data/models.py b/tierkreis/tierkreis/controller/data/models.py index ffdc43be9..df117053b 100644 --- a/tierkreis/tierkreis/controller/data/models.py +++ b/tierkreis/tierkreis/controller/data/models.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from inspect import isclass -from itertools import chain from typing import ( Any, Callable, @@ -25,7 +24,7 @@ RestrictedNamedTuple, ValueRef, ) -from tierkreis.controller.data.types import PType, generics_in_ptype +from tierkreis.controller.data.types import PType TKR_PORTMAPPING_FLAG = "__tkr_portmapping__" @@ -143,15 +142,3 @@ def init_tmodel[T: TModel](tmodel: type[T], input_fn: Callable[[str], ValueRef]) return cast("T", model(*args)) (ref,) = fields.values() return tmodel(*ref) - - -def generics_in_pmodel(pmodel: type[PModel]) -> set[str]: - if is_portmapping(pmodel): - origin = get_origin(pmodel) - if origin is not None: - return generics_in_pmodel(origin) - - x = [generics_in_ptype(pmodel.__annotations__[t]) for t in model_fields(pmodel)] - return set(chain(*x)) - - return generics_in_ptype(pmodel) diff --git a/tierkreis/tierkreis/controller/data/types.py b/tierkreis/tierkreis/controller/data/types.py index 7d3b8eff8..33ff544f1 100644 --- a/tierkreis/tierkreis/controller/data/types.py +++ b/tierkreis/tierkreis/controller/data/types.py @@ -1,7 +1,6 @@ """Valid Python types for annotating worker functions and their serialisation.""" # ruff: noqa: ANN001 ANN003 ANN401 due to serialization and inheritance from json -import collections.abc import json import logging import pickle @@ -9,7 +8,6 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from inspect import Parameter, _empty, isclass -from itertools import chain from types import NoneType, UnionType from typing import ( Annotated, @@ -26,7 +24,6 @@ ) from pydantic import BaseModel, ValidationError -from pydantic._internal._generics import get_args as pydantic_get_args from typing_extensions import TypeIs from tierkreis.controller.data.core import ( @@ -184,11 +181,11 @@ def _is_generic(o) -> TypeIs[type[TypeVar]]: def _is_list(ptype: object) -> TypeIs[type[Sequence[PType]]]: - return get_origin(ptype) == collections.abc.Sequence or get_origin(ptype) is list + return get_origin(ptype) == Sequence or get_origin(ptype) is list def _is_mapping(ptype: object) -> TypeIs[type[Mapping[str, PType]]]: - return get_origin(ptype) is collections.abc.Mapping or get_origin(ptype) is dict + return get_origin(ptype) is Mapping or get_origin(ptype) is dict def _is_tuple(o: object) -> TypeIs[type[tuple[Any, ...]]]: @@ -259,10 +256,10 @@ def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> JsonType: case tuple(): args = get_args(annotation) or [None] * len(ptype) return tuple([ser_from_ptype(p, args[i]) for i, p in enumerate(ptype)]) - case collections.abc.Sequence(): + case Sequence(): arg = get_args(annotation)[0] if get_args(annotation) else None return [ser_from_ptype(p, arg) for p in ptype] - case collections.abc.Mapping(): + case Mapping(): arg = get_args(annotation)[1] if get_args(annotation) else None return {k: ser_from_ptype(p, arg) for k, p in ptype.items()} case DictConvertible(): @@ -371,14 +368,14 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: } return cast("T", origin(**d)) - if issubclass(origin, collections.abc.Sequence): + if issubclass(origin, Sequence): args = get_args(annotation) if len(args) == 0: return ser return cast("T", [coerce_from_annotation(x, args[0]) for x in ser]) - if issubclass(origin, collections.abc.Mapping): + if issubclass(origin, Mapping): args = get_args(annotation) if len(args) == 0: return ser @@ -443,36 +440,6 @@ def ptype_from_bytes[T: PType](bs: bytes, annotation: type[T] | None = None) -> assert_never(method) -def generics_in_ptype(ptype: type[PType]) -> set[str]: - """Get the generics in a type annotation. - - :param ptype: The ptype to extract generics from. - :type ptype: type[PType] - :return: The set of generic names in the ptype. - :rtype: set[str] - """ - if _is_generic(ptype): - return {str(ptype)} - - if _is_union(ptype) or _is_tuple(ptype) or _is_list(ptype) or _is_mapping(ptype): - return set(chain(*[generics_in_ptype(x) for x in get_args(ptype)])) - - origin = get_origin(ptype) - if origin is not None: - return generics_in_ptype(origin) - - if issubclass(ptype, (bool, int, float, complex, str, bytes, NoneType)): - return set() - - if issubclass(ptype, (DictConvertible, ListConvertible, NdarraySurrogate, Struct)): - return set() - - if issubclass(ptype, BaseModel): - return {str(x) for x in pydantic_get_args(ptype)} - - assert_never(ptype) - - def has_default(t: Parameter) -> bool: """Check if a parameter has a default value. diff --git a/tierkreis/tierkreis/graphs/fold.py b/tierkreis/tierkreis/graphs/fold.py index 9cf4a1579..de28d7038 100644 --- a/tierkreis/tierkreis/graphs/fold.py +++ b/tierkreis/tierkreis/graphs/fold.py @@ -46,8 +46,8 @@ def _fold_graph_outer[A: PType, B: PType]() -> GraphBuilder[ # Apply the function if we were able to pop off a value. tgd = TypedGraphRef[_InnerFuncInput, TKR[B]]( func.value_ref(), - TKR[B], _InnerFuncInput, + TKR[B], ) applied_next = g.eval(tgd, _InnerFuncInput(accum, headed.head))