diff --git a/tierkreis/tests/controller/main.py b/tierkreis/tests/controller/main.py index a04069b61..de7b852f3 100644 --- a/tierkreis/tests/controller/main.py +++ b/tierkreis/tests/controller/main.py @@ -5,25 +5,20 @@ # [tool.uv.sources] # tierkreis = { path = "../../../tierkreis", editable = true } # /// -from pathlib import Path from time import sleep from sys import argv from tierkreis import Worker, Value +from tierkreis.models import PType worker = Worker("tests_worker") @worker.task() -def sleep_and_return[T](*, output: T) -> Value[T]: +def sleep_and_return[T: PType](output: T) -> Value[T]: sleep(10) return Value(value=output) -def main() -> None: - node_definition_path = argv[1] - worker.run(Path(node_definition_path)) - - if __name__ == "__main__": - main() + worker.app(argv) diff --git a/tierkreis/tierkreis/builtins/main.py b/tierkreis/tierkreis/builtins/main.py index 5132c3d86..937fd1815 100644 --- a/tierkreis/tierkreis/builtins/main.py +++ b/tierkreis/tierkreis/builtins/main.py @@ -49,7 +49,7 @@ def impl_id[T: PType](value: T) -> T: @worker.task() -def append[T](v: list[T], a: T) -> list[T]: # noqa: E741 +def append[T: PType](v: list[T], a: T) -> list[T]: # noqa: E741 v.append(a) return v @@ -66,8 +66,8 @@ def head[T: PType](v: list[T]) -> Headed[T]: # noqa: E741 return Headed(head=head, rest=rest) -@worker.task(name="len") -def impl_len[A](v: list[A]) -> int: +@worker.task() +def impl_len[A: PType](v: list[A]) -> int: logger.info("len: %s", v) return len(v) @@ -107,8 +107,8 @@ def concat(lhs: str, rhs: str) -> str: return lhs + rhs -@worker.task(name="zip") -def zip_impl[U, V](a: list[U], b: list[V]) -> list[tuple[U, V]]: +@worker.task() +def zip_impl[U: PType, V: PType](a: list[U], b: list[V]) -> list[tuple[U, V]]: return list(zip(a, b)) diff --git a/tierkreis/tierkreis/controller/data/worker_function.py b/tierkreis/tierkreis/controller/data/worker_function.py new file mode 100644 index 000000000..d8e030240 --- /dev/null +++ b/tierkreis/tierkreis/controller/data/worker_function.py @@ -0,0 +1,34 @@ +from typing import Callable, TypeVar + +from tierkreis.controller.data.types import PType + + +_T0 = TypeVar("_T0", bound=PType, contravariant=True) +_T1 = TypeVar("_T1", bound=PType, contravariant=True) +_T2 = TypeVar("_T2", bound=PType, contravariant=True) +_T3 = TypeVar("_T3", bound=PType, contravariant=True) +_T4 = TypeVar("_T4", bound=PType, contravariant=True) +_T5 = TypeVar("_T5", bound=PType, contravariant=True) +_T6 = TypeVar("_T6", bound=PType, contravariant=True) +_T7 = TypeVar("_T7", bound=PType, contravariant=True) +_T8 = TypeVar("_T8", bound=PType, contravariant=True) +_T9 = TypeVar("_T9", bound=PType, contravariant=True) +_T10 = TypeVar("_T10", bound=PType, contravariant=True) +_T11 = TypeVar("_T11", bound=PType, contravariant=True) + + +WorkerFunction = ( + Callable[[], PType] + | Callable[[_T0], PType] + | Callable[[_T0, _T1], PType] + | Callable[[_T0, _T1, _T2], PType] + | Callable[[_T0, _T1, _T2, _T3], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10], PType] + | Callable[[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10, _T11], PType] +) diff --git a/tierkreis/tierkreis/models.py b/tierkreis/tierkreis/models.py index 1405f4468..e4b3c4541 100644 --- a/tierkreis/tierkreis/models.py +++ b/tierkreis/tierkreis/models.py @@ -1,5 +1,5 @@ from tierkreis.controller.data.core import EmptyModel from tierkreis.controller.data.models import TKR, portmapping -from tierkreis.controller.data.types import Struct +from tierkreis.controller.data.types import Struct, PType -__all__ = ["EmptyModel", "TKR", "portmapping", "Struct"] +__all__ = ["EmptyModel", "TKR", "portmapping", "Struct", "PType"] diff --git a/tierkreis/tierkreis/namespace.py b/tierkreis/tierkreis/namespace.py index 5f0ac152e..fb2f09260 100644 --- a/tierkreis/tierkreis/namespace.py +++ b/tierkreis/tierkreis/namespace.py @@ -2,13 +2,13 @@ from inspect import isclass from logging import getLogger from types import NoneType -from typing import Any, Callable +from typing import Any from tierkreis.controller.data.models import PModel, PNamedModel, is_portmapping from tierkreis.controller.data.types import PType, Struct, is_ptype +from tierkreis.controller.data.worker_function import WorkerFunction from tierkreis.exceptions import TierkreisError logger = getLogger(__name__) -WorkerFunction = Callable[..., PModel] class TierkreisWorkerError(TierkreisError): diff --git a/tierkreis/tierkreis/worker/worker.py b/tierkreis/tierkreis/worker/worker.py index ba42b79c6..e7998aafe 100644 --- a/tierkreis/tierkreis/worker/worker.py +++ b/tierkreis/tierkreis/worker/worker.py @@ -12,6 +12,10 @@ from tierkreis.controller.data.location import WorkerCallArgs from tierkreis.controller.data.models import PModel, dict_from_pmodel from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes + +# fmt: off +from tierkreis.controller.data.worker_function import _T0, _T11, _T2, _T3, _T1, _T10, _T4, _T5, _T6, _T7, _T8, _T9 +# fmt: on from tierkreis.exceptions import TierkreisError from tierkreis.namespace import Namespace, WorkerFunction from tierkreis.worker.storage.filestorage import WorkerFileStorage @@ -75,16 +79,25 @@ def wrapper(args: WorkerCallArgs): return function_decorator - def task(self, name: str | None = None) -> Callable[[WorkerFunction], None]: + def task( + self, name: str | None = None + ) -> Callable[ + [WorkerFunction[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10, _T11]], + None, + ]: """Register a function with the worker.""" - def function_decorator(func: WorkerFunction) -> None: + def function_decorator( + func: WorkerFunction[ + _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, _T10, _T11 + ], + ) -> None: func_name = func.__name__ self.namespace.add_function(func) def wrapper(node_definition: WorkerCallArgs): kwargs = self._load_args(func, node_definition.inputs) - results = func(**kwargs) + results = func(**kwargs) # type: ignore self._save_results(node_definition.outputs, results) self.functions[func_name] = wrapper