Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions tierkreis/tests/controller/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,20 @@
# [tool.uv.sources]
# tierkreis = { path = "../../../tierkreis", editable = true }
# ///
from pathlib import Path
from time import sleep
from sys import argv

from tierkreis import Worker, Value
from tierkreis.models import PType

worker = Worker("tests_worker")


@worker.task()
def sleep_and_return[T](*, output: T) -> Value[T]:
def sleep_and_return[T: PType](output: T) -> Value[T]:
sleep(10)
return Value(value=output)


def main() -> None:
node_definition_path = argv[1]
worker.run(Path(node_definition_path))


if __name__ == "__main__":
main()
worker.app(argv)
10 changes: 5 additions & 5 deletions tierkreis/tierkreis/builtins/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def impl_id[T: PType](value: T) -> T:


@worker.task()
def append[T](v: list[T], a: T) -> list[T]: # noqa: E741
def append[T: PType](v: list[T], a: T) -> list[T]: # noqa: E741
v.append(a)
return v

Expand All @@ -66,8 +66,8 @@ def head[T: PType](v: list[T]) -> Headed[T]: # noqa: E741
return Headed(head=head, rest=rest)


@worker.task(name="len")
def impl_len[A](v: list[A]) -> int:
@worker.task()
def impl_len[A: PType](v: list[A]) -> int:
logger.info("len: %s", v)
return len(v)

Expand Down Expand Up @@ -107,8 +107,8 @@ def concat(lhs: str, rhs: str) -> str:
return lhs + rhs


@worker.task(name="zip")
def zip_impl[U, V](a: list[U], b: list[V]) -> list[tuple[U, V]]:
@worker.task()
def zip_impl[U: PType, V: PType](a: list[U], b: list[V]) -> list[tuple[U, V]]:
return list(zip(a, b))


Expand Down
34 changes: 34 additions & 0 deletions tierkreis/tierkreis/controller/data/worker_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Callable, TypeVar

from tierkreis.controller.data.types import PType


_T0 = TypeVar("_T0", bound=PType, contravariant=True)
_T1 = TypeVar("_T1", bound=PType, contravariant=True)
_T2 = TypeVar("_T2", bound=PType, contravariant=True)
_T3 = TypeVar("_T3", bound=PType, contravariant=True)
_T4 = TypeVar("_T4", bound=PType, contravariant=True)
_T5 = TypeVar("_T5", bound=PType, contravariant=True)
_T6 = TypeVar("_T6", bound=PType, contravariant=True)
_T7 = TypeVar("_T7", bound=PType, contravariant=True)
_T8 = TypeVar("_T8", bound=PType, contravariant=True)
_T9 = TypeVar("_T9", bound=PType, contravariant=True)
_T10 = TypeVar("_T10", bound=PType, contravariant=True)
_T11 = TypeVar("_T11", bound=PType, contravariant=True)


WorkerFunction = (
Callable[[], PType]
| Callable[[_T0], PType]
| Callable[[_T0, _T1], PType]
| Callable[[_T0, _T1, _T2], PType]
| Callable[[_T0, _T1, _T2, _T3], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10], PType]
| Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10, _T11], PType]
)
4 changes: 2 additions & 2 deletions tierkreis/tierkreis/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tierkreis.controller.data.core import EmptyModel
from tierkreis.controller.data.models import TKR, portmapping
from tierkreis.controller.data.types import Struct
from tierkreis.controller.data.types import Struct, PType

__all__ = ["EmptyModel", "TKR", "portmapping", "Struct"]
__all__ = ["EmptyModel", "TKR", "portmapping", "Struct", "PType"]
4 changes: 2 additions & 2 deletions tierkreis/tierkreis/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from inspect import isclass
from logging import getLogger
from types import NoneType
from typing import Any, Callable
from typing import Any
from tierkreis.controller.data.models import PModel, PNamedModel, is_portmapping
from tierkreis.controller.data.types import PType, Struct, is_ptype
from tierkreis.controller.data.worker_function import WorkerFunction
from tierkreis.exceptions import TierkreisError

logger = getLogger(__name__)
WorkerFunction = Callable[..., PModel]


class TierkreisWorkerError(TierkreisError):
Expand Down
19 changes: 16 additions & 3 deletions tierkreis/tierkreis/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from tierkreis.controller.data.location import WorkerCallArgs
from tierkreis.controller.data.models import PModel, dict_from_pmodel
from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes

# fmt: off
from tierkreis.controller.data.worker_function import _T0, _T11, _T2, _T3, _T1, _T10, _T4, _T5, _T6, _T7, _T8, _T9
# fmt: on
from tierkreis.exceptions import TierkreisError
from tierkreis.namespace import Namespace, WorkerFunction
from tierkreis.worker.storage.filestorage import WorkerFileStorage
Expand Down Expand Up @@ -75,16 +79,25 @@ def wrapper(args: WorkerCallArgs):

return function_decorator

def task(self, name: str | None = None) -> Callable[[WorkerFunction], None]:
def task(
self, name: str | None = None
) -> Callable[
[WorkerFunction[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10, _T11]],
None,
]:
"""Register a function with the worker."""

def function_decorator(func: WorkerFunction) -> None:
def function_decorator(
func: WorkerFunction[
_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10, _T11
],
) -> None:
func_name = func.__name__
self.namespace.add_function(func)

def wrapper(node_definition: WorkerCallArgs):
kwargs = self._load_args(func, node_definition.inputs)
results = func(**kwargs)
results = func(**kwargs) # type: ignore
self._save_results(node_definition.outputs, results)

self.functions[func_name] = wrapper
Expand Down