Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions tierkreis/tests/controller/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from datetime import datetime
from types import NoneType, UnionType
from typing import TypeVar
from uuid import UUID

import pytest
Expand All @@ -11,7 +10,6 @@
from tierkreis.controller.data.types import (
PType,
bytes_from_ptype,
generics_in_ptype,
is_ptype,
ptype_from_bytes,
)
Expand Down Expand Up @@ -135,19 +133,3 @@ def test_ptype_from_annotation(ptype: type[PType]) -> None:
@pytest.mark.parametrize("ptype", fail_list)
def test_ptype_from_annotation_fails(ptype: type[PType]) -> None:
assert not is_ptype(ptype)


S = TypeVar("S")
T = TypeVar("T")

generic_types = []
generic_types.append((list[T], {str(T)})) # type: ignore[valid-type]
generic_types.append((list[S | T], {str(S), str(T)})) # type: ignore[valid-type]
generic_types.append((list[list[list[T]]], {str(T)})) # type: ignore[valid-type]
generic_types.append((tuple[S, T], {str(S), str(T)})) # type: ignore[valid-type]
generic_types.append((UntupledModel[S, T], {str(S), str(T)})) # type: ignore[valid-type]


@pytest.mark.parametrize(("ptype", "generics"), generic_types)
def test_generic_types(ptype: type[PType], generics: set[type[PType]]) -> None:
assert generics_in_ptype(ptype) == generics
34 changes: 4 additions & 30 deletions tierkreis/tierkreis/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class TypedGraphRef[Ins: TModel, Outs: TModel]:
"""

graph_ref: ValueRef
outputs_type: type[Outs]
inputs_type: type[Ins]
outputs_type: type[Outs]


class LoopOutput(TNamedModel, Protocol):
Expand Down Expand Up @@ -151,7 +151,7 @@ def ref(self) -> TypedGraphRef[Inputs, Outputs]:
:return: The ref of the typed graph.
:rtype: TypedGraphRef[Inputs, Outputs]
"""
return TypedGraphRef((-1, "body"), self.outputs_type, self.inputs_type)
return TypedGraphRef((-1, "body"), self.inputs_type, self.outputs_type)

def outputs(self, outputs: Outputs) -> None:
"""Set output nodes of a graph.
Expand Down Expand Up @@ -279,23 +279,11 @@ def task[Out: TModel](self, func: Function[Out]) -> Out:
OutModel = func.out() # noqa: N806
return init_tmodel(OutModel, lambda p: (idx, p))

@overload
def eval[A: TModel, B: TModel](
self,
body: TypedGraphRef[A, B],
eval_inputs: A,
) -> B: ...
@overload
def eval[A: TModel, B: TModel](
self,
body: GraphBuilder[A, B],
body: GraphBuilder[A, B] | TypedGraphRef[A, B],
eval_inputs: A,
) -> B: ...
def eval[A: TModel, B: TModel](
self,
body: GraphBuilder[A, B] | TypedGraphRef,
eval_inputs: Any,
) -> Any:
) -> B:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly changing the type of parameter eval_inputs and the return type, both from Any, to A/B might have some impact???

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but I think it is a correct change though (and the lint passes) so I'm in favour of doing this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a very positive change

"""Add a evaluation node to the graph.

This will evaluate a nested graph with the given inputs.
Expand All @@ -314,20 +302,6 @@ def eval[A: TModel, B: TModel](
idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(eval_inputs))("dummy")
return init_tmodel(body.outputs_type, lambda p: (idx, p))

@overload
def loop[A: TModel, B: LoopOutput](
self,
body: TypedGraphRef[A, B],
loop_inputs: A,
name: str | None = None,
) -> B: ...
@overload
def loop[A: TModel, B: LoopOutput](
self,
body: GraphBuilder[A, B],
loop_inputs: A,
name: str | None = None,
) -> B: ...
def loop[A: TModel, B: LoopOutput](
self,
body: TypedGraphRef[A, B] | GraphBuilder[A, B],
Expand Down
15 changes: 1 addition & 14 deletions tierkreis/tierkreis/controller/data/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from dataclasses import dataclass
from inspect import isclass
from itertools import chain
from typing import (
Any,
Callable,
Expand All @@ -25,7 +24,7 @@
RestrictedNamedTuple,
ValueRef,
)
from tierkreis.controller.data.types import PType, generics_in_ptype
from tierkreis.controller.data.types import PType

TKR_PORTMAPPING_FLAG = "__tkr_portmapping__"

Expand Down Expand Up @@ -143,15 +142,3 @@ def init_tmodel[T: TModel](tmodel: type[T], input_fn: Callable[[str], ValueRef])
return cast("T", model(*args))
(ref,) = fields.values()
return tmodel(*ref)


def generics_in_pmodel(pmodel: type[PModel]) -> set[str]:
if is_portmapping(pmodel):
origin = get_origin(pmodel)
if origin is not None:
return generics_in_pmodel(origin)

x = [generics_in_ptype(pmodel.__annotations__[t]) for t in model_fields(pmodel)]
return set(chain(*x))

return generics_in_ptype(pmodel)
45 changes: 6 additions & 39 deletions tierkreis/tierkreis/controller/data/types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Valid Python types for annotating worker functions and their serialisation."""

# ruff: noqa: ANN001 ANN003 ANN401 due to serialization and inheritance from json
import collections.abc
import json
import logging
import pickle
from base64 import b64decode, b64encode
from collections import defaultdict
from collections.abc import Mapping, Sequence
from inspect import Parameter, _empty, isclass
from itertools import chain
from types import NoneType, UnionType
from typing import (
Annotated,
Expand All @@ -26,7 +24,6 @@
)

from pydantic import BaseModel, ValidationError
from pydantic._internal._generics import get_args as pydantic_get_args
from typing_extensions import TypeIs

from tierkreis.controller.data.core import (
Expand Down Expand Up @@ -184,11 +181,11 @@ def _is_generic(o) -> TypeIs[type[TypeVar]]:


def _is_list(ptype: object) -> TypeIs[type[Sequence[PType]]]:
return get_origin(ptype) == collections.abc.Sequence or get_origin(ptype) is list
return get_origin(ptype) == Sequence or get_origin(ptype) is list


def _is_mapping(ptype: object) -> TypeIs[type[Mapping[str, PType]]]:
return get_origin(ptype) is collections.abc.Mapping or get_origin(ptype) is dict
return get_origin(ptype) is Mapping or get_origin(ptype) is dict


def _is_tuple(o: object) -> TypeIs[type[tuple[Any, ...]]]:
Expand Down Expand Up @@ -259,10 +256,10 @@ def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> JsonType:
case tuple():
args = get_args(annotation) or [None] * len(ptype)
return tuple([ser_from_ptype(p, args[i]) for i, p in enumerate(ptype)])
case collections.abc.Sequence():
case Sequence():
arg = get_args(annotation)[0] if get_args(annotation) else None
return [ser_from_ptype(p, arg) for p in ptype]
case collections.abc.Mapping():
case Mapping():
arg = get_args(annotation)[1] if get_args(annotation) else None
return {k: ser_from_ptype(p, arg) for k, p in ptype.items()}
case DictConvertible():
Expand Down Expand Up @@ -371,14 +368,14 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T:
}
return cast("T", origin(**d))

if issubclass(origin, collections.abc.Sequence):
if issubclass(origin, Sequence):
args = get_args(annotation)
if len(args) == 0:
return ser

return cast("T", [coerce_from_annotation(x, args[0]) for x in ser])

if issubclass(origin, collections.abc.Mapping):
if issubclass(origin, Mapping):
args = get_args(annotation)
if len(args) == 0:
return ser
Expand Down Expand Up @@ -443,36 +440,6 @@ def ptype_from_bytes[T: PType](bs: bytes, annotation: type[T] | None = None) ->
assert_never(method)


def generics_in_ptype(ptype: type[PType]) -> set[str]:
"""Get the generics in a type annotation.

:param ptype: The ptype to extract generics from.
:type ptype: type[PType]
:return: The set of generic names in the ptype.
:rtype: set[str]
"""
if _is_generic(ptype):
return {str(ptype)}

if _is_union(ptype) or _is_tuple(ptype) or _is_list(ptype) or _is_mapping(ptype):
return set(chain(*[generics_in_ptype(x) for x in get_args(ptype)]))

origin = get_origin(ptype)
if origin is not None:
return generics_in_ptype(origin)

if issubclass(ptype, (bool, int, float, complex, str, bytes, NoneType)):
return set()

if issubclass(ptype, (DictConvertible, ListConvertible, NdarraySurrogate, Struct)):
return set()

if issubclass(ptype, BaseModel):
return {str(x) for x in pydantic_get_args(ptype)}

assert_never(ptype)


def has_default(t: Parameter) -> bool:
"""Check if a parameter has a default value.

Expand Down
2 changes: 1 addition & 1 deletion tierkreis/tierkreis/graphs/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _fold_graph_outer[A: PType, B: PType]() -> GraphBuilder[
# Apply the function if we were able to pop off a value.
tgd = TypedGraphRef[_InnerFuncInput, TKR[B]](
func.value_ref(),
TKR[B],
_InnerFuncInput,
TKR[B],
)
applied_next = g.eval(tgd, _InnerFuncInput(accum, headed.head))

Expand Down