diff --git a/tierkreis/pyproject.toml b/tierkreis/pyproject.toml index 3011342ab..8b300449c 100644 --- a/tierkreis/pyproject.toml +++ b/tierkreis/pyproject.toml @@ -30,7 +30,38 @@ build-backend = "hatchling.build" [tool.ruff] target-version = "py312" -extend-exclude = [] +extend-exclude = [ + # Ignore worker stubs: + "*_worker.py", + "stubs.py", + "stubs_output.py", + # Ignore docs + "tierkreis/docs/*", +] +[tool.ruff.lint] +select = ["ALL"] +isort.known-first-party = ["tierkreis", "tierkreis_visualizer"] +pydocstyle.convention = "pep257" + +# Ignore specific rules that might be redundant or annoying +ignore = [ + "D203", # Conflict: 1 blank line before class (D203) vs no blank lines (D211) + "D213", # Conflict: Multi-line docstring summary start (D212) vs (D213) + "ISC001", # Single line implicit string concatenation + "S603", # Subprocess calls + "D107", # In favor of documenting classes directly +] + +[tool.ruff.lint.per-file-ignores] +"tierkreis/tests/*" = [ + "S101", # asserts allowed in tests... + "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... + "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() + "PLR2004", # Magic value used in comparison, ... + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "D", +] + [tool.pyright] include = ["."] diff --git a/tierkreis/tests/cli/test_run_workflow.py b/tierkreis/tests/cli/test_run_workflow.py index 251f044c1..bdd29651c 100644 --- a/tierkreis/tests/cli/test_run_workflow.py +++ b/tierkreis/tests/cli/test_run_workflow.py @@ -1,17 +1,17 @@ -import pytest import json from pathlib import Path -from uuid import UUID from unittest import mock +from uuid import UUID +import pytest -from tierkreis.controller.data.graph import GraphData -from tierkreis.cli.run_workflow import run_workflow from tests.controller.sample_graphdata import simple_eval +from tierkreis.cli.run_workflow import run_workflow +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import ptype_from_bytes -@pytest.fixture() +@pytest.fixture def graph() -> GraphData: return simple_eval() @@ -19,7 +19,7 @@ def graph() -> GraphData: def test_run_workflow(graph: GraphData) -> None: inputs = {} run_workflow(inputs=inputs, graph=graph, run_id=31415) - with open( + with Path.open( Path.home() / ".tierkreis" / "checkpoints" @@ -32,18 +32,24 @@ def test_run_workflow(graph: GraphData) -> None: assert c == 12 -def test_run_workflow_with_output(graph: GraphData, capfd) -> None: +def test_run_workflow_with_output(graph: GraphData, capfd) -> None: # noqa: ANN001 inputs = {} run_workflow(inputs=inputs, graph=graph, run_id=31415, print_output=True) out, _ = capfd.readouterr() - assert "simple_eval_output: b'12'\n" in out + assert "simple_eval_output: 12\n" in out + + +@pytest.fixture +def _patch_uuid4() -> mock.Mock: + with mock.patch("uuid.uuid4", return_value=UUID(int=31415)) as m: + return m -@mock.patch("uuid.uuid4", return_value=UUID(int=31415)) -def test_run_workflow_default_run_id(_, graph: GraphData) -> None: +@pytest.mark.usefixtures("_patch_uuid4", "graph") +def test_run_workflow_default_run_id(graph: GraphData) -> None: inputs = {} run_workflow(inputs=inputs, graph=graph) - with open( + with Path.open( Path.home() / ".tierkreis" / "checkpoints" @@ -61,6 +67,6 @@ def test_run_workflow_uv_executor(graph: GraphData) -> None: inputs=inputs, graph=graph, run_id=31415, - use_uv_worker=True, - registry_path=Path("."), + use_uv_executor=True, + registry_path=Path(), ) diff --git a/tierkreis/tests/cli/test_tkr.py b/tierkreis/tests/cli/test_tkr.py index 92545f019..e9b4158da 100644 --- a/tierkreis/tests/cli/test_tkr.py +++ b/tierkreis/tests/cli/test_tkr.py @@ -1,17 +1,17 @@ import json -import pytest import sys from pathlib import Path from unittest import mock from uuid import UUID -from tierkreis.cli.tkr import load_graph, _load_inputs, main +import pytest + +from tests.controller.sample_graphdata import simple_eval +from tierkreis.cli.tkr import _load_inputs, load_graph, main from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType from tierkreis.exceptions import TierkreisError -from tests.controller.sample_graphdata import simple_eval - simple_eval_graph = simple_eval() graph_params = [ @@ -20,9 +20,13 @@ ] -@pytest.mark.parametrize("input,graph", graph_params, ids=["load_module", "load_file"]) -def test_load_graph(input: str, graph: GraphData) -> None: - assert load_graph(input) == graph +@pytest.mark.parametrize( + ("inputs", "graph"), + graph_params, + ids=["load_module", "load_file"], +) +def test_load_graph(inputs: str, graph: GraphData) -> None: + assert load_graph(inputs) == graph def test_load_graph_invalid() -> None: @@ -55,10 +59,12 @@ def test_load_graph_invalid() -> None: @pytest.mark.parametrize( - "input,result", input_params, ids=["json_input", "binary_input"] + ("inputs", "result"), + input_params, + ids=["json_input", "binary_input"], ) -def test_load_inputs(input: list[str], result: dict[str, PType]) -> None: - assert _load_inputs(input) == result +def test_load_inputs(inputs: list[str], result: dict[str, PType]) -> None: + assert _load_inputs(inputs) == result def test_load_inputs_invalid() -> None: @@ -91,12 +97,12 @@ def test_load_inputs_invalid() -> None: cli_params = [ ( - default_args + ["-f", "tierkreis/tests/cli/data/sample_graph"], + [*default_args, "-f", "tierkreis/tests/cli/data/sample_graph"], {"simple_eval_output": 12}, ), ( - default_args - + [ + [ + *default_args, "-g", "tests.controller.sample_graphdata:factorial", "-i", @@ -109,13 +115,15 @@ def test_load_inputs_invalid() -> None: @pytest.mark.parametrize( - "args,result", cli_params, ids=["simple_eval_cli", "factorial_cli"] + ("args", "result"), + cli_params, + ids=["simple_eval_cli", "factorial_cli"], ) def test_end_to_end(args: list[str], result: dict[str, bytes]) -> None: with mock.patch.object(sys, "argv", args): main() for key, value in result.items(): - with open( + with Path.open( Path.home() / ".tierkreis" / "checkpoints" diff --git a/tierkreis/tests/conftest.py b/tierkreis/tests/conftest.py index afe4ba84f..62a9a2f45 100644 --- a/tierkreis/tests/conftest.py +++ b/tierkreis/tests/conftest.py @@ -1,17 +1,23 @@ import pytest -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( - "--optional", action="store_true", default=False, help="run optional tests" + "--optional", + action="store_true", + default=False, + help="run optional tests", ) -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line("markers", "optional: mark test as optional to run") -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems( + config: pytest.Config, + items: list[pytest.Item], +) -> None: if config.getoption("--optional"): return skip_slow = pytest.mark.skip(reason="need --optional option to run") diff --git a/tierkreis/tests/controller/loop_graphdata.py b/tierkreis/tests/controller/loop_graphdata.py index 46467e1cc..426a87ea1 100644 --- a/tierkreis/tests/controller/loop_graphdata.py +++ b/tierkreis/tests/controller/loop_graphdata.py @@ -1,7 +1,8 @@ from typing import NamedTuple + import tierkreis.builtins.stubs as tkr_builtins -from tierkreis.controller.data.core import EmptyModel from tierkreis.builder import GraphBuilder +from tierkreis.controller.data.core import EmptyModel from tierkreis.controller.data.graph import GraphData from tierkreis.models import TKR @@ -30,7 +31,7 @@ def _loop_body_multiple_acc_untyped() -> GraphData: "acc1": new_acc, "acc2": new_acc2, "acc3": new_acc3, - } + }, ) return g @@ -136,10 +137,10 @@ def _loop_body_scoping() -> GraphBuilder[Scoping, ScopingOut]: one = g.const(1) - next = g.task(tkr_builtins.iadd(g.inputs.current, one)) + next_val = g.task(tkr_builtins.iadd(g.inputs.current, one)) should_continue = g.task(tkr_builtins.neq(g.inputs.end, g.inputs.current)) - g.outputs(ScopingOut(should_continue=should_continue, current=next)) + g.outputs(ScopingOut(should_continue=should_continue, current=next_val)) return g diff --git a/tierkreis/tests/controller/main.py b/tierkreis/tests/controller/main.py index d134ab716..97cca9a55 100644 --- a/tierkreis/tests/controller/main.py +++ b/tierkreis/tests/controller/main.py @@ -6,8 +6,8 @@ # tierkreis = { path = "../../../tierkreis", editable = true } # /// from pathlib import Path -from time import sleep from sys import argv +from time import sleep from tierkreis import Worker diff --git a/tierkreis/tests/controller/sample_graphdata.py b/tierkreis/tests/controller/sample_graphdata.py index ee46a88a4..a9f0a7465 100644 --- a/tierkreis/tests/controller/sample_graphdata.py +++ b/tierkreis/tests/controller/sample_graphdata.py @@ -38,10 +38,10 @@ def loop_body() -> GraphData: g = GraphData() a = g.input("loop_acc") one = g.const(1) - N = g.const(10) + n_val: tuple[int, str] = g.const(10) a_plus = g.func("builtins.iadd", {"a": a, "b": one})("value") - pred = g.func("builtins.igt", {"a": N, "b": a_plus})("value") + pred = g.func("builtins.igt", {"a": n_val, "b": a_plus})("value") g.output({"loop_acc": a_plus, "should_continue": pred}) return g @@ -58,11 +58,11 @@ def simple_loop() -> GraphData: def simple_map() -> GraphData: g = GraphData() six = g.const(6) - Ns_const = g.const(list(range(21))) - Ns = g.func("builtins.unfold_values", {Labels.VALUE: Ns_const}) + n_consts = g.const(list(range(21))) + n_vals = g.func("builtins.unfold_values", {Labels.VALUE: n_consts}) doubler_const = g.const(doubler_plus()) - m = g.map(doubler_const, {"doubler_input": Ns("*"), "intercept": six}) + m = g.map(doubler_const, {"doubler_input": n_vals("*"), "intercept": six}) folded = g.func("builtins.fold_values", {"values_glob": m("*")}) g.output({"value": folded(Labels.VALUE)}) return g @@ -82,11 +82,11 @@ def maps_in_series() -> GraphData: g = GraphData() zero = g.const(0) - Ns_const = g.const(list(range(21))) - Ns = g.func("builtins.unfold_values", {Labels.VALUE: Ns_const}) + n_consts = g.const(list(range(21))) + n_vals = g.func("builtins.unfold_values", {Labels.VALUE: n_consts}) doubler_const = g.const(doubler_plus()) - m = g.map(doubler_const, {"doubler_input": Ns("*"), "intercept": zero}) + m = g.map(doubler_const, {"doubler_input": n_vals("*"), "intercept": zero}) m2 = g.map(doubler_const, {"doubler_input": m("*"), "intercept": zero}) folded = g.func("builtins.fold_values", {"values_glob": m2("*")}) @@ -97,11 +97,11 @@ def maps_in_series() -> GraphData: def map_with_str_keys() -> GraphData: g = GraphData() zero = g.const(0) - Ns_const = g.const({"one": 1, "two": 2, "three": 3}) - Ns = g.func("builtins.unfold_dict", {Labels.VALUE: Ns_const}) + n_consts = g.const({"one": 1, "two": 2, "three": 3}) + n_vals = g.func("builtins.unfold_dict", {Labels.VALUE: n_consts}) doubler_const = g.const(doubler_plus()) - m = g.map(doubler_const, {"doubler_input": Ns("*"), "intercept": zero}) + m = g.map(doubler_const, {"doubler_input": n_vals("*"), "intercept": zero}) folded = g.func("builtins.fold_dict", {"values_glob": m("*")}) g.output({"value": folded(Labels.VALUE)}) return g diff --git a/tierkreis/tests/controller/test_codegen.py b/tierkreis/tests/controller/test_codegen.py index 7c011b850..6d3996ace 100644 --- a/tierkreis/tests/controller/test_codegen.py +++ b/tierkreis/tests/controller/test_codegen.py @@ -1,5 +1,7 @@ from types import NoneType + import pytest + from tierkreis.codegen import format_generic_type from tierkreis.controller.data.types import PType from tierkreis.idl.models import GenericType @@ -21,8 +23,10 @@ ] -@pytest.mark.parametrize("ttype,expected", formats) -def test_format_ttype(ttype: type[PType], expected: str): +@pytest.mark.parametrize(("ttype", "expected"), formats) +def test_format_ttype(ttype: type[PType], expected: str) -> None: generic_type = GenericType.from_type(ttype) - assert format_generic_type(generic_type, False, False) == expected + assert ( + format_generic_type(generic_type, include_bound=False, is_tkr=False) == expected + ) diff --git a/tierkreis/tests/controller/test_eagerifelse.py b/tierkreis/tests/controller/test_eagerifelse.py index 4bfe0d6f3..6afd9af51 100644 --- a/tierkreis/tests/controller/test_eagerifelse.py +++ b/tierkreis/tests/controller/test_eagerifelse.py @@ -1,20 +1,20 @@ import json -import pytest from pathlib import Path from uuid import UUID +import pytest + from tests.controller.sample_graphdata import ( simple_eagerifelse, simple_ifelse, ) - from tierkreis.controller import run_graph +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import PType from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.executor.uv_executor import UvExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage -from tierkreis.controller.data.graph import GraphData def eagerifelse_long_running() -> GraphData: @@ -34,8 +34,8 @@ def eagerifelse_long_running() -> GraphData: params = [({"pred": True}, 1), ({"pred": False}, 2)] -@pytest.mark.parametrize("input, output", params) -def test_eagerifelse_long_running(input: dict[str, PType], output: int) -> None: +@pytest.mark.parametrize(("inputs", "output"), params) +def test_eagerifelse_long_running(inputs: dict[str, PType], output: int) -> None: g = eagerifelse_long_running() storage = ControllerFileStorage(UUID(int=150), name="eagerifelse_long_running") @@ -43,7 +43,7 @@ def test_eagerifelse_long_running(input: dict[str, PType], output: int) -> None: executor = UvExecutor(registry_path=registry_path, logs_path=storage.logs_path) storage.clean_graph_files() - run_graph(storage, executor, g, input, n_iterations=20000) + run_graph(storage, executor, g, inputs, n_iterations=20000) actual_output = json.loads(storage.read_output(Loc(), "simple_eagerifelse_output")) assert actual_output == output @@ -58,7 +58,7 @@ def test_eagerifelse_nodes() -> None: assert storage.is_node_finished(Loc("-.N4")) -def test_ifelse_nodes(): +def test_ifelse_nodes() -> None: g = simple_ifelse() storage = ControllerFileStorage(UUID(int=152), name="simple_if_else") executor = ShellExecutor(Path("./python/examples/launchers"), storage.workflow_dir) diff --git a/tierkreis/tests/controller/test_graphdata.py b/tierkreis/tests/controller/test_graphdata.py index fa9d0ac89..92c2407f2 100644 --- a/tierkreis/tests/controller/test_graphdata.py +++ b/tierkreis/tests/controller/test_graphdata.py @@ -1,10 +1,11 @@ import pytest -from tierkreis.exceptions import TierkreisError + from tierkreis.controller.data.graph import GraphData +from tierkreis.exceptions import TierkreisError -def test_only_one_output(): +def test_only_one_output() -> None: + g = GraphData() + g.output({"one": g.const(1)}) with pytest.raises(TierkreisError): - g = GraphData() - g.output({"one": g.const(1)}) g.output({"two": g.const(2)}) diff --git a/tierkreis/tests/controller/test_graphdata_storage.py b/tierkreis/tests/controller/test_graphdata_storage.py index 0d053a105..d64fc6b67 100644 --- a/tierkreis/tests/controller/test_graphdata_storage.py +++ b/tierkreis/tests/controller/test_graphdata_storage.py @@ -1,5 +1,7 @@ from uuid import UUID + import pytest + from tests.controller.sample_graphdata import simple_eval, simple_map from tierkreis.controller.data.core import PortID from tierkreis.controller.data.graph import ( @@ -16,7 +18,7 @@ @pytest.mark.parametrize( - ["node_location_str", "graph", "target"], + ("node_location_str", "graph", "target"), [ ("-.N0", simple_eval(), Const(0, outputs={"value": 3})), ("-.N4.M0", simple_map(), Eval((-1, "body"), {})), @@ -31,14 +33,14 @@ def test_read_nodedef(node_location_str: str, graph: GraphData, target: str) -> @pytest.mark.parametrize( - ["node_location_str", "graph", "port", "target"], + ("node_location_str", "graph", "port", "target"), [ ("-.N0", simple_eval(), "value", b"null"), ("-.N4.M0", simple_map(), "0", b"null"), ], ) def test_read_output( - node_location_str: str, graph: GraphData, port: PortID, target: str + node_location_str: str, graph: GraphData, port: PortID, target: str, ) -> None: loc = Loc(node_location_str) storage = GraphDataStorage(UUID(int=0), graph) @@ -54,14 +56,14 @@ def test_raises() -> None: @pytest.mark.parametrize( - ["node_location_str", "graph", "target"], + ("node_location_str", "graph", "target"), [ ("-.N0", simple_eval(), ["value"]), ("-.N4.M0", simple_map(), ["0"]), ], ) def test_read_output_ports( - node_location_str: str, graph: GraphData, target: str + node_location_str: str, graph: GraphData, target: str, ) -> None: loc = Loc(node_location_str) storage = GraphDataStorage(UUID(int=0), graph) @@ -70,7 +72,7 @@ def test_read_output_ports( @pytest.mark.parametrize( - ["node_location_str", "graph", "target"], + ("node_location_str", "graph", "target"), [ ("-.N0", simple_eval(), Const(0, outputs={"value": 3})), ("-.N3.N1", simple_eval(), Input("intercept", outputs={"intercept": 4})), @@ -99,7 +101,7 @@ def test_read_output_ports( ], ) def test_graph_node_from_loc( - node_location_str: str, graph: GraphData, target: str + node_location_str: str, graph: GraphData, target: str, ) -> None: loc = Loc(node_location_str) node_def, _ = graph_node_from_loc(loc, graph) diff --git a/tierkreis/tests/controller/test_locs.py b/tierkreis/tests/controller/test_locs.py index 52051cf1f..91765cc9f 100644 --- a/tierkreis/tests/controller/test_locs.py +++ b/tierkreis/tests/controller/test_locs.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize( - ["node_location", "loc_str"], + ("node_location", "loc_str"), [ (node_location_1, "-.N1.L0.N3.L2.N0.M7.N10"), (node_location_2, "-.N0.L0.N3.N8.N0"), @@ -35,7 +35,7 @@ (node_location_4, "-"), ], ) -def test_to_from_str(node_location: Loc, loc_str: str): +def test_to_from_str(node_location: Loc, loc_str: str) -> None: node_location_str = str(node_location) assert node_location_str == loc_str @@ -44,7 +44,7 @@ def test_to_from_str(node_location: Loc, loc_str: str): @pytest.mark.parametrize( - ["node_location", "loc_str"], + ("node_location", "loc_str"), [ (node_location_1, "-.N1.L0.N3.L2.N0.M7"), (node_location_2, "-.N0.L0.N3.N8"), @@ -61,7 +61,7 @@ def test_parent(node_location: Loc, loc_str: str) -> None: @pytest.mark.parametrize( - ["node_location", "node_step", "loc_str"], + ("node_location", "node_step", "loc_str"), [ (node_location_1, ("N", 1), "-.L0.N3.L2.N0.M7.N10"), (node_location_2, ("N", 0), "-.L0.N3.N8.N0"), @@ -77,7 +77,7 @@ def test_pop_first(node_location: Loc, node_step: NodeStep, loc_str: str) -> Non @pytest.mark.parametrize( - ["node_location", "node_step", "loc_str"], + ("node_location", "node_step", "loc_str"), [ (node_location_1, ("N", 10), "-.N1.L0.N3.L2.N0.M7"), (node_location_2, ("N", 0), "-.N0.L0.N3.N8"), @@ -157,7 +157,7 @@ def test_pop_last_multiple() -> None: @pytest.mark.parametrize( - ["node_location", "index"], + ("node_location", "index"), [ (node_location_1, 10), (node_location_2, 0), @@ -171,7 +171,7 @@ def test_get_last_index(node_location: Loc, index: int) -> None: @pytest.mark.parametrize( - ["node_location", "expected"], + ("node_location", "expected"), [ ( node_location_1, @@ -201,5 +201,5 @@ def test_get_last_index(node_location: Loc, index: int) -> None: (node_location_4, [Loc()]), ], ) -def test_partial_paths(node_location: Loc, expected: list[Loc]): +def test_partial_paths(node_location: Loc, expected: list[Loc]) -> None: assert expected == node_location.partial_locs() diff --git a/tierkreis/tests/controller/test_models.py b/tierkreis/tests/controller/test_models.py index fc96f8c48..d08206cc5 100644 --- a/tierkreis/tests/controller/test_models.py +++ b/tierkreis/tests/controller/test_models.py @@ -1,9 +1,11 @@ from types import NoneType from typing import NamedTuple + import pytest + +from tests.controller.test_types import ptypes from tierkreis.controller.data.models import PModel, dict_from_pmodel, portmapping from tierkreis.controller.data.types import PType -from tests.controller.test_types import ptypes @portmapping @@ -20,7 +22,7 @@ class NamedPModel(NamedTuple): @pytest.mark.parametrize("pmodel", ptypes) -def test_dict_from_pmodel_unnested(pmodel: PModel): +def test_dict_from_pmodel_unnested(pmodel: PModel) -> None: assert dict_from_pmodel(pmodel) == {"value": pmodel} @@ -50,6 +52,6 @@ def test_dict_from_pmodel_unnested(pmodel: PModel): pmodels = [(named_p_model, named_p_model_expected)] -@pytest.mark.parametrize("pmodel,expected", pmodels) -def test_dict_from_pmodel_nested(pmodel: PModel, expected: dict[str, PType]): +@pytest.mark.parametrize(("pmodel", "expected"), pmodels) +def test_dict_from_pmodel_nested(pmodel: PModel, expected: dict[str, PType]) -> None: assert dict_from_pmodel(pmodel) == expected diff --git a/tierkreis/tests/controller/test_read_loop_trace.py b/tierkreis/tests/controller/test_read_loop_trace.py index f949d976e..5ac829b64 100644 --- a/tierkreis/tests/controller/test_read_loop_trace.py +++ b/tierkreis/tests/controller/test_read_loop_trace.py @@ -1,24 +1,23 @@ from pathlib import Path -from typing import Any, Type from uuid import UUID import pytest + from tests.controller.loop_graphdata import loop_multiple_acc, loop_multiple_acc_untyped from tierkreis.controller import run_graph +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage -from tierkreis.controller.data.graph import GraphData from tierkreis.storage import read_loop_trace - return_value = [ {"acc1": x, "acc2": y, "acc3": z} - for x, y, z in zip(range(1, 7), range(2, 13, 2), range(3, 19, 3)) + for x, y, z in zip(range(1, 7), range(2, 13, 2), range(3, 19, 3), strict=False) ] -params: list[tuple[GraphData, Any, str, int]] = [ +params: list[tuple[GraphData, list[dict[str, int]], str, int]] = [ ( loop_multiple_acc_untyped(), return_value, @@ -42,17 +41,17 @@ @pytest.mark.parametrize("storage_class", storage_classes, ids=storage_ids) -@pytest.mark.parametrize("graph,output,name,id", params, ids=ids) +@pytest.mark.parametrize(("graph", "output", "name", "workflow_id"), params, ids=ids) def test_read_loop_trace( - storage_class: Type[ControllerFileStorage | ControllerInMemoryStorage], + storage_class: type[ControllerFileStorage | ControllerInMemoryStorage], graph: GraphData, - output: Any, + output: list[dict[str, int]], name: str, - id: int, -): + workflow_id: int, +) -> None: g = graph - storage = storage_class(UUID(int=id), name=name) - executor = ShellExecutor(Path("./python/examples/launchers"), Path("")) + storage = storage_class(UUID(int=workflow_id), name=name) + executor = ShellExecutor(Path("./python/examples/launchers"), Path()) if isinstance(storage, ControllerInMemoryStorage): executor = InMemoryExecutor(Path("./tierkreis/tierkreis"), storage=storage) storage.clean_graph_files() diff --git a/tierkreis/tests/controller/test_resume.py b/tierkreis/tests/controller/test_resume.py index 38928fe12..b65b49bd2 100644 --- a/tierkreis/tests/controller/test_resume.py +++ b/tierkreis/tests/controller/test_resume.py @@ -1,9 +1,14 @@ from pathlib import Path -from typing import Any, Type +from typing import Any from uuid import UUID import pytest +from tests.controller.loop_graphdata import ( + loop_multiple_acc, + loop_multiple_acc_untyped, + loop_scoping, +) from tests.controller.sample_graphdata import ( maps_in_series, simple_eagerifelse, @@ -12,12 +17,9 @@ simple_loop, simple_map, ) -from tests.controller.loop_graphdata import ( - loop_multiple_acc, - loop_multiple_acc_untyped, - loop_scoping, -) from tests.controller.typed_graphdata import ( + factorial, + gcd, tkr_conj, tkr_list_conj, tuple_untuple, @@ -25,20 +27,20 @@ typed_eval, typed_loop, typed_map, - factorial, - gcd, typed_map_simple, ) from tierkreis.controller import run_graph +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage -from tierkreis.controller.data.graph import GraphData from tierkreis.storage import read_outputs -param_data: list[tuple[GraphData, Any, str, dict[str, PType] | PType]] = [ +param_data: list[ + tuple[GraphData, dict[str, PType] | PType, str, dict[str, PType] | PType] +] = [ (simple_eval(), {"simple_eval_output": 12}, "simple_eval", {}), (simple_loop(), 10, "simple_loop", {}), (simple_map(), list(range(6, 47, 2)), "simple_map", {}), @@ -138,18 +140,22 @@ @pytest.mark.parametrize("storage_class", storage_classes, ids=storage_ids) -@pytest.mark.parametrize("graph,output,name,id,inputs", params, ids=ids) -def test_resume( - storage_class: Type[ControllerFileStorage | ControllerInMemoryStorage], +@pytest.mark.parametrize( + ("graph", "output", "name", "workflow_id", "inputs"), + params, + ids=ids, +) +def test_resume( # noqa: PLR0913 + storage_class: type[ControllerFileStorage | ControllerInMemoryStorage], graph: GraphData, - output: Any, + output: dict[str, PType] | PType, name: str, - id: int, + workflow_id: int, inputs: dict[str, PType] | PType, -): +) -> None: g = graph - storage = storage_class(UUID(int=id), name=name) - executor = ShellExecutor(Path("./python/examples/launchers"), Path("")) + storage = storage_class(UUID(int=workflow_id), name=name) + executor = ShellExecutor(Path("./python/examples/launchers"), Path()) if isinstance(storage, ControllerInMemoryStorage): executor = InMemoryExecutor(Path("./tierkreis/tierkreis"), storage=storage) storage.clean_graph_files() diff --git a/tierkreis/tests/controller/test_types.py b/tierkreis/tests/controller/test_types.py index f261d9e61..4beca82fc 100644 --- a/tierkreis/tests/controller/test_types.py +++ b/tierkreis/tests/controller/test_types.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime from types import NoneType, UnionType -from typing import Mapping, Sequence, TypeVar +from typing import TypeVar from uuid import UUID -from pydantic import BaseModel + import pytest +from pydantic import BaseModel + from tierkreis.controller.data.types import ( PType, bytes_from_ptype, @@ -96,7 +99,7 @@ def from_list(cls, args: list) -> "DummyListConvertible": @pytest.mark.parametrize("ptype", ptypes) -def test_bytes_roundtrip(ptype: PType): +def test_bytes_roundtrip(ptype: PType) -> None: bs = bytes_from_ptype(ptype) new_type = ptype_from_bytes(bs, type(ptype)) assert ptype == new_type @@ -117,7 +120,7 @@ def test_bytes_roundtrip(ptype: PType): @pytest.mark.parametrize("annotated_ptype", annotated_ptypes) -def test_annotated_bytes_roundtrip(annotated_ptype: tuple[PType, type]): +def test_annotated_bytes_roundtrip(annotated_ptype: tuple[PType, type]) -> None: ptype, annotation = annotated_ptype bs = bytes_from_ptype(ptype) new_type = ptype_from_bytes(bs, annotation) @@ -125,12 +128,12 @@ def test_annotated_bytes_roundtrip(annotated_ptype: tuple[PType, type]): @pytest.mark.parametrize("ptype", type_list) -def test_ptype_from_annotation(ptype: type[PType]): +def test_ptype_from_annotation(ptype: type[PType]) -> None: assert is_ptype(ptype) @pytest.mark.parametrize("ptype", fail_list) -def test_ptype_from_annotation_fails(ptype: type[PType]): +def test_ptype_from_annotation_fails(ptype: type[PType]) -> None: assert not is_ptype(ptype) @@ -138,13 +141,13 @@ def test_ptype_from_annotation_fails(ptype: type[PType]): T = TypeVar("T") generic_types = [] -generic_types.append((list[T], {str(T)})) # type: ignore -generic_types.append((list[S | T], {str(S), str(T)})) # type: ignore -generic_types.append((list[list[list[T]]], {str(T)})) # type: ignore -generic_types.append((tuple[S, T], {str(S), str(T)})) # type: ignore -generic_types.append((UntupledModel[S, T], {str(S), str(T)})) # type: ignore +generic_types.append((list[T], {str(T)})) # type: ignore[valid-type] +generic_types.append((list[S | T], {str(S), str(T)})) # type: ignore[valid-type] +generic_types.append((list[list[list[T]]], {str(T)})) # type: ignore[valid-type] +generic_types.append((tuple[S, T], {str(S), str(T)})) # type: ignore[valid-type] +generic_types.append((UntupledModel[S, T], {str(S), str(T)})) # type: ignore[valid-type] -@pytest.mark.parametrize("ptype,generics", generic_types) -def test_generic_types(ptype: type[PType], generics: set[type[PType]]): +@pytest.mark.parametrize(("ptype", "generics"), generic_types) +def test_generic_types(ptype: type[PType], generics: set[type[PType]]) -> None: assert generics_in_ptype(ptype) == generics diff --git a/tierkreis/tests/controller/typed_graphdata.py b/tierkreis/tests/controller/typed_graphdata.py index cd0bd7f71..75545043e 100644 --- a/tierkreis/tests/controller/typed_graphdata.py +++ b/tierkreis/tests/controller/typed_graphdata.py @@ -1,15 +1,16 @@ from typing import NamedTuple + +from tierkreis.builder import GraphBuilder from tierkreis.builtins.stubs import ( + conjugate, iadd, igt, itimes, + mod, tkr_tuple, untuple, - mod, - conjugate, ) from tierkreis.controller.data.core import EmptyModel -from tierkreis.builder import GraphBuilder from tierkreis.controller.data.models import TKR @@ -23,14 +24,14 @@ class DoublerOutput(NamedTuple): value: TKR[int] -def typed_doubler(): +def typed_doubler() -> GraphBuilder[TKR[int], TKR[int]]: g = GraphBuilder(TKR[int], TKR[int]) out = g.task(itimes(a=g.const(2), b=g.inputs)) g.outputs(out) return g -def typed_doubler_plus_multi(): +def typed_doubler_plus_multi() -> GraphBuilder[DoublerInput, DoublerOutput]: g = GraphBuilder(DoublerInput, DoublerOutput) mul = g.task(itimes(a=g.inputs.x, b=g.const(2))) out = g.task(iadd(a=mul, b=g.inputs.intercept)) @@ -38,7 +39,7 @@ def typed_doubler_plus_multi(): return g -def typed_doubler_plus(): +def typed_doubler_plus() -> GraphBuilder[DoublerInput, TKR[int]]: g = GraphBuilder(DoublerInput, TKR[int]) mul = g.task(itimes(a=g.inputs.x, b=g.const(2))) out = g.task(iadd(a=mul, b=g.inputs.intercept)) @@ -50,7 +51,7 @@ class TypedEvalOutputs(NamedTuple): typed_eval_output: TKR[int] -def typed_eval(): +def typed_eval() -> GraphBuilder[EmptyModel, TypedEvalOutputs]: g = GraphBuilder(EmptyModel, TypedEvalOutputs) e = g.eval(typed_doubler_plus(), DoublerInput(x=g.const(6), intercept=g.const(0))) g.outputs(TypedEvalOutputs(typed_eval_output=e)) @@ -66,7 +67,7 @@ class LoopBodyOutput(NamedTuple): should_continue: TKR[bool] -def loop_body(): +def loop_body() -> GraphBuilder[LoopBodyInput, LoopBodyOutput]: g = GraphBuilder(LoopBodyInput, LoopBodyOutput) a_plus = g.task(iadd(a=g.inputs.loop_acc, b=g.const(1))) pred = g.task(igt(a=g.const(10), b=a_plus)) @@ -74,21 +75,21 @@ def loop_body(): return g -def typed_loop(): +def typed_loop() -> GraphBuilder[EmptyModel, TKR[int]]: g = GraphBuilder(EmptyModel, TKR[int]) loop = g.loop(loop_body(), LoopBodyInput(loop_acc=g.const(6))) g.outputs(loop.loop_acc) return g -def typed_map_simple(): +def typed_map_simple() -> GraphBuilder[TKR[list[int]], TKR[list[int]]]: g = GraphBuilder(TKR[list[int]], TKR[list[int]]) m = g.map(typed_doubler(), g.inputs) g.outputs(m) return g -def typed_map(): +def typed_map() -> GraphBuilder[TKR[list[int]], TKR[list[int]]]: g = GraphBuilder(TKR[list[int]], TKR[list[int]]) ins = g.map(lambda n: DoublerInput(x=n, intercept=g.const(6)), g.inputs) m = g.map(typed_doubler_plus(), ins) @@ -96,7 +97,7 @@ def typed_map(): return g -def typed_destructuring(): +def typed_destructuring() -> GraphBuilder[TKR[list[int]], TKR[list[int]]]: g = GraphBuilder(TKR[list[int]], TKR[list[int]]) ins = g.map(lambda n: DoublerInput(x=n, intercept=g.const(6)), g.inputs) m = g.map(typed_doubler_plus_multi(), ins) @@ -105,7 +106,7 @@ def typed_destructuring(): return g -def tuple_untuple(): +def tuple_untuple() -> GraphBuilder[EmptyModel, TKR[int]]: g = GraphBuilder(EmptyModel, TKR[int]) t = g.task(tkr_tuple(g.const(1), g.const(2))) ut = g.task(untuple(t)) @@ -113,7 +114,7 @@ def tuple_untuple(): return g -def factorial(): +def factorial() -> GraphBuilder[TKR[int], TKR[int]]: g = GraphBuilder(TKR[int], TKR[int]) pred = g.task(igt(g.inputs, g.const(1))) n_minus_one = g.task(iadd(g.const(-1), g.inputs)) @@ -128,7 +129,7 @@ class GCDInput(NamedTuple): b: TKR[int] -def gcd(): +def gcd() -> GraphBuilder[GCDInput, TKR[int]]: g = GraphBuilder(GCDInput, TKR[int]) pred = g.task(igt(g.inputs.b, g.const(0))) @@ -139,14 +140,14 @@ def gcd(): return g -def tkr_conj(): +def tkr_conj() -> GraphBuilder[TKR[complex], TKR[complex]]: g = GraphBuilder(TKR[complex], TKR[complex]) z = g.task(conjugate(g.inputs)) g.outputs(z) return g -def tkr_list_conj(): +def tkr_list_conj() -> GraphBuilder[TKR[list[complex]], TKR[list[complex]]]: g = GraphBuilder(TKR[list[complex]], TKR[list[complex]]) zs = g.map(tkr_conj(), g.inputs) g.outputs(zs) diff --git a/tierkreis/tests/errors/failing_worker/main.py b/tierkreis/tests/errors/failing_worker/main.py index 0365613f5..466644bb1 100644 --- a/tierkreis/tests/errors/failing_worker/main.py +++ b/tierkreis/tests/errors/failing_worker/main.py @@ -1,5 +1,7 @@ +# noqa: INP001 import logging -from sys import argv +import sys + from tierkreis import Worker logger = logging.getLogger(__name__) @@ -9,7 +11,8 @@ @worker.task() def fail() -> int: logger.error("Raising an error now...") - raise ValueError("Worker failed!") + msg = "Worker failed!" + raise ValueError(msg) @worker.task() @@ -19,8 +22,8 @@ def wont_fail() -> int: @worker.task() def exit_code_1() -> int: - exit(1) + sys.exit(1) if __name__ == "__main__": - worker.app(argv) + worker.app(sys.argv) diff --git a/tierkreis/tests/errors/test_error.py b/tierkreis/tests/errors/test_error.py index 1c4da439b..0d60cce4a 100644 --- a/tierkreis/tests/errors/test_error.py +++ b/tierkreis/tests/errors/test_error.py @@ -1,7 +1,9 @@ -import pytest from pathlib import Path from uuid import UUID +import pytest + +from tests.errors.failing_worker.stubs import exit_code_1, fail, wont_fail from tierkreis.builder import GraphBuilder from tierkreis.controller import run_graph from tierkreis.controller.data.core import EmptyModel @@ -9,29 +11,28 @@ from tierkreis.controller.data.models import TKR from tierkreis.controller.executor.uv_executor import UvExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage -from tests.errors.failing_worker.stubs import fail, wont_fail, exit_code_1 from tierkreis.exceptions import TierkreisError -def will_fail_graph(): +def will_fail_graph() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.task(fail())) return graph -def wont_fail_graph(): +def wont_fail_graph() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.task(wont_fail())) return graph -def fail_in_eval(): +def fail_in_eval() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.eval(will_fail_graph(), EmptyModel())) return graph -def non_zero_exit_code(): +def non_zero_exit_code() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.task(exit_code_1())) return graph @@ -44,7 +45,7 @@ def test_raise_error() -> None: storage.clean_graph_files() with pytest.raises(TierkreisError): run_graph(storage, executor, g.get_data(), {}, n_iterations=1000) - assert storage.node_has_error(Loc("-.N0")) + assert storage.node_has_error(Loc("-.N0")) def test_raises_no_error() -> None: @@ -63,7 +64,10 @@ def test_nested_error() -> None: storage.clean_graph_files() with pytest.raises(TierkreisError): run_graph(storage, executor, g.get_data(), {}, n_iterations=1000) - assert (storage.logs_path.parent / "-/errors").exists() + assert (storage.logs_path.parent / "-/logs").exists() + with Path.open((storage.logs_path.parent / "-/logs"), "r") as f: + error_contents = f.read() + assert "-.N1.N0" in error_contents def test_non_zero_exit_code() -> None: @@ -73,4 +77,4 @@ def test_non_zero_exit_code() -> None: storage.clean_graph_files() with pytest.raises(TierkreisError): run_graph(storage, executor, g.get_data(), {}, n_iterations=1000) - assert (storage.logs_path.parent / "-/_error").exists() + assert (storage.logs_path.parent / "-/logs").exists() diff --git a/tierkreis/tests/executor/test_hpc_executor.py b/tierkreis/tests/executor/test_hpc_executor.py index 84b287887..5f47f70ba 100644 --- a/tierkreis/tests/executor/test_hpc_executor.py +++ b/tierkreis/tests/executor/test_hpc_executor.py @@ -1,6 +1,9 @@ from pathlib import Path from uuid import UUID + import pytest + +from tests.executor.stubs import mpi_rank_info from tierkreis.builder import GraphBuilder from tierkreis.controller import run_graph from tierkreis.controller.data.graph import GraphData @@ -12,8 +15,6 @@ ) from tierkreis.controller.executor.hpc.slurm import SLURMExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage - -from tests.executor.stubs import mpi_rank_info from tierkreis.storage import read_outputs @@ -28,7 +29,9 @@ def job_spec() -> JobSpec: return JobSpec( job_name="test_job", account="test_usr", - command="--allow-run-as-root /root/.local/bin/uv run /slurm_mpi_worker/main.py ", + command=( + "--allow-run-as-root /root/.local/bin/uv run /slurm_mpi_worker/main.py " + ), resource=ResourceSpec(nodes=2, memory_gb=None), walltime="00:15:00", mpi=MpiSpec(max_proc_per_node=1), @@ -47,7 +50,7 @@ def test_slurm_with_mpi() -> None: do_cleanup=True, ) sbatch = str( - Path(__file__).parent.parent.parent.parent / "infra/slurm_local/sbatch" + Path(__file__).parent.parent.parent.parent / "infra/slurm_local/sbatch", ) executor = SLURMExecutor( spec=job_spec(), diff --git a/tierkreis/tests/idl/__init__.py b/tierkreis/tests/idl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tierkreis/tests/idl/namespace1.py b/tierkreis/tests/idl/namespace1.py index 069b6c5f0..38c12cfe1 100644 --- a/tierkreis/tests/idl/namespace1.py +++ b/tierkreis/tests/idl/namespace1.py @@ -1,6 +1,7 @@ from typing import NamedTuple -from tierkreis.controller.data.models import portmapping + from tierkreis import Worker +from tierkreis.controller.data.models import portmapping from tierkreis.controller.data.types import PType worker = Worker("TestNamespace") diff --git a/tierkreis/tests/idl/test_idl.py b/tierkreis/tests/idl/test_idl.py index f2999d3e9..6ff6fa490 100644 --- a/tierkreis/tests/idl/test_idl.py +++ b/tierkreis/tests/idl/test_idl.py @@ -1,10 +1,12 @@ from pathlib import Path + import pytest + +import tests.idl.namespace1 from tierkreis.exceptions import TierkreisError from tierkreis.idl.models import GenericType -from tierkreis.namespace import Namespace from tierkreis.idl.type_symbols import type_symbol -import tests.idl.namespace1 +from tierkreis.namespace import Namespace type_symbols = [ ("uint8", GenericType(int, [])), @@ -13,7 +15,8 @@ ( "Record>", GenericType( - dict, [GenericType(str, []), GenericType(list, [GenericType(str, [])])] + dict, + [GenericType(str, []), GenericType(list, [GenericType(str, [])])], ), ), ( @@ -28,17 +31,17 @@ ), ] type_symbols_for_failure = ["decimal", "unknown", "duration"] -dir = Path(__file__).parent -typespecs = [(dir / "namespace1.tsp", tests.idl.namespace1.expected_namespace)] +current_dir = Path(__file__).parent +typespecs = [(current_dir / "namespace1.tsp", tests.idl.namespace1.expected_namespace)] -@pytest.mark.parametrize("type_symb,expected", type_symbols) -def test_type_t(type_symb: str, expected: type): +@pytest.mark.parametrize(("type_symb", "expected"), type_symbols) +def test_type_t(type_symb: str, expected: type) -> None: assert (expected, "") == type_symbol(type_symb) -@pytest.mark.parametrize("path,expected", typespecs) -def test_namespace(path: Path, expected: Namespace): +@pytest.mark.parametrize(("path", "expected"), typespecs) +def test_namespace(path: Path, expected: Namespace) -> None: namespace = Namespace.from_spec_file(path) assert namespace.stubs() == expected.stubs() @@ -49,6 +52,6 @@ def test_namespace(path: Path, expected: Namespace): @pytest.mark.parametrize("type_symb", type_symbols_for_failure) -def test_parser_fail(type_symb: str): +def test_parser_fail(type_symb: str) -> None: with pytest.raises(TierkreisError): type_symbol(type_symb) diff --git a/tierkreis/tierkreis/__init__.py b/tierkreis/tierkreis/__init__.py index 615e11994..03f08a080 100644 --- a/tierkreis/tierkreis/__init__.py +++ b/tierkreis/tierkreis/__init__.py @@ -1,5 +1,7 @@ +"""Tierkreis main package.""" + +from tierkreis.controller import run_graph from tierkreis.labels import Labels from tierkreis.worker.worker import Worker -from tierkreis.controller import run_graph __all__ = ["Labels", "Worker", "run_graph"] diff --git a/tierkreis/tierkreis/builder.py b/tierkreis/tierkreis/builder.py index 0d6969e64..cbc01079d 100644 --- a/tierkreis/tierkreis/builder.py +++ b/tierkreis/tierkreis/builder.py @@ -1,18 +1,23 @@ +"""Typed graph builder for Tierkreis workflows.""" + +from __future__ import annotations + +from collections.abc import Callable from dataclasses import dataclass from inspect import isclass -from typing import Any, Callable, NamedTuple, Protocol, overload, runtime_checkable +from typing import Any, NamedTuple, Protocol, overload, runtime_checkable from tierkreis.controller.data.core import EmptyModel +from tierkreis.controller.data.graph import GraphData, ValueRef from tierkreis.controller.data.models import ( TKR, TModel, TNamedModel, dict_from_tmodel, - model_fields, init_tmodel, + model_fields, ) from tierkreis.controller.data.types import PType -from tierkreis.controller.data.graph import GraphData, ValueRef @dataclass @@ -24,27 +29,72 @@ class TList[T: TModel]: @runtime_checkable class Function[Out](TNamedModel, Protocol): + """A worker function type. + + :abstract: + """ + @property - def namespace(self) -> str: ... + def namespace(self) -> str: + """The namespace name. + + :return: The namespace name. + :rtype: str + """ + ... @staticmethod - def out() -> type[Out]: ... + def out() -> type[Out]: + """Return the output type of the function. + + :return: The output type. + :rtype: type[Out] + """ + ... @dataclass class TypedGraphRef[Ins: TModel, Outs: TModel]: + """A typed tierkreis graph. + + :attr graph_ref: The graph reference. + :attr outputs_type: The output type of the graph. + :attr inputs_type: The input type of the graph. + """ + graph_ref: ValueRef outputs_type: type[Outs] inputs_type: type[Ins] class LoopOutput(TNamedModel, Protocol): + """Protocol for loop output models to ensure should continue.""" + @property - def should_continue(self) -> TKR[bool]: ... + def should_continue(self) -> TKR[bool]: + """The loop continuation port. + + :return: The continuation port value. + :rtype: TKR[bool] + """ + ... + + +def script(script_name: str, script_input: TKR[bytes]) -> Function[TKR[bytes]]: + """Add a script to the graph. + + A shell script or binary with a single input and output. + Inputs are provided from the standard input and outputs to the standard output. + :param script_name: The name of the script. + :type script_name: str + :param script_input: The input to the script. + :type script_input: TKR[bytes] + :return: The output of the script. + :rtype: Function[TKR[bytes]] + """ -def script(script_name: str, input: TKR[bytes]) -> Function[TKR[bytes]]: - class exec_script(NamedTuple): + class exec_script(NamedTuple): # noqa: N801 input: TKR[bytes] @staticmethod @@ -55,10 +105,18 @@ def out() -> type[TKR[bytes]]: def namespace(self) -> str: return script_name - return exec_script(input=input) + return exec_script(input=script_input) class GraphBuilder[Inputs: TModel, Outputs: TModel]: + """Class to construct typed workflow graphs. + + :attr data: The underlying graph data. + :attr inputs_type: The input type of the graph. + :attr inputs: The inputs to the graph. + :attr outputs_type: The output type of the graph. + """ + outputs_type: type inputs: Inputs @@ -66,7 +124,7 @@ def __init__( self, inputs_type: type[Inputs] = EmptyModel, outputs_type: type[Outputs] = EmptyModel, - ): + ) -> None: self.data = GraphData() self.inputs_type = inputs_type self.outputs_type = outputs_type @@ -74,37 +132,97 @@ def __init__( self.inputs = init_tmodel(self.inputs_type, inputs) def get_data(self) -> GraphData: + """Return the underlying graph from the builder. + + :return: The graph. + :rtype: GraphData + """ return self.data def ref(self) -> TypedGraphRef[Inputs, Outputs]: + """Return a reference of the typed graph. + + :return: The ref of the typed graph. + :rtype: TypedGraphRef[Inputs, Outputs] + """ return TypedGraphRef((-1, "body"), self.outputs_type, self.inputs_type) - def outputs(self, outputs: Outputs): + def outputs(self, outputs: Outputs) -> None: + """Set output nodes of a graph. + + :param outputs: The output nodes. + :type outputs: Outputs + """ self.data.output(inputs=dict_from_tmodel(outputs)) def const[T: PType](self, value: T) -> TKR[T]: + """Add a constant node to the graph. + + :return: The constant value. + :rtype: TKR[T] + """ idx, port = self.data.const(value) return TKR[T](idx, port) def ifelse[A: PType, B: PType]( - self, pred: TKR[bool], if_true: TKR[A], if_false: TKR[B] + self, + pred: TKR[bool], + if_true: TKR[A], + if_false: TKR[B], ) -> TKR[A] | TKR[B]: + """Add an if-else node to the graph. + + This will be evaluated lazily. + The values can be returned from an eval node or another graph. + + :param pred: The predicate value. + :type pred: TKR[bool] + :param if_true: The value if the predicate is true. + :type if_true: TKR[A] + :param if_false: The value if the predicate is false. + :type if_false: TKR[B] + :return: The outputs of the if-else expression. + :rtype: TKR[A] | TKR[B] + """ idx, port = self.data.if_else( - pred.value_ref(), if_true.value_ref(), if_false.value_ref() + pred.value_ref(), + if_true.value_ref(), + if_false.value_ref(), )("value") return TKR(idx, port) def eifelse[A: PType, B: PType]( - self, pred: TKR[bool], if_true: TKR[A], if_false: TKR[B] + self, + pred: TKR[bool], + if_true: TKR[A], + if_false: TKR[B], ) -> TKR[A] | TKR[B]: + """Add an eager if-else node to the graph. + + This will be evaluated eagerly. + The values can be returned from an eval node or another graph. + + :param pred: The predicate value. + :type pred: TKR[bool] + :param if_true: The value if the predicate is true. + :type if_true: TKR[A] + :param if_false: The value if the predicate is false. + :type if_false: TKR[B] + :return: The outputs of the if-else expression. + :rtype: TKR[A] | TKR[B] + """ idx, port = self.data.eager_if_else( - pred.value_ref(), if_true.value_ref(), if_false.value_ref() + pred.value_ref(), + if_true.value_ref(), + if_false.value_ref(), )("value") return TKR(idx, port) def _graph_const[A: TModel, B: TModel]( - self, graph: "GraphBuilder[A, B]" + self, + graph: GraphBuilder[A, B], ) -> TypedGraphRef[A, B]: + # TODO @philipp-seitz: Turn this into a public method? idx, port = self.data.const(graph.data.model_dump()) return TypedGraphRef[A, B]( graph_ref=(idx, port), @@ -112,48 +230,104 @@ def _graph_const[A: TModel, B: TModel]( inputs_type=graph.inputs_type, ) - def task[Out: TModel](self, f: Function[Out]) -> Out: - name = f"{f.namespace}.{f.__class__.__name__}" - ins = dict_from_tmodel(f) - idx, _ = self.data.func(name, ins)("dummy") - OutModel = f.out() + def task[Out: TModel](self, func: Function[Out]) -> Out: + """Add a worker task node to the graph. + + :param func: The worker function. + :type func: Function[Out] + :return: The outputs of the task. + :rtype: Out + """ + name = f"{func.namespace}.{func.__class__.__name__}" + inputs = dict_from_tmodel(func) + idx, _ = self.data.func(name, inputs)("dummy") + OutModel = func.out() # noqa: N806 outputs = [(idx, x) for x in model_fields(OutModel)] return init_tmodel(OutModel, outputs) @overload - def eval[A: TModel, B: TModel](self, body: TypedGraphRef[A, B], a: A) -> B: ... + def eval[A: TModel, B: TModel]( + self, + body: TypedGraphRef[A, B], + eval_inputs: A, + ) -> B: ... @overload - def eval[A: TModel, B: TModel](self, body: "GraphBuilder[A, B]", a: A) -> B: ... def eval[A: TModel, B: TModel]( - self, body: "GraphBuilder[A,B] | TypedGraphRef", a: Any + self, + body: GraphBuilder[A, B], + eval_inputs: A, + ) -> B: ... + def eval[A: TModel, B: TModel]( + self, + body: GraphBuilder[A, B] | TypedGraphRef, + eval_inputs: Any, ) -> Any: + """Add a evaluation node to the graph. + + This will evaluate a nested graph with the given inputs. + + :param body: The graph to evaluate. + :type body: TypedGraphRef[A, B] | GraphBuilder[A, B], + where A are the input type and B the output type of the graph. + :param eval_inputs: The inputs to the graph. + :type eval_inputs: A + :return: The outputs of the evaluation. + :rtype: B + """ if isinstance(body, GraphBuilder): body = self._graph_const(body) - idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(a))("dummy") + idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(eval_inputs))("dummy") outputs = [(idx, x) for x in model_fields(body.outputs_type)] return init_tmodel(body.outputs_type, outputs) @overload def loop[A: TModel, B: LoopOutput]( - self, body: TypedGraphRef[A, B], a: A, name: str | None = None + self, + body: TypedGraphRef[A, B], + loop_inputs: A, + name: str | None = None, ) -> B: ... @overload def loop[A: TModel, B: LoopOutput]( - self, body: "GraphBuilder[A, B]", a: A, name: str | None = None + self, + body: GraphBuilder[A, B], + loop_inputs: A, + name: str | None = None, ) -> B: ... def loop[A: TModel, B: LoopOutput]( self, - body: "TypedGraphRef[A, B] |GraphBuilder[A, B]", - a: A, + body: TypedGraphRef[A, B] | GraphBuilder[A, B], + loop_inputs: A, name: str | None = None, ) -> B: + """Add a loop node to the graph. + + This will loop over the given graph until the `should_continue` output is false. + To trace intermediate values, use the name attribute in conjunction with + read_loop_trace. + + :param body: The graph to loop. + :type body: TypedGraphRef[A, B] | GraphBuilder[A, B], + where A are the input type and B the output type of the graph. + :param loop_inputs: The inputs to the loop graph. + :type loop_inputs: A + :param name: An optional name for the loop. + :type name: str | None + :return: The outputs of the loop. + :rtype: B + """ if isinstance(body, GraphBuilder): body = self._graph_const(body) - g = body.graph_ref - idx, _ = self.data.loop(g, dict_from_tmodel(a), "should_continue", name)( - "dummy" + graph = body.graph_ref + idx, _ = self.data.loop( + graph, + dict_from_tmodel(loop_inputs), + "should_continue", + name, + )( + "dummy", ) outputs = [(idx, x) for x in model_fields(body.outputs_type)] return init_tmodel(body.outputs_type, outputs) @@ -164,27 +338,33 @@ def _unfold_list[T: PType](self, ref: TKR[list[T]]) -> TList[TKR[T]]: return TList(TKR[T](idx, "*")) def _fold_list[T: PType](self, refs: TList[TKR[T]]) -> TKR[list[T]]: - value_ref = (refs._value.node_index, refs._value.port_id) + value_ref = (refs._value.node_index, refs._value.port_id) # noqa: SLF001 idx, _ = self.data.func("builtins.fold_values", {"values_glob": value_ref})( - "dummy" + "dummy", ) return TKR[list[T]](idx, "value") def _map_fn_single_in[A: PType, B: TModel]( - self, aes: TKR[list[A]], body: Callable[[TKR[A]], B] - ) -> "TList[B]": - tlist = self._unfold_list(aes) - return TList(body(TKR(tlist._value.node_index, "*"))) + self, + map_inputs: TKR[list[A]], + body: Callable[[TKR[A]], B], + ) -> TList[B]: + tlist = self._unfold_list(map_inputs) + return TList(body(TKR(tlist._value.node_index, "*"))) # noqa: SLF001 def _map_fn_single_out[A: TModel, B: PType]( - self, aes: TList[A], body: Callable[[A], TKR[B]] + self, + map_inputs: TList[A], + body: Callable[[A], TKR[B]], ) -> TKR[list[B]]: - return self._fold_list(TList(body(aes._value))) + return self._fold_list(TList(body(map_inputs._value))) # noqa: SLF001 def _map_graph_full[A: TModel, B: TModel]( - self, aes: TList[A], body: TypedGraphRef[A, B] + self, + map_inputs: TList[A], + body: TypedGraphRef[A, B], ) -> TList[B]: - ins = dict_from_tmodel(aes._value) + ins = dict_from_tmodel(map_inputs._value) # noqa: SLF001 idx, _ = self.data.map(body.graph_ref, ins)("x") refs = [(idx, s + "-*") for s in model_fields(body.outputs_type)] @@ -194,23 +374,25 @@ def _map_graph_full[A: TModel, B: TModel]( def map[A: PType, B: TNamedModel]( self, body: ( - Callable[[TKR[A]], B] | TypedGraphRef[TKR[A], B] | "GraphBuilder[TKR[A], B]" + Callable[[TKR[A]], B] | TypedGraphRef[TKR[A], B] | GraphBuilder[TKR[A], B] ), - aes: TKR[list[A]], + map_inputs: TKR[list[A]], ) -> TList[B]: ... @overload def map[A: TNamedModel, B: PType]( self, body: ( - Callable[[A], TKR[B]] | TypedGraphRef[A, TKR[B]] | "GraphBuilder[A, TKR[B]]" + Callable[[A], TKR[B]] | TypedGraphRef[A, TKR[B]] | GraphBuilder[A, TKR[B]] ), - aes: TList[A], + map_inputs: TList[A], ) -> TKR[list[B]]: ... @overload def map[A: TNamedModel, B: TNamedModel]( - self, body: TypedGraphRef[A, B] | "GraphBuilder[A, B]", aes: TList[A] + self, + body: TypedGraphRef[A, B] | GraphBuilder[A, B], + map_inputs: TList[A], ) -> TList[B]: ... @overload @@ -219,30 +401,42 @@ def map[A: PType, B: PType]( body: ( Callable[[TKR[A]], TKR[B]] | TypedGraphRef[TKR[A], TKR[B]] - | "GraphBuilder[TKR[A], TKR[B]]" + | GraphBuilder[TKR[A], TKR[B]] ), - aes: TKR[list[A]], + map_inputs: TKR[list[A]], ) -> TKR[list[B]]: ... def map( - self, body: TypedGraphRef | Callable | "GraphBuilder", aes: TKR | TList + self, + body: TypedGraphRef | Callable | GraphBuilder, + map_inputs: TKR | TList, ) -> Any: + """Add a map node to the graph. + + :param body: The graph to map over. + :type body: TypedGraphRef | Callable | GraphBuilder + :param map_inputs: The values to map over. + :type map_inputs: TKR | TList + :return: The outputs of the map. + :rtype: Any + """ if isinstance(body, GraphBuilder): body = self._graph_const(body) if isinstance(body, Callable): - if isinstance(aes, TList): - return self._map_fn_single_out(aes, body) - elif isinstance(aes, TKR): - return self._map_fn_single_in(aes, body) + if isinstance(map_inputs, TList): + return self._map_fn_single_out(map_inputs, body) + if isinstance(map_inputs, TKR): + return self._map_fn_single_in(map_inputs, body) - if isinstance(aes, TKR): - aes = self._unfold_list(aes) + if isinstance(map_inputs, TKR): + map_inputs = self._unfold_list(map_inputs) - out = self._map_graph_full(aes, body) + out = self._map_graph_full(map_inputs, body) if not isclass(body.outputs_type) or not issubclass( - body.outputs_type, TNamedModel + body.outputs_type, + TNamedModel, ): out = self._fold_list(out) diff --git a/tierkreis/tierkreis/builtins/__init__.py b/tierkreis/tierkreis/builtins/__init__.py index e69de29bb..19467cbcc 100644 --- a/tierkreis/tierkreis/builtins/__init__.py +++ b/tierkreis/tierkreis/builtins/__init__.py @@ -0,0 +1 @@ +"""Built-in Tierkreis worker and stubs for basic operations.""" diff --git a/tierkreis/tierkreis/builtins/main.py b/tierkreis/tierkreis/builtins/main.py index aace1d8ec..6a735cbee 100644 --- a/tierkreis/tierkreis/builtins/main.py +++ b/tierkreis/tierkreis/builtins/main.py @@ -1,11 +1,15 @@ +"""Built-in Tierkreis tasks for basic operations.""" + +import statistics +from collections.abc import Sequence from logging import getLogger from pathlib import Path from random import randint -import statistics from sys import argv from time import sleep -from typing import NamedTuple, Sequence +from typing import NamedTuple +from tierkreis import Worker from tierkreis.controller.data.location import WorkerCallArgs from tierkreis.controller.data.models import portmapping from tierkreis.controller.data.types import ( @@ -13,10 +17,8 @@ bytes_from_ptype, ptype_from_bytes, ) -from tierkreis.worker.worker import TierkreisWorkerError from tierkreis.worker.storage.protocol import WorkerStorage -from tierkreis import Worker - +from tierkreis.worker.worker import TierkreisWorkerError logger = getLogger(__name__) @@ -25,151 +27,405 @@ @worker.task() def iadd(a: int, b: int) -> int: - logger.debug(f"iadd {a} {b}") + """Add two integers a+b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: The sum of the two integers. + :rtype: int + """ + logger.debug("iadd %s %s", a, b) return a + b @worker.task() -def add(a: int | float, b: int | float) -> int | float: +def add(a: float, b: float) -> int | float: + """Add two float like values a+b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: The sum of the two values. + :rtype: int | float + """ return a + b @worker.task() def isubtract(a: int, b: int) -> int: + """Subtract two integers a-b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: The difference of the two integers. + :rtype: int + """ return a - b @worker.task() -def subtract(a: int | float, b: int | float) -> int | float: +def subtract(a: float, b: float) -> int | float: + """Subtract two float like values a-b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: The difference of the two values. + :rtype: int | float + """ return a - b @worker.task() def itimes(a: int, b: int) -> int: - logger.debug(f"itimes {a} {b}") + """Multiply two integers a*b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: The product of the two integers. + :rtype: int + """ + logger.debug("itimes %s %s", a, b) return a * b @worker.task() -def times(a: int | float, b: int | float) -> int | float: +def times(a: float, b: float) -> int | float: + """Multiply two float like values a*b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: The product of the two values. + :rtype: int | float + """ return a * b @worker.task() -def divide(a: int | float, b: int | float) -> float: +def divide(a: float, b: float) -> float: + """Divide two float like values a/b. + + :param a: The dividend. + :type a: float + :param b: The divisor. + :type b: float + :return: The quotient of the two values. + :rtype: float + """ return a / b @worker.task() def idivide(a: int, b: int) -> int: + """Integer division of two integers a//b. + + :param a: The dividend. + :type a: int + :param b: The divisor. + :type b: int + :return: The integer quotient of the two integers. + :rtype: int + """ return a // b @worker.task() def igt(a: int, b: int) -> bool: - logger.debug(f"igt {a} {b}") + """Check if integer a is greater than integer b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: True if a > b, False otherwise. + :rtype: bool + """ + logger.debug("igt %s %s", a, b) return a > b @worker.task() -def gt(a: int | float, b: int | float) -> bool: +def gt(a: float, b: float) -> bool: + """Check if value a is greater than value b. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a > b, False otherwise. + :rtype: bool + """ return a > b +@worker.task() +def lt(a: float, b: float) -> bool: + """Check if value a is less than value b. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a < b, False otherwise. + :rtype: bool + """ + return a < b + + @worker.task() def conjugate(z: complex) -> complex: + """Return the complex conjugate of z. + + :param z: The complex number. + :type z: complex + :return: The complex conjugate of z. + :rtype: complex + """ return z.conjugate() @worker.task() -def eq(a: int | float, b: int | float) -> bool: +def eq(a: float, b: float) -> bool: + """Check if two float like values are equal. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a == b, False otherwise. + :rtype: bool + """ return a == b @worker.task() -def neq(a: int | float, b: int | float) -> bool: +def neq(a: float, b: float) -> bool: + """Check if two float like values are not equal. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a != b, False otherwise. + :rtype: bool + """ return a != b @worker.task() def ipow(a: int, b: int) -> int: + """Raise integer a to the power of integer b. + + :param a: The base integer. + :type a: int + :param b: The exponent integer. + :type b: int + :return: The result of a**b. + :rtype: int + """ return a**b @worker.task() -def pow(a: int | float, b: int | float) -> int | float: +def tkr_pow(a: float, b: float) -> int | float: + """Raise value a to the power of value b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The base value. + :type a: float + :param b: The exponent value. + :type b: float + :return: The result of a**b. + :rtype: int | float + """ return a**b @worker.task() -def tkr_abs(a: int | float) -> int | float: +def tkr_abs(a: float) -> int | float: + """Return the absolute value of a float like value. + + :param a: The value. + :type a: float + :return: The absolute value of a. + :rtype: int | float + """ return abs(a) @worker.task() -def tkr_round(a: float | int) -> int: +def tkr_round(a: float) -> int: + """Round a float to the nearest integer. + + :param a: The float value to round. + :type a: float + :return: The rounded integer. + :rtype: int + """ return round(a) @worker.task() -def neg(a: bool) -> bool: +def neg(*, a: bool) -> bool: + """Negate a boolean value. + + :param a: The boolean value. + :type a: bool + :return: The negated boolean value. + :rtype: bool + """ return not a @worker.task() -def trk_and(a: bool, b: bool) -> bool: - logger.debug(f"and {a} {b}") +def tkr_and(*, a: bool, b: bool) -> bool: + """Return the logical AND of two boolean values. + + :param a: The first boolean value. + :type a: bool + :param b: The second boolean value. + :type b: bool + :return: The logical AND of a and b. + :rtype: bool + """ + logger.debug("and %s %s", a, b) return a and b @worker.task() -def trk_or(a: bool, b: bool) -> bool: - logger.debug(f"and {a} {b}") +def tkr_or(*, a: bool, b: bool) -> bool: + """Return the logical OR of two boolean values. + + :param a: The first boolean value. + :type a: bool + :param b: The second boolean value. + :type b: bool + :return: The logical OR of a and b. + :rtype: bool + """ + logger.debug("and %s %s", a, b) return a or b @worker.task() def tkr_id[T: PType](value: T) -> T: - logger.debug(f"id {value}") + """Return the input value unchanged (identity function). + + :param value: The value to return. + :type value: T + :return: The same value. + :rtype: T + """ + logger.debug("id %s", value) return value @worker.task() -def append[T](v: list[T], a: T) -> list[T]: # noqa: E741 +def append[T](v: list[T], a: T) -> list[T]: + """Append an element to a list and return the modified list. + + :param v: The list to append to. + :type v: list[T] + :param a: The element to append. + :type a: T + :return: The list with the element appended. + :rtype: list[T] + """ v.append(a) return v @portmapping class Headed[T: PType](NamedTuple): + """A tuple containing a head element and the rest of the list.""" + head: T rest: list[T] @worker.task() -def head[T: PType](v: list[T]) -> Headed[T]: # noqa: E741 +def head[T: PType](v: list[T]) -> Headed[T]: + """Return the first element and remaining elements of a list. + + :param v: The list. + :type v: list[T] + :return: A Headed tuple containing the first element and the rest of the list. + :rtype: Headed[T] + """ head, rest = v[0], v[1:] return Headed(head=head, rest=rest) @worker.task() def tkr_len[A](v: list[A]) -> int: + """Return the length of a list. + + :param v: The list. + :type v: list[A] + :return: The number of elements in the list. + :rtype: int + """ logger.info("len: %s", v) return len(v) @worker.task() def str_eq(a: str, b: str) -> bool: + """Check if two strings are equal. + + :param a: The first string. + :type a: str + :param b: The second string. + :type b: str + :return: True if the strings are equal, False otherwise. + :rtype: bool + """ return a == b @worker.task() def str_neq(a: str, b: str) -> bool: + """Check if two strings are not equal. + + :param a: The first string. + :type a: str + :param b: The second string. + :type b: str + :return: True if the strings are not equal, False otherwise. + :rtype: bool + """ return a != b @worker.primitive_task() def fold_values(args: WorkerCallArgs, storage: WorkerStorage) -> None: + """Fold multiple values from storage into a single list. + + Reads values from storage matching a glob pattern (values_glob) + and combines them into a single list output at the specified output path. + + :param args: The worker call arguments containing the glob pattern and output path. + :type args: WorkerCallArgs + :param storage: The worker storage for reading and writing values. + :type storage: WorkerStorage + """ values_glob = storage.glob(str(args.inputs["values_glob"])) values_glob.sort(key=lambda x: int(Path(x).name.split("-")[-1])) bs = [storage.read_input(Path(value)) for value in values_glob] @@ -179,50 +435,106 @@ def fold_values(args: WorkerCallArgs, storage: WorkerStorage) -> None: @worker.primitive_task() def unfold_values(args: WorkerCallArgs, storage: WorkerStorage) -> None: + """Unfold a single list value into multiple individual values in storage. + + Reads a list from storage and writes each element to a separate storage location. + + :param args: The worker call arguments containing input value and output directory. + :type args: WorkerCallArgs + :param storage: The worker storage for reading and writing values. + :type storage: WorkerStorage + :raises TierkreisWorkerError: If the input is not a list or sequence. + """ value_list = ptype_from_bytes(storage.read_input(args.inputs["value"])) match value_list: case list() | Sequence(): for i, v in enumerate(value_list): storage.write_output(args.output_dir / str(i), bytes_from_ptype(v)) case _: - raise TierkreisWorkerError(f"Expected list found {value_list}") + msg = f"Expected list found {value_list}" + raise TierkreisWorkerError(msg) @worker.task() def concat(lhs: str, rhs: str) -> str: + """Concatenate two strings lhs+rhs. + + :param lhs: The first string. + :type lhs: str + :param rhs: The second string. + :type rhs: str + :return: The concatenated string. + :rtype: str + """ return lhs + rhs @worker.task() def tkr_zip[U, V](a: list[U], b: list[V]) -> list[tuple[U, V]]: - return list(zip(a, b)) + """Zip two lists together into a list of tuples. + + :param a: The first list. + :type a: list[U] + :param b: The second list. + :type b: list[V] + :return: A list of tuples pairing elements from both lists. + :rtype: list[tuple[U, V]] + """ + return list(zip(a, b, strict=False)) @portmapping class Unzipped[U: PType, V: PType](NamedTuple): + """A tuple containing two lists resulting from unzipping.""" + a: list[U] b: list[V] @worker.task() def unzip[U: PType, V: PType](value: list[tuple[U, V]]) -> Unzipped[U, V]: - value_a, value_b = map(list, zip(*value)) + """Unzip a list of tuples into two separate lists. + + :param value: The list of tuples to unzip. + :type value: list[tuple[U, V]] + :return: An Unzipped tuple containing two lists. + :rtype: Unzipped[U, V] + """ + value_a, value_b = map(list, zip(*value, strict=False)) return Unzipped(a=value_a, b=value_b) @worker.task() def tkr_tuple[U, V](a: U, b: V) -> tuple[U, V]: + """Create a tuple from two values. + + :param a: The first value. + :type a: U + :param b: The second value. + :type b: V + :return: A tuple containing both values. + :rtype: tuple[U, V] + """ return (a, b) @portmapping class Untupled[U: PType, V: PType](NamedTuple): + """A tuple containing two unpacked values.""" + a: U b: V @worker.task() def untuple[U: PType, V: PType](value: tuple[U, V]) -> Untupled[U, V]: + """Unpack a tuple of two elements into separate values. + + :param value: The tuple to unpack. + :type value: tuple[U, V] + :return: An Untupled tuple containing the two unpacked values. + :rtype: Untupled[U, V] + """ logger.info("untuple: %s", value) value_a, value_b = value return Untupled(a=value_a, b=value_b) @@ -230,78 +542,191 @@ def untuple[U: PType, V: PType](value: tuple[U, V]) -> Untupled[U, V]: @worker.task() def mean(values: list[float]) -> float: + """Calculate the arithmetic mean of a list of floats. + + :param values: The list of float values. + :type values: list[float] + :return: The mean of the values. + :rtype: float + """ return statistics.mean(values) @worker.task() def mod(a: int, b: int) -> int: + """Return the modulo of two integers a % b. + + :param a: The dividend. + :type a: int + :param b: The divisor. + :type b: int + :return: The remainder of a divided by b. + :rtype: int + """ return a % b @worker.task() def rand_int(a: int, b: int) -> int: - return randint(a, b) + """Return a random integer between a and b (inclusive). + + :param a: The lower bound (inclusive). + :type a: int + :param b: The upper bound (inclusive). + :type b: int + :return: A random integer between a and b. + :rtype: int + """ + return randint(a, b) # noqa: S311 @worker.task() def tkr_sleep(delay_seconds: float) -> bool: + """Sleep for a specified number of seconds. + + :param delay_seconds: The number of seconds to sleep. + :type delay_seconds: float + :return: True after the sleep completes. + :rtype: bool + """ sleep(delay_seconds) return True @worker.task() def tkr_encode(string: str) -> bytes: + """Encode a string to bytes using UTF-8 encoding. + + :param string: The string to encode. + :type string: str + :return: The UTF-8 encoded bytes. + :rtype: bytes + """ return string.encode() @worker.task() -def tkr_decode(bytes: bytes) -> str: - return bytes.decode() +def tkr_decode(value_bytes: bytes) -> str: + """Decode bytes to a string using UTF-8 decoding. + + :param value_bytes: The bytes to decode. + :type value_bytes: bytes + :return: The decoded string. + :rtype: str + """ + return value_bytes.decode() @worker.task() def tkr_all[T: PType](values: Sequence[T]) -> bool: + """Check if all elements in a sequence are truthy. + + :param values: The sequence of values. + :type values: Sequence[T] + :return: True if all elements are truthy, False otherwise. + :rtype: bool + """ return all(values) @worker.task() def tkr_any[T: PType](values: Sequence[T]) -> bool: + """Check if any element in a sequence is truthy. + + :param values: The sequence of values. + :type values: Sequence[T] + :return: True if any element is truthy, False otherwise. + :rtype: bool + """ return any(values) @worker.task() def tkr_reversed[T: PType](values: list[T]) -> list[T]: + """Return a reversed copy of a list. + + :param values: The list to reverse. + :type values: list[T] + :return: A new list with elements in reverse order. + :rtype: list[T] + """ return list(reversed(values)) @worker.task() def tkr_extend[T: PType](first: list[T], second: list[T]) -> list[T]: + """Extend a list with elements from another list. + + :param first: The list to extend. + :type first: list[T] + :param second: The list of elements to add. + :type second: list[T] + :return: The extended list. + :rtype: list[T] + """ first.extend(second) return first @worker.task() def concat_lists[U: PType, V: PType](first: list[U], second: list[V]) -> list[U | V]: + """Concatenate two lists of potentially different types. + + :param first: The first list. + :type first: list[U] + :param second: The second list. + :type second: list[V] + :return: A concatenated list containing elements from both lists. + :rtype: list[U | V] + """ return first + second @worker.task() -def tkr_str(value: int | float | bool) -> str: +def tkr_str(*, value: float | bool) -> str: + """Convert a float or bool value to a string. + + :param value: The value to convert. + :type value: float | bool + :return: The string representation of the value. + :rtype: str + """ return str(value) @worker.task() -def tkr_int(value: int | float | bool | str) -> int: +def tkr_int(*, value: float | bool | str) -> int: + """Convert a float, bool, or string value to an integer. + + :param value: The value to convert. + :type value: float | bool | str + :return: The integer representation of the value. + :rtype: int + """ return int(value) @worker.task() def sum_list(values: list[int | float]) -> int | float: + """Sum all elements in a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The sum of all elements. + :rtype: int | float + """ return sum(values) @worker.task() def prod_list(values: list[int | float]) -> int | float: + """Calculate the product of all elements in a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The product of all elements. + :rtype: int | float + """ prod = 1 for v in values: prod *= v @@ -310,26 +735,61 @@ def prod_list(values: list[int | float]) -> int | float: @worker.task() def max_item(values: list[int | float]) -> int | float: + """Return the maximum element from a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The maximum value in the list. + :rtype: int | float + """ return max(values) @worker.task() def min_item(values: list[int | float]) -> int | float: + """Return the minimum element from a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The minimum value in the list. + :rtype: int | float + """ return min(values) @worker.task() def sort_number_list(values: list[int | float]) -> list[int | float]: + """Sort a list of numbers in ascending order. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: A sorted list of numeric values. + :rtype: list[int | float] + """ return sorted(values) @worker.task() def sort_string_list(values: list[str]) -> list[str]: + """Sort a list of strings in ascending order. + + :param values: The list of strings. + :type values: list[str] + :return: A sorted list of strings. + :rtype: list[str] + """ return sorted(values) @worker.task() def flatten[T: PType](values: list[list[T]]) -> list[T]: + """Flatten a list of lists into a single list. + + :param values: The list of lists to flatten. + :type values: list[list[T]] + :return: A flattened list containing all elements. + :rtype: list[T] + """ out = [] for sub in values: out.extend(sub) @@ -338,11 +798,29 @@ def flatten[T: PType](values: list[list[T]]) -> list[T]: @worker.task() def take[T: PType](values: list[T], n: int) -> list[T]: + """Return the first n elements of a list. + + :param values: The list. + :type values: list[T] + :param n: The number of elements to take. + :type n: int + :return: A list containing the first n elements. + :rtype: list[T] + """ return values[:n] @worker.task() def drop[T: PType](values: list[T], n: int) -> list[T]: + """Drop the first n elements of a list and return the rest. + + :param values: The list. + :type values: list[T] + :param n: The number of elements to drop. + :type n: int + :return: A list with the first n elements removed. + :rtype: list[T] + """ return values[n:] diff --git a/tierkreis/tierkreis/builtins/stubs.py b/tierkreis/tierkreis/builtins/stubs.py index 2577a2bd5..58210dade 100644 --- a/tierkreis/tierkreis/builtins/stubs.py +++ b/tierkreis/tierkreis/builtins/stubs.py @@ -1,32 +1,34 @@ """Code generated from builtins namespace. Please do not edit.""" -from typing import NamedTuple, Sequence, Union +from collections.abc import Sequence +from typing import NamedTuple, Union + from tierkreis.controller.data.models import TKR from tierkreis.controller.data.types import PType class Headed[T: PType](NamedTuple): - head: TKR[T] # noqa: F821 # fmt: skip - rest: TKR[list[T]] # noqa: F821 # fmt: skip + head: TKR[T] # fmt: skip + rest: TKR[list[T]] # fmt: skip class Untupled[U: PType, V: PType](NamedTuple): - a: TKR[U] # noqa: F821 # fmt: skip - b: TKR[V] # noqa: F821 # fmt: skip + a: TKR[U] # fmt: skip + b: TKR[V] # fmt: skip class Unzipped[U: PType, V: PType](NamedTuple): - a: TKR[list[U]] # noqa: F821 # fmt: skip - b: TKR[list[V]] # noqa: F821 # fmt: skip + a: TKR[list[U]] # fmt: skip + b: TKR[list[V]] # fmt: skip class iadd(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -34,12 +36,12 @@ def namespace(self) -> str: class add(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -47,12 +49,12 @@ def namespace(self) -> str: class isubtract(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -60,12 +62,12 @@ def namespace(self) -> str: class subtract(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -73,12 +75,12 @@ def namespace(self) -> str: class itimes(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -86,12 +88,12 @@ def namespace(self) -> str: class times(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -99,12 +101,12 @@ def namespace(self) -> str: class divide(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[float]]: # noqa: F821 # fmt: skip - return TKR[float] # noqa: F821 # fmt: skip + def out() -> type[TKR[float]]: # fmt: skip + return TKR[float] # fmt: skip @property def namespace(self) -> str: @@ -112,12 +114,12 @@ def namespace(self) -> str: class idivide(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -125,12 +127,12 @@ def namespace(self) -> str: class igt(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -138,12 +140,25 @@ def namespace(self) -> str: class gt(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip + + @staticmethod + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip + + @property + def namespace(self) -> str: + return "builtins" + + +class lt(NamedTuple): + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -151,11 +166,11 @@ def namespace(self) -> str: class conjugate(NamedTuple): - z: TKR[complex] # noqa: F821 # fmt: skip + z: TKR[complex] # fmt: skip @staticmethod - def out() -> type[TKR[complex]]: # noqa: F821 # fmt: skip - return TKR[complex] # noqa: F821 # fmt: skip + def out() -> type[TKR[complex]]: # fmt: skip + return TKR[complex] # fmt: skip @property def namespace(self) -> str: @@ -163,12 +178,12 @@ def namespace(self) -> str: class eq(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -176,12 +191,12 @@ def namespace(self) -> str: class neq(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -189,25 +204,25 @@ def namespace(self) -> str: class ipow(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: return "builtins" -class pow(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip +class tkr_pow(NamedTuple): + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -215,11 +230,11 @@ def namespace(self) -> str: class tkr_abs(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -227,11 +242,11 @@ def namespace(self) -> str: class tkr_round(NamedTuple): - a: TKR[Union[float, int]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -239,37 +254,37 @@ def namespace(self) -> str: class neg(NamedTuple): - a: TKR[bool] # noqa: F821 # fmt: skip + a: TKR[bool] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: return "builtins" -class trk_and(NamedTuple): - a: TKR[bool] # noqa: F821 # fmt: skip - b: TKR[bool] # noqa: F821 # fmt: skip +class tkr_and(NamedTuple): + a: TKR[bool] # fmt: skip + b: TKR[bool] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: return "builtins" -class trk_or(NamedTuple): - a: TKR[bool] # noqa: F821 # fmt: skip - b: TKR[bool] # noqa: F821 # fmt: skip +class tkr_or(NamedTuple): + a: TKR[bool] # fmt: skip + b: TKR[bool] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -277,11 +292,11 @@ def namespace(self) -> str: class tkr_id[T: PType](NamedTuple): - value: TKR[T] # noqa: F821 # fmt: skip + value: TKR[T] # fmt: skip @staticmethod - def out() -> type[TKR[T]]: # noqa: F821 # fmt: skip - return TKR[T] # noqa: F821 # fmt: skip + def out() -> type[TKR[T]]: # fmt: skip + return TKR[T] # fmt: skip @property def namespace(self) -> str: @@ -289,12 +304,12 @@ def namespace(self) -> str: class append[T: PType](NamedTuple): - v: TKR[list[T]] # noqa: F821 # fmt: skip - a: TKR[T] # noqa: F821 # fmt: skip + v: TKR[list[T]] # fmt: skip + a: TKR[T] # fmt: skip @staticmethod - def out() -> type[TKR[list[T]]]: # noqa: F821 # fmt: skip - return TKR[list[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[T]]]: # fmt: skip + return TKR[list[T]] # fmt: skip @property def namespace(self) -> str: @@ -302,11 +317,11 @@ def namespace(self) -> str: class head[T: PType](NamedTuple): - v: TKR[list[T]] # noqa: F821 # fmt: skip + v: TKR[list[T]] # fmt: skip @staticmethod - def out() -> type[Headed[T]]: # noqa: F821 # fmt: skip - return Headed[T] # noqa: F821 # fmt: skip + def out() -> type[Headed[T]]: # fmt: skip + return Headed[T] # fmt: skip @property def namespace(self) -> str: @@ -314,11 +329,11 @@ def namespace(self) -> str: class tkr_len[A: PType](NamedTuple): - v: TKR[list[A]] # noqa: F821 # fmt: skip + v: TKR[list[A]] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -326,12 +341,12 @@ def namespace(self) -> str: class str_eq(NamedTuple): - a: TKR[str] # noqa: F821 # fmt: skip - b: TKR[str] # noqa: F821 # fmt: skip + a: TKR[str] # fmt: skip + b: TKR[str] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -339,12 +354,12 @@ def namespace(self) -> str: class str_neq(NamedTuple): - a: TKR[str] # noqa: F821 # fmt: skip - b: TKR[str] # noqa: F821 # fmt: skip + a: TKR[str] # fmt: skip + b: TKR[str] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -352,12 +367,12 @@ def namespace(self) -> str: class concat(NamedTuple): - lhs: TKR[str] # noqa: F821 # fmt: skip - rhs: TKR[str] # noqa: F821 # fmt: skip + lhs: TKR[str] # fmt: skip + rhs: TKR[str] # fmt: skip @staticmethod - def out() -> type[TKR[str]]: # noqa: F821 # fmt: skip - return TKR[str] # noqa: F821 # fmt: skip + def out() -> type[TKR[str]]: # fmt: skip + return TKR[str] # fmt: skip @property def namespace(self) -> str: @@ -365,12 +380,12 @@ def namespace(self) -> str: class tkr_zip[U: PType, V: PType](NamedTuple): - a: TKR[list[U]] # noqa: F821 # fmt: skip - b: TKR[list[V]] # noqa: F821 # fmt: skip + a: TKR[list[U]] # fmt: skip + b: TKR[list[V]] # fmt: skip @staticmethod - def out() -> type[TKR[list[tuple[U, V]]]]: # noqa: F821 # fmt: skip - return TKR[list[tuple[U, V]]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[tuple[U, V]]]]: # fmt: skip + return TKR[list[tuple[U, V]]] # fmt: skip @property def namespace(self) -> str: @@ -378,11 +393,11 @@ def namespace(self) -> str: class unzip[U: PType, V: PType](NamedTuple): - value: TKR[list[tuple[U, V]]] # noqa: F821 # fmt: skip + value: TKR[list[tuple[U, V]]] # fmt: skip @staticmethod - def out() -> type[Unzipped[U, V]]: # noqa: F821 # fmt: skip - return Unzipped[U, V] # noqa: F821 # fmt: skip + def out() -> type[Unzipped[U, V]]: # fmt: skip + return Unzipped[U, V] # fmt: skip @property def namespace(self) -> str: @@ -390,12 +405,12 @@ def namespace(self) -> str: class tkr_tuple[U: PType, V: PType](NamedTuple): - a: TKR[U] # noqa: F821 # fmt: skip - b: TKR[V] # noqa: F821 # fmt: skip + a: TKR[U] # fmt: skip + b: TKR[V] # fmt: skip @staticmethod - def out() -> type[TKR[tuple[U, V]]]: # noqa: F821 # fmt: skip - return TKR[tuple[U, V]] # noqa: F821 # fmt: skip + def out() -> type[TKR[tuple[U, V]]]: # fmt: skip + return TKR[tuple[U, V]] # fmt: skip @property def namespace(self) -> str: @@ -403,11 +418,11 @@ def namespace(self) -> str: class untuple[U: PType, V: PType](NamedTuple): - value: TKR[tuple[U, V]] # noqa: F821 # fmt: skip + value: TKR[tuple[U, V]] # fmt: skip @staticmethod - def out() -> type[Untupled[U, V]]: # noqa: F821 # fmt: skip - return Untupled[U, V] # noqa: F821 # fmt: skip + def out() -> type[Untupled[U, V]]: # fmt: skip + return Untupled[U, V] # fmt: skip @property def namespace(self) -> str: @@ -415,11 +430,11 @@ def namespace(self) -> str: class mean(NamedTuple): - values: TKR[list[float]] # noqa: F821 # fmt: skip + values: TKR[list[float]] # fmt: skip @staticmethod - def out() -> type[TKR[float]]: # noqa: F821 # fmt: skip - return TKR[float] # noqa: F821 # fmt: skip + def out() -> type[TKR[float]]: # fmt: skip + return TKR[float] # fmt: skip @property def namespace(self) -> str: @@ -427,12 +442,12 @@ def namespace(self) -> str: class mod(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -440,12 +455,12 @@ def namespace(self) -> str: class rand_int(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[int] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -453,11 +468,11 @@ def namespace(self) -> str: class tkr_sleep(NamedTuple): - delay_seconds: TKR[float] # noqa: F821 # fmt: skip + delay_seconds: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -465,11 +480,11 @@ def namespace(self) -> str: class tkr_encode(NamedTuple): - string: TKR[str] # noqa: F821 # fmt: skip + string: TKR[str] # fmt: skip @staticmethod - def out() -> type[TKR[bytes]]: # noqa: F821 # fmt: skip - return TKR[bytes] # noqa: F821 # fmt: skip + def out() -> type[TKR[bytes]]: # fmt: skip + return TKR[bytes] # fmt: skip @property def namespace(self) -> str: @@ -477,11 +492,11 @@ def namespace(self) -> str: class tkr_decode(NamedTuple): - bytes: TKR[bytes] # noqa: F821 # fmt: skip + value_bytes: TKR[bytes] # fmt: skip @staticmethod - def out() -> type[TKR[str]]: # noqa: F821 # fmt: skip - return TKR[str] # noqa: F821 # fmt: skip + def out() -> type[TKR[str]]: # fmt: skip + return TKR[str] # fmt: skip @property def namespace(self) -> str: @@ -489,11 +504,11 @@ def namespace(self) -> str: class tkr_all[T: PType](NamedTuple): - values: TKR[Sequence[T]] # noqa: F821 # fmt: skip + values: TKR[Sequence[T]] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -501,11 +516,11 @@ def namespace(self) -> str: class tkr_any[T: PType](NamedTuple): - values: TKR[Sequence[T]] # noqa: F821 # fmt: skip + values: TKR[Sequence[T]] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: @@ -513,11 +528,11 @@ def namespace(self) -> str: class tkr_reversed[T: PType](NamedTuple): - values: TKR[list[T]] # noqa: F821 # fmt: skip + values: TKR[list[T]] # fmt: skip @staticmethod - def out() -> type[TKR[list[T]]]: # noqa: F821 # fmt: skip - return TKR[list[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[T]]]: # fmt: skip + return TKR[list[T]] # fmt: skip @property def namespace(self) -> str: @@ -525,12 +540,12 @@ def namespace(self) -> str: class tkr_extend[T: PType](NamedTuple): - first: TKR[list[T]] # noqa: F821 # fmt: skip - second: TKR[list[T]] # noqa: F821 # fmt: skip + first: TKR[list[T]] # fmt: skip + second: TKR[list[T]] # fmt: skip @staticmethod - def out() -> type[TKR[list[T]]]: # noqa: F821 # fmt: skip - return TKR[list[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[T]]]: # fmt: skip + return TKR[list[T]] # fmt: skip @property def namespace(self) -> str: @@ -538,12 +553,12 @@ def namespace(self) -> str: class concat_lists[U: PType, V: PType](NamedTuple): - first: TKR[list[U]] # noqa: F821 # fmt: skip - second: TKR[list[V]] # noqa: F821 # fmt: skip + first: TKR[list[U]] # fmt: skip + second: TKR[list[V]] # fmt: skip @staticmethod - def out() -> type[TKR[list[Union[U, V]]]]: # noqa: F821 # fmt: skip - return TKR[list[Union[U, V]]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[Union[U, V]]]]: # fmt: skip + return TKR[list[Union[U, V]]] # fmt: skip @property def namespace(self) -> str: @@ -551,11 +566,11 @@ def namespace(self) -> str: class tkr_str(NamedTuple): - value: TKR[Union[int, float, bool]] # noqa: F821 # fmt: skip + value: TKR[Union[float, bool]] # fmt: skip @staticmethod - def out() -> type[TKR[str]]: # noqa: F821 # fmt: skip - return TKR[str] # noqa: F821 # fmt: skip + def out() -> type[TKR[str]]: # fmt: skip + return TKR[str] # fmt: skip @property def namespace(self) -> str: @@ -563,11 +578,11 @@ def namespace(self) -> str: class tkr_int(NamedTuple): - value: TKR[Union[int, float, bool, str]] # noqa: F821 # fmt: skip + value: TKR[Union[float, bool, str]] # fmt: skip @staticmethod - def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip - return TKR[int] # noqa: F821 # fmt: skip + def out() -> type[TKR[int]]: # fmt: skip + return TKR[int] # fmt: skip @property def namespace(self) -> str: @@ -575,11 +590,11 @@ def namespace(self) -> str: class sum_list(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -587,11 +602,11 @@ def namespace(self) -> str: class prod_list(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -599,11 +614,11 @@ def namespace(self) -> str: class max_item(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -611,11 +626,11 @@ def namespace(self) -> str: class min_item(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -623,11 +638,11 @@ def namespace(self) -> str: class sort_number_list(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[list[Union[int, float]]]]: # noqa: F821 # fmt: skip - return TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[Union[int, float]]]]: # fmt: skip + return TKR[list[Union[int, float]]] # fmt: skip @property def namespace(self) -> str: @@ -635,11 +650,11 @@ def namespace(self) -> str: class sort_string_list(NamedTuple): - values: TKR[list[str]] # noqa: F821 # fmt: skip + values: TKR[list[str]] # fmt: skip @staticmethod - def out() -> type[TKR[list[str]]]: # noqa: F821 # fmt: skip - return TKR[list[str]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[str]]]: # fmt: skip + return TKR[list[str]] # fmt: skip @property def namespace(self) -> str: @@ -647,11 +662,11 @@ def namespace(self) -> str: class flatten[T: PType](NamedTuple): - values: TKR[list[list[T]]] # noqa: F821 # fmt: skip + values: TKR[list[list[T]]] # fmt: skip @staticmethod - def out() -> type[TKR[list[T]]]: # noqa: F821 # fmt: skip - return TKR[list[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[T]]]: # fmt: skip + return TKR[list[T]] # fmt: skip @property def namespace(self) -> str: @@ -659,12 +674,12 @@ def namespace(self) -> str: class take[T: PType](NamedTuple): - values: TKR[list[T]] # noqa: F821 # fmt: skip - n: TKR[int] # noqa: F821 # fmt: skip + values: TKR[list[T]] # fmt: skip + n: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[list[T]]]: # noqa: F821 # fmt: skip - return TKR[list[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[T]]]: # fmt: skip + return TKR[list[T]] # fmt: skip @property def namespace(self) -> str: @@ -672,12 +687,12 @@ def namespace(self) -> str: class drop[T: PType](NamedTuple): - values: TKR[list[T]] # noqa: F821 # fmt: skip - n: TKR[int] # noqa: F821 # fmt: skip + values: TKR[list[T]] # fmt: skip + n: TKR[int] # fmt: skip @staticmethod - def out() -> type[TKR[list[T]]]: # noqa: F821 # fmt: skip - return TKR[list[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[T]]]: # fmt: skip + return TKR[list[T]] # fmt: skip @property def namespace(self) -> str: diff --git a/tierkreis/tierkreis/cli/__init__.py b/tierkreis/tierkreis/cli/__init__.py index e69de29bb..0c81fcc6b 100644 --- a/tierkreis/tierkreis/cli/__init__.py +++ b/tierkreis/tierkreis/cli/__init__.py @@ -0,0 +1 @@ +"""The Tierkreis CLI.""" diff --git a/tierkreis/tierkreis/cli/run_workflow.py b/tierkreis/tierkreis/cli/run_workflow.py index 55c32c2f6..2a6ec1d70 100644 --- a/tierkreis/tierkreis/cli/run_workflow.py +++ b/tierkreis/tierkreis/cli/run_workflow.py @@ -1,41 +1,67 @@ -from pathlib import Path -import uuid +"""Implementation to run a workflow.""" + import logging +import uuid +from pathlib import Path from tierkreis.controller import run_graph from tierkreis.controller.data.graph import GraphData -from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import PType -from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.executor.uv_executor import UvExecutor +from tierkreis.controller.storage.filestorage import ControllerFileStorage +from tierkreis.storage import read_outputs logger = logging.getLogger(__name__) -def run_workflow( +def run_workflow( # noqa: PLR0913 graph: GraphData, inputs: dict[str, PType], name: str | None = None, run_id: int | None = None, log_level: int | str = logging.INFO, registry_path: Path | None = None, + *, print_output: bool = False, - use_uv_worker: bool = False, + use_uv_executor: bool = False, n_iterations: int = 10000, polling_interval_seconds: float = 0.1, ) -> None: - """Run a workflow.""" + """Run a workflow. + + Wrapper for :py:func:`tierkreis.controller.run_graph.run_graph` to run a workflow. + Adds some sensible defaults. + + :param graph: The graph to run. + :type graph: GraphData + :param inputs: The inputs to the workflow. + :type inputs: dict[str, PType] + :param name: The name of the workflow, defaults to None + :type name: str | None, optional + :param run_id: The run ID of the workflow, defaults to None + :type run_id: int | None, optional + :param log_level: The log level for the workflow, defaults to logging.INFO + :type log_level: int | str, optional + :param registry_path: The worker registry, defaults to Path(__file__).parent + :type registry_path: Path | None, optional + :param print_output: Whether to print final outputs, defaults to False + :type print_output: bool, optional + :param use_uv_executor: Use the UV executor instead of ShellExecutor + , defaults to False + :type use_uv_executor: bool, optional + :param n_iterations: The maximum number of iterations, defaults to 10000 + :type n_iterations: int, optional + :param polling_interval_seconds: The controller tickrate, defaults to 0.1 + :type polling_interval_seconds: float, optional + """ logger.setLevel(log_level) - if run_id is None: - workflow_id = uuid.uuid4() - else: - workflow_id = uuid.UUID(int=run_id) + workflow_id = uuid.uuid4() if run_id is None else uuid.UUID(int=run_id) logger.info("Workflow ID is %s", workflow_id) storage = ControllerFileStorage(workflow_id, name=name, do_cleanup=True) if registry_path is None: registry_path = Path(__file__).parent - if use_uv_worker: + if use_uv_executor: executor = UvExecutor(registry_path=registry_path, logs_path=storage.logs_path) else: executor = ShellExecutor(registry_path, storage.workflow_dir) @@ -51,6 +77,9 @@ def run_workflow( polling_interval_seconds, ) if print_output: - all_outputs = graph.nodes[graph.output_idx()].inputs - for output in all_outputs: - print(f"{output}: {storage.read_output(Loc(), output)}") + all_outputs = read_outputs(graph, storage) + if isinstance(all_outputs, dict): + for output_name, output_value in all_outputs.items(): + print(f"{output_name}: {output_value!r}") # noqa: T201 + else: + print(f"value: {all_outputs!r}") # noqa: T201 diff --git a/tierkreis/tierkreis/cli/tkr.py b/tierkreis/tierkreis/cli/tkr.py index 7edc5eb7b..2ede61c45 100644 --- a/tierkreis/tierkreis/cli/tkr.py +++ b/tierkreis/tierkreis/cli/tkr.py @@ -1,3 +1,5 @@ +"""Tierkreis CLI main entrypoint.""" + from __future__ import annotations import argparse @@ -6,27 +8,44 @@ import logging import sys from pathlib import Path -from typing import Any, Callable +from typing import TYPE_CHECKING from tierkreis.cli.run_workflow import run_workflow from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType, ptype_from_bytes from tierkreis.exceptions import TierkreisError +if TYPE_CHECKING: + import types + from collections.abc import Callable + +logger = logging.getLogger(__name__) + -def _import_from_path(module_name: str, file_path: str) -> Any: - spec = importlib.util.spec_from_file_location(module_name, file_path) # type: ignore - module = importlib.util.module_from_spec(spec) # type: ignore +def _import_from_path(module_name: str, file_path: str) -> types.ModuleType: + """Import a graph when supplied as a path to a python file.""" + spec = importlib.util.spec_from_file_location(module_name, file_path) # type: ignore[no-untyped-call] + module = importlib.util.module_from_spec(spec) # type: ignore[no-untyped-call] sys.modules[module_name] = module spec.loader.exec_module(module) return module def load_graph(graph_input: str) -> GraphData: + """Load a graph from an argument string. + + Loads a graph similar to how python runs modules with "-m" + + :param graph_input: The argument string specifying the graph. + :type graph_input: str + :raises TierkreisError: If the argument string is invalid. + :return: The loaded graph data. + :rtype: GraphData + """ if ":" not in graph_input: - raise TierkreisError(f"Invalid argument: {graph_input}") + msg = f"Invalid argument: {graph_input}" + raise TierkreisError(msg) module_name, function_name = graph_input.split(":") - print(f"Loading graph from module '{module_name}' and function '{function_name}'") if ".py" in module_name: module = _import_from_path("graph_module", module_name) else: @@ -37,15 +56,17 @@ def load_graph(graph_input: str) -> GraphData: def _load_inputs(input_files: list[str]) -> dict[str, PType]: + """Load the inputs to a graph.""" if len(input_files) == 1 and input_files[0].endswith(".json"): - with open(input_files[0], "r") as fh: + with Path.open(Path(input_files[0])) as fh: return {k: json.dumps(v).encode() for k, v in json.load(fh).items()} inputs = {} for input_file in input_files: if ":" not in input_file: - raise TierkreisError(f"Invalid argument: {input_file}") + msg = f"Invalid argument: {input_file}" + raise TierkreisError(msg) key, value = input_file.split(":") - with open(value, "rb") as fh: + with Path.open(Path(value), "rb") as fh: inputs[key] = ptype_from_bytes(fh.read()) return inputs @@ -53,21 +74,31 @@ def _load_inputs(input_files: list[str]) -> dict[str, PType]: def parse_args( main_parser: argparse._SubParsersAction[argparse.ArgumentParser], ) -> argparse.ArgumentParser: + """Parse the arguments for the 'run' subcommand. + + :param main_parser: The main parser to add the subcommand to. + :type main_parser: argparse._SubParsersAction[argparse.ArgumentParser] + :return: The parser for the 'run' subcommand. + :rtype: argparse.ArgumentParser + """ parser = main_parser.add_parser( name="run", description="Tierkreis: a workflow engine for quantum HPC.", ) graph = parser.add_mutually_exclusive_group(required=True) graph.add_argument( - "-f", "--from-file", type=Path, help="Load a graph from a .json file" + "-f", + "--from-file", + type=Path, + help="Load a graph from a .json file", ) graph.add_argument( "-g", "--graph-location", help="Fully qualifying name of a Callable () -> GraphData. " - + "Example: tierkreis.cli.sample_graph:simple_eval" - + "Or a path to a python file and function." - + "Example: examples/hello_world/hello_world_graph.py:hello_graph", + "Example: tierkreis.cli.sample_graph:simple_eval" + "Or a path to a python file and function." + "Example: examples/hello_world/hello_world_graph.py:hello_graph", type=str, ) parser.add_argument( @@ -76,10 +107,13 @@ def parse_args( nargs="*", help="Graph inputs:" "Either a single .json file or a key value list port1:path1 port2:path2" - + "where path is a binary file.", + "where path is a binary file.", ) parser.add_argument( - "--run-id", default=None, type=int, help="Set a workflow run id" + "--run-id", + default=None, + type=int, + help="Set a workflow run id", ) parser.add_argument("--name", default=None, type=str, help="Set a workflow name") parser.add_argument( @@ -91,7 +125,10 @@ def parse_args( ) parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( - "--registry-path", default=None, type=Path, help="Location of executable tasks." + "--registry-path", + default=None, + type=Path, + help="Location of executable tasks.", ) parser.add_argument( "-o", @@ -126,27 +163,28 @@ def parse_args( return parser -def run_workflow_args(args: argparse.Namespace): +def run_workflow_args(args: argparse.Namespace) -> None: + """Run a tierkreis workflow according to the run command. + + :param args: The arguments parsed from tkr run. + :type args: argparse.Namespace + """ if args.verbose: args.log_level = logging.DEBUG if args.graph_location is not None: graph = load_graph(args.graph_location) else: - with open(args.from_file, "r") as fh: + with Path.open(args.from_file) as fh: graph = ptype_from_bytes(fh.read().encode(), GraphData) - if args.input_files is not None: - inputs = _load_inputs(args.input_files) - else: - inputs = {} - print(inputs) + inputs = _load_inputs(args.input_files) if args.input_files is not None else {} run_workflow( graph, inputs, name=args.name, run_id=args.run_id, registry_path=args.registry_path, - use_uv_worker=args.uv, + use_uv_executor=args.uv, n_iterations=args.n_iterations, polling_interval_seconds=args.polling_interval_seconds, print_output=args.print_output, @@ -154,19 +192,28 @@ def run_workflow_args(args: argparse.Namespace): class TierkreisCli: + """The main CLI class.""" + @staticmethod def add_subcommand( main_parser: argparse._SubParsersAction[argparse.ArgumentParser], ) -> None: + """Add the 'run' subcommand.""" parser = parse_args(main_parser) parser.set_defaults(func=TierkreisCli.execute) @staticmethod def execute(args: argparse.Namespace) -> None: + """Execute function according to parsed args. + + :param args: The parsed arguments. + :type args: argparse.Namespace + """ run_workflow_args(args) def main() -> None: + """Run the main entry point for the tkr cli.""" parser = argparse.ArgumentParser( prog="tkr", description="The main tierkreis command-line tool.", @@ -174,13 +221,13 @@ def main() -> None: subparser = parser.add_subparsers(title="subcommands") TierkreisCli.add_subcommand(subparser) try: - from tierkreis_visualization.cli import TierkreisVizCli + from tierkreis_visualization.cli import TierkreisVizCli # noqa: PLC0415 TierkreisVizCli.add_subcommand(subparser) except ImportError: - logging.warning("Could not import Tierkreis Visualization CLI") - logging.warning( - "To install it, please run 'pip install tierkreis-visualization'" + logger.warning("Could not import Tierkreis Visualization CLI") + logger.warning( + "To install it, please run 'pip install tierkreis-visualization'", ) args = parser.parse_args(args=None if sys.argv[1:] else ["--help"]) args.func(args) diff --git a/tierkreis/tierkreis/codegen.py b/tierkreis/tierkreis/codegen.py index 1f43d5ee2..1d589b186 100644 --- a/tierkreis/tierkreis/codegen.py +++ b/tierkreis/tierkreis/codegen.py @@ -1,5 +1,9 @@ +"""Code generation utilities for Tierkreis stubs.""" + from inspect import isclass + from pydantic import BaseModel + from tierkreis.controller.data.types import ( DictConvertible, ListConvertible, @@ -12,11 +16,19 @@ def format_ptype(ptype: type | str) -> str: + """Format a ptype to a string. + + :param ptype: The type to format. + :type ptype: type | str + :return: The formatted string representation of the type. + :rtype: str + """ if isinstance(ptype, str): return ptype if isclass(ptype) and issubclass( - ptype, (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel) + ptype, + (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel), ): return f'OpaqueType["{ptype.__module__}.{ptype.__qualname__}"]' @@ -27,8 +39,22 @@ def format_ptype(ptype: type | str) -> str: def format_generic_type( - generictype: GenericType | str, include_bound: bool, is_tkr: bool + generictype: GenericType | str, + *, + include_bound: bool, + is_tkr: bool, ) -> str: + """Format a generic type to a string. + + :param generictype: The generic type to format. + :type generictype: GenericType | str + :param include_bound: Whether to include the bound. + :type include_bound: bool + :param is_tkr: Whether the type is a TKR type. + :type is_tkr: bool + :return: The formatted string representation of the generic type. + :rtype: str + """ bound_str = ": PType" if include_bound else "" if isinstance(generictype, str): out = generictype + bound_str @@ -36,45 +62,84 @@ def format_generic_type( origin_str = format_ptype(generictype.origin) - generics = [format_generic_type(x, include_bound, False) for x in generictype.args] + generics = [ + format_generic_type(x, include_bound=include_bound, is_tkr=False) + for x in generictype.args + ] generics_str = f"[{', '.join(generics)}]" if generictype.args else "" out = f"{origin_str}{generics_str}" return f"TKR[{out}]" if is_tkr else out -def format_typed_arg(typed_arg: TypedArg, is_portmaping: bool) -> str: - type_str = format_generic_type(typed_arg.t, False, not is_portmaping) - should_quote = typed_arg.t.included_structs() and is_portmaping +def format_typed_arg(typed_arg: TypedArg, *, is_portmapping: bool) -> str: + """Format a typed argument to a string. + + :param typed_arg: The typed argument. + :type typed_arg: TypedArg + :param is_portmapping: Wheter the argument is a portmapping. + :type is_portmapping: bool + :return: The formatted string representation of the typed argument. + :rtype: str + """ + type_str = format_generic_type( + typed_arg.t, + include_bound=False, + is_tkr=not is_portmapping, + ) + should_quote = typed_arg.t.included_structs() and is_portmapping type_str = f'"{type_str}"' if should_quote else type_str default_str = " | None = None " if typed_arg.has_default else "" return f"{typed_arg.name}: {type_str}{default_str} {NO_QA_STR}" def format_model(model: Model) -> str: + """Format a model to a string. + + :param model: The model to format. + :type model: Model + :return: The formatted string representation of the model. + :rtype: str + """ is_portmapping = model.is_portmapping - outs = [format_typed_arg(x, not is_portmapping) for x in model.decls] + outs = [format_typed_arg(x, is_portmapping=not is_portmapping) for x in model.decls] outs.sort() outs_str = "\n ".join(outs) bases = ["NamedTuple"] if is_portmapping else ["Struct", "Protocol"] + bases_str = ", ".join(bases) + generic_type_str = format_generic_type(model.t, include_bound=True, is_tkr=False) return f""" -class {format_generic_type(model.t, True, False)}({", ".join(bases)}): +class {generic_type_str}({bases_str}): {outs_str} """ def format_method(namespace_name: str, fn: Method) -> str: - ins = [format_typed_arg(x, False) for x in fn.args] + """Format a method to a string. + + :param namespace_name: The function namespace. + :type namespace_name: str + :param fn: The method to format. + :type fn: Method + :return: The formatted string representation of the method. + :rtype: str + """ + ins = [format_typed_arg(x, is_portmapping=False) for x in fn.args] ins_str = "\n ".join(ins) class_name = format_generic_type( - fn.return_type, False, not fn.return_type_is_portmapping + fn.return_type, + include_bound=False, + is_tkr=not fn.return_type_is_portmapping, ) bases = ["NamedTuple"] - return f"""class {format_generic_type(fn.name, True, False)}({", ".join(bases)}): + class_def = format_generic_type(fn.name, include_bound=True, is_tkr=False) + bases_str = ", ".join(bases) + + return f"""class {class_def}({bases_str}): {ins_str} @staticmethod diff --git a/tierkreis/tierkreis/consts.py b/tierkreis/tierkreis/consts.py index b4ef9c584..27f3a2926 100644 --- a/tierkreis/tierkreis/consts.py +++ b/tierkreis/tierkreis/consts.py @@ -1,5 +1,6 @@ -from pathlib import Path +"""Tierkreis constant definitions.""" +from pathlib import Path PACKAGE_PATH = Path(__file__).parent.parent TESTS_PATH = PACKAGE_PATH / "tests" diff --git a/tierkreis/tierkreis/controller/__init__.py b/tierkreis/tierkreis/controller/__init__.py index f39336509..ff28fcb35 100644 --- a/tierkreis/tierkreis/controller/__init__.py +++ b/tierkreis/tierkreis/controller/__init__.py @@ -1,5 +1,6 @@ import logging from time import sleep +from typing import TYPE_CHECKING from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import Eval, GraphData @@ -7,10 +8,12 @@ from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes from tierkreis.controller.executor.protocol import ControllerExecutor from tierkreis.controller.start import NodeRunData, start, start_nodes -from tierkreis.logger_setup import set_tkr_logger from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.controller.storage.walk import walk_node -from tierkreis.controller.data.core import PortID, ValueRef +from tierkreis.logger_setup import set_tkr_logger + +if TYPE_CHECKING: + from tierkreis.controller.data.core import PortID, ValueRef from tierkreis.exceptions import TierkreisError root_loc = Loc("") @@ -24,6 +27,7 @@ def run_graph( graph_inputs: dict[str, PType] | PType, n_iterations: int = 10000, polling_interval_seconds: float = 0.01, + *, enable_logging: bool = True, ) -> None: if isinstance(g, GraphBuilder): @@ -31,9 +35,9 @@ def run_graph( if not isinstance(graph_inputs, dict): graph_inputs = {"value": graph_inputs} - remaining_inputs = g.remaining_inputs({k for k in graph_inputs.keys()}) + remaining_inputs = g.remaining_inputs(set(graph_inputs.keys())) if len(remaining_inputs) > 0: - logger.warning(f"Some inputs were not provided: {remaining_inputs}") + logger.warning("Some inputs were not provided: %s", remaining_inputs) storage.write_metadata(Loc("")) if enable_logging: @@ -48,7 +52,7 @@ def run_graph( k: (-1, k) for k, _ in graph_inputs.items() if k != "body" } node_run_data = NodeRunData(Loc(), Eval((-1, "body"), inputs), []) - start(storage, executor, node_run_data, enable_logging) + start(storage, executor, node_run_data) resume_graph(storage, executor, n_iterations, polling_interval_seconds) @@ -64,23 +68,21 @@ def resume_graph( for _ in range(n_iterations): walk_results = walk_node(storage, Loc(), graph.output_idx(), graph) if walk_results.errored != []: - # TODO: add to base class after storage refactor - (storage.logs_path.parent / "-" / "_error").touch() node_errors = "\n".join(x for x in walk_results.errored) storage.write_node_errors(Loc(), node_errors) - print("\n\nGraph finished with errors.\n\n") - + logger.error("\n\nGraph finished with errors.\n\n") for error_loc in walk_results.errored: - print(storage.read_errors(error_loc)) - print(f"Node: '{error_loc}' encountered an error.") - print( - f"Stderr information is available at {storage._worker_logs_path(error_loc)}." + logger.error(storage.read_errors(error_loc)) + logger.error("Node: '%s' encountered an error.", error_loc) + logger.error( + "Stderr information is available at %s.", + storage._worker_logs_path(error_loc), # noqa: SLF001 ) - print("\n\n") - print("--- Tierkreis graph errors above this line. ---\n\n") - raise TierkreisError("Graph encountered errors") + logger.error("--- Tierkreis graph errors above this line. ---") + msg = "Graph encountered errors" + raise TierkreisError(msg) start_nodes(storage, executor, walk_results.inputs_ready) if storage.is_node_finished(Loc()): diff --git a/tierkreis/tierkreis/controller/consts.py b/tierkreis/tierkreis/controller/consts.py index 65bf1757d..9df05b531 100644 --- a/tierkreis/tierkreis/controller/consts.py +++ b/tierkreis/tierkreis/controller/consts.py @@ -1,5 +1,6 @@ -import os +"""Controller constants.""" + from pathlib import Path BODY_PORT = "body" -PACKAGE_PATH = Path(os.path.dirname(os.path.realpath(__file__))) +PACKAGE_PATH = Path(__file__).resolve().parent diff --git a/tierkreis/tierkreis/controller/data/__init__.py b/tierkreis/tierkreis/controller/data/__init__.py index e69de29bb..2ba1472be 100644 --- a/tierkreis/tierkreis/controller/data/__init__.py +++ b/tierkreis/tierkreis/controller/data/__init__.py @@ -0,0 +1 @@ +"""Core data structures for typing and constructing graphs.""" diff --git a/tierkreis/tierkreis/controller/data/core.py b/tierkreis/tierkreis/controller/data/core.py index 1aa3422f7..42f80a62a 100644 --- a/tierkreis/tierkreis/controller/data/core.py +++ b/tierkreis/tierkreis/controller/data/core.py @@ -1,8 +1,16 @@ +"""Core types in tierkreis. + +- PortID = str, name of an output on a node +- NodeIndex = int, index of node in the graph list +- ValueRef = tuple[NodeIndex, PortID] reference of a value; + uniquely identified by the node and its output. +""" + +from collections.abc import Callable from dataclasses import dataclass from typing import ( Annotated, Any, - Callable, Literal, NamedTuple, Protocol, @@ -12,14 +20,14 @@ runtime_checkable, ) - PortID = str NodeIndex = int ValueRef = tuple[NodeIndex, PortID] SerializationFormat = Literal["bytes", "json", "unknown"] -class EmptyModel(NamedTuple): ... +class EmptyModel(NamedTuple): + """A model without content.""" @runtime_checkable @@ -27,33 +35,82 @@ class RestrictedNamedTuple[T](Protocol): """A NamedTuple whose members are restricted to being of type T.""" def _asdict(self) -> dict[str, T]: ... - def __getitem__(self, key: SupportsIndex, /) -> T: ... + def __getitem__(self, key: SupportsIndex, /) -> T: + """Access the indexed element as in a tuple.""" + ... @dataclass class Serializer: + """Serializer for tkr values. + + :fields: + serializer (Callable[[Any], Any]): A function taking a value producing a + serialized version of it. + serialization_method (Literal): Indicator of serializer type of + ["bytes", "json", "unknown"], defaults to "bytes". + """ + serializer: Callable[[Any], Any] serialization_method: SerializationFormat = "bytes" @dataclass class Deserializer: + """Serializer for tkr values. + + :fields: + serializer (Callable[[Any], Any]): A function taking a serialized + value producing a deserialized value. + serialization_method (Literal): Indicator of serializer type of + ["bytes", "json", "unknown"], defaults to "bytes". + """ + deserializer: Callable[[Any], Any] serialization_method: SerializationFormat = "bytes" def get_t_from_args[T](t: type[T], hint: type | None) -> T | None: + """Get the possible type generic T from a type. + + :return: The generic hint T if it exists on the value. + Either from its annotation or type hint. + :rtype: T | None + """ if hint is None or get_origin(hint) is not Annotated: return None for arg in get_args(hint): if isinstance(arg, t): return arg + return None + +def get_serializer(hint: type | None) -> Serializer | None: + """Get the serializer for an annotated type. -def get_serializer(hint: type | None): + This is relevant for annotated worker types + AnnotatedType = Annotated[BaseType, ser, deser] + worker_fn: AnnotatedType -> AnnotatedType + + :param hint: The type to get the serializer for. + :type hint: type | None + :return: The deserializer if one is annotated. + :rtype: Deserializer | None + """ return get_t_from_args(Serializer, hint) -def get_deserializer(hint: type | None): +def get_deserializer(hint: type | None) -> Deserializer | None: + """Get the deserializer for an annotated type. + + This is relevant for annotated worker types + AnnotatedType = Annotated[BaseType, ser, deser] + worker_fn: AnnotatedType -> AnnotatedType + + :param hint: The type to get the deserializer for. + :type hint: type | None + :return: The deserializer if one is annotated. + :rtype: Deserializer | None + """ return get_t_from_args(Deserializer, hint) diff --git a/tierkreis/tierkreis/controller/data/graph.py b/tierkreis/tierkreis/controller/data/graph.py index 624dbd572..d2611588d 100644 --- a/tierkreis/tierkreis/controller/data/graph.py +++ b/tierkreis/tierkreis/controller/data/graph.py @@ -1,10 +1,22 @@ -from dataclasses import dataclass, field +"""Graph and node definitions. + +(Computational) graphs are the underlying data structure for workflows in tierkreis. +A Graph is comprised on nodes (atomic operations) and edges (their values). +Nodes have named inputs referencing a previously computed value in the graph; +and named outputs referencing an id to look for the respective value. +Inputs and outputs are called ports. +The graph is constructed by mapping inputs off a node (by name) to the +outputs of a previous node. +""" + import logging -from typing import Any, Callable, Literal, assert_never +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal, assert_never + from pydantic import BaseModel, RootModel -from tierkreis.controller.data.core import PortID -from tierkreis.controller.data.core import NodeIndex -from tierkreis.controller.data.core import ValueRef + +from tierkreis.controller.data.core import NodeIndex, PortID, ValueRef from tierkreis.controller.data.location import Loc, OutputLoc from tierkreis.controller.data.types import PType, ptype_from_bytes from tierkreis.exceptions import TierkreisError @@ -14,79 +26,183 @@ @dataclass class Func: + """A function node. + + Defines a task which is run by a worker on an executor. + + :fields: + function_name (str): The function to run. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + outputs (dict[PortID, NodeIndex]): Typically not used on functions, + for the sake of simplifying NodeData. + """ + function_name: str inputs: dict[PortID, ValueRef] - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) type: Literal["function"] = field(default="function") @dataclass class Eval: + """An eval node. + + Evaluates a nested graph. + Necessary for higher order operations. + + :fields: + graph (ValueRef): The reference to a nested graph body. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the nested graph. + """ + graph: ValueRef inputs: dict[PortID, ValueRef] - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) type: Literal["eval"] = field(default="eval") @dataclass class Loop: + """A loop node. + + Evaluates a nested graph iteratively. + Inputs are updated from the previous iteration. + Loops until continue_port value evaluates to false. + + :fields: + body (ValueRef): The reference to a nested graph body. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + continue_port: PortID: A named boolean port as stopping criterion. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the nested graph. + name (str | None): Used as debug data for loop tracing. + """ + body: ValueRef inputs: dict[PortID, ValueRef] continue_port: PortID # The port that specifies if the loop should continue. - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) type: Literal["loop"] = field(default="loop") name: str | None = None @dataclass class Map: + """A map node. + + Evaluates a nested graph concurrently for a set of values on one port. + Maps have a * input which indicates the value to map over. + Typically this is done by fold map a b c unfold [...] where a b c are the arbitrary + but fixed inputs of the map. + + :fields: + body (ValueRef): The reference to a nested graph body. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + outputs (dict[PortID, NodeIndex]): Typically not used on functions, + for the sake of simplifying NodeData. + """ + body: ValueRef inputs: dict[PortID, ValueRef] - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) type: Literal["map"] = field(default="map") @dataclass class Const: + """A constant node. + + :fields: + value (Any): The constant value + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + Typically "value" or "body" + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the nested graphs. + """ + value: Any - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["const"] = field(default="const") @dataclass class IfElse: + """A lazy if else node. + + :fields: + pred (ValueRef): Ref to a boolean value dictating which branch to evaluate. + if_true (ValueRef): Branch to evaluate when pred is true. + if_false (ValueRef): Branch to evaluate when pred is false. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + Typically pred and values for the branches. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the branches. + """ + pred: ValueRef if_true: ValueRef if_false: ValueRef - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["ifelse"] = field(default="ifelse") @dataclass class EagerIfElse: + """An eager if else node. + + :fields: + pred (ValueRef): Ref to a boolean value dictating which value to forward. + if_true (ValueRef): Branch to forward when pred is true. + if_false (ValueRef): Branch to forward when pred is false. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + Typically pred and values for the branches. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the branches. + """ + pred: ValueRef if_true: ValueRef if_false: ValueRef - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["eifelse"] = field(default="eifelse") @dataclass class Input: + """An input node. + + :fields: + name (str): The name of the input value. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values, + typically a single element. + outputs (dict[PortID, NodeIndex]): Typically not used on inputs, + for the sake of simplifying NodeData. + """ + name: str - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["input"] = field(default="input") @dataclass class Output: + """An output node. + + :fields: + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values, + typically a single element (e.g. computation -> output) + outputs (dict[PortID, NodeIndex]): Typically only forwards itself. + """ + inputs: dict[PortID, ValueRef] - outputs: dict[PortID, NodeIndex] = field(default_factory=lambda: {}) + outputs: dict[PortID, NodeIndex] = field(default_factory=dict) type: Literal["output"] = field(default="output") @@ -95,6 +211,24 @@ class Output: class GraphData(BaseModel): + """The model of a computational graph. + + Encapsulates the entire computation. + Nodes are stored in a list, where the NodeIndex points to a unique node. + Graphs have a single output which can be a Struct of multiple fields. + + :fields: + nodes (list[NodeDef]): The list of nodes in a graph. + fixed_inputs (dict[PortID, OutputLoc]): A dict of fixed inputs for the graph. + They have values defined at construction time. + graph_inputs: (set[PortID]): A set of user defined inputs at runtime. + graph_output_idx (NodeIndex | None): The index of the output node. + Graphs must have exactly one output to run. + named_nodes (dict[str, NodeIndex]): Mapping of node names to their index in the + list. This is used for debug information. + + """ + nodes: list[NodeDef] = [] fixed_inputs: dict[PortID, OutputLoc] = {} graph_inputs: set[PortID] = set() @@ -102,19 +236,55 @@ class GraphData(BaseModel): named_nodes: dict[str, NodeIndex] = {} def input(self, name: str) -> ValueRef: + """Add an input name. + + :param name: The name of the input. + :type name: str + :return: The reference to that value. + :rtype: ValueRef + """ return self.add(Input(name))(name) def const(self, value: PType) -> ValueRef: + """Add a constant value. + + :param value: The value to add. + :type value: PType + :return: The reference to that value. + :rtype: ValueRef + """ return self.add(Const(value))("value") def func( - self, function_name: str, inputs: dict[PortID, ValueRef] + self, + function_name: str, + inputs: dict[PortID, ValueRef], ) -> Callable[[PortID], ValueRef]: + """Add a funciton node (task). + + :param function_name: The name of the function. + :type function_name: str + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Func(function_name, inputs)) def eval( - self, graph: ValueRef, inputs: dict[PortID, ValueRef] + self, + graph: ValueRef, + inputs: dict[PortID, ValueRef], ) -> Callable[[PortID], ValueRef]: + """Add an eval node. + + :param graph: The nested graph to evaluate. + :type graph: ValueRef + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Eval(graph, inputs)) def loop( @@ -124,6 +294,19 @@ def loop( continue_port: PortID, name: str | None = None, ) -> Callable[[PortID], ValueRef]: + """Add a loop node. + + :param body: The graph to loop over. + :type body: ValueRef + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :param continue_port: The termination criterion port. + :type continue_port: PortID + :param name: Name of the loop for tracing, defaults to None + :type name: str | None, optional + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Loop(body, inputs, continue_port, name=name)) def map( @@ -131,25 +314,83 @@ def map( body: ValueRef, inputs: dict[PortID, ValueRef], ) -> Callable[[PortID], ValueRef]: + """Add a map node. + + :param body: The graph to map over. + :type body: ValueRef + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Map(body, inputs)) - def if_else(self, pred: ValueRef, if_true: ValueRef, if_false: ValueRef): + def if_else( + self, + pred: ValueRef, + if_true: ValueRef, + if_false: ValueRef, + ) -> Callable[[PortID], ValueRef]: + """Add an lazy if else node. + + :param pred: The reference to conditional value. + :type pred: ValueRef + :param if_true: The graph/value for the true branch. + :type if_true: ValueRef + :param if_false: The graph/value for the false branch. + :type if_false: ValueRef + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(IfElse(pred, if_true, if_false)) - def eager_if_else(self, pred: ValueRef, if_true: ValueRef, if_false: ValueRef): + def eager_if_else( + self, + pred: ValueRef, + if_true: ValueRef, + if_false: ValueRef, + ) -> Callable[[PortID], ValueRef]: + """Add an eager if else node. + + :param pred: The reference to conditional value. + :type pred: ValueRef + :param if_true: The graph/value for the true branch. + :type if_true: ValueRef + :param if_false: The graph/value for the false branch. + :type if_false: ValueRef + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(EagerIfElse(pred, if_true, if_false)) def output(self, inputs: dict[PortID, ValueRef]) -> None: + """Add an output node. + + Computation -> output. + + :param inputs: The inputs of the outup node. + :type inputs: dict[PortID, ValueRef] + """ self.add(Output(inputs)) def add(self, node: NodeDef) -> Callable[[PortID], ValueRef]: + """Add a node to the graph. + + :param node: The node to add. + :type node: NodeDef + :raises TierkreisError: If multiple outputs are added. + :return: A function given the output name of a node returns + the index of the node it corresponds to. + :rtype: Callable[[PortID], ValueRef] + """ idx = len(self.nodes) self.nodes.append(node) match node.type: case "output": if self.graph_output_idx is not None: + msg = f"Graph already has output at index {self.graph_output_idx}" raise TierkreisError( - f"Graph already has output at index {self.graph_output_idx}" + msg, ) self.graph_output_idx = idx @@ -173,23 +414,43 @@ def add(self, node: NodeDef) -> Callable[[PortID], ValueRef]: return lambda k: (idx, k) def output_idx(self) -> NodeIndex: + """Find the index of the graph output node. + + :raises TierkreisError: If the graph has no output. + :raises TierkreisError: It the node at the index is not an output. + :return: The index for the output node in self.nodes + :rtype: NodeIndex + """ idx = self.graph_output_idx if idx is None: - raise TierkreisError("Graph has no output index.") + msg = "Graph has no output index." + raise TierkreisError(msg) node = self.nodes[idx] if node.type != "output": - raise TierkreisError(f"Expected output node at {idx} found {node}") + msg = f"Expected output node at {idx} found {node}" + raise TierkreisError(msg) return idx def remaining_inputs(self, provided_inputs: set[PortID]) -> set[PortID]: + """Find the inputs for which no values are provided. + + :param provided_inputs: The list of already provided inputs. + :type provided_inputs: set[PortID] + :raises TierkreisError: If provided inputs would overwrite fixed inputs. + :return: A set of input names which don't have an associated value. + :rtype: set[PortID] + """ fixed_inputs = set(self.fixed_inputs.keys()) if fixed_inputs & provided_inputs: - raise TierkreisError( + msg = ( f"Fixed inputs {fixed_inputs}" f" should not intersect provided inputs {provided_inputs}." ) + raise TierkreisError( + msg, + ) actual_inputs = fixed_inputs.union(provided_inputs) return self.graph_inputs - actual_inputs @@ -199,15 +460,35 @@ def graph_node_from_loc( node_location: Loc, graph: GraphData, ) -> tuple[NodeDef, GraphData]: - """Assumes the first part of a loc can be found in current graph""" + """Find the node definition and graph of a nested graph given a loc. + + Nested graphs nodes are not indexed in their parent as their are + represented by a single node. E.g. g_1.eval(const(g_2)) will only produce a single + index although g_2 can contain many nodes. + Locs on the other hand contain this information e.g -.N0.L0.N-1 is a virtual eval + node. + This functions recursively steps trough nested graph definitions like this to find + a graph according to a flat loc. + Assumes the first part of a loc can be found in current graph. + + :param node_location: The loc to search for. + :type node_location: Loc + :param graph: The current graph to search in. + :type graph: GraphData + :raises TierkreisError: On an empty graph of a malformed Loc + :return: The node containing a graph and the graph itself. + :rtype: tuple[NodeDef, GraphData] + """ if len(graph.nodes) == 0: - raise TierkreisError("Cannot convert location to node. Reason: Empty Graph") + msg = "Cannot convert location to node. Reason: Empty Graph" + raise TierkreisError(msg) if node_location == "-": return Eval((-1, "body"), {}), graph step, remaining_location = node_location.pop_first() if isinstance(step, str): - raise TierkreisError("Cannot convert location: Reason: Malformed Loc") + msg = "Cannot convert location: Reason: Malformed Loc" + raise TierkreisError(msg) (_, node_id) = step if node_id == -1: return Eval((-1, "body"), {}), graph @@ -221,7 +502,7 @@ def graph_node_from_loc( case "loop" | "map": graph = _unwrap_graph(graph.nodes[node.body[0]], node.type) _, remaining_location = remaining_location.pop_first() # Remove the M0/L0 - if len(remaining_location.steps()) < 2: + if len(remaining_location.steps()) <= 1: return Eval((-1, "body"), node.inputs, node.outputs), graph node, graph = graph_node_from_loc(remaining_location, graph) @@ -236,9 +517,12 @@ def graph_node_from_loc( def _unwrap_graph(node: NodeDef, node_type: str) -> GraphData: """Safely unwraps a const nodes GraphData.""" if not isinstance(node, Const): - raise TierkreisError( + msg = ( f"Cannot convert location to node. Reason: {node_type} does not wrap const" ) + raise TierkreisError( + msg, + ) match node.value: case GraphData() as graph: return graph @@ -248,6 +532,7 @@ def _unwrap_graph(node: NodeDef, node_type: str) -> GraphData: return GraphData(**data) case _: + msg = "Cannot convert location to node. Reason: const value is not a graph" raise TierkreisError( - "Cannot convert location to node. Reason: const value is not a graph" + msg, ) diff --git a/tierkreis/tierkreis/controller/data/location.py b/tierkreis/tierkreis/controller/data/location.py index 1e5bbd8b3..86cec1354 100644 --- a/tierkreis/tierkreis/controller/data/location.py +++ b/tierkreis/tierkreis/controller/data/location.py @@ -1,13 +1,11 @@ from logging import getLogger from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, assert_never from pydantic import BaseModel, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from tierkreis.controller.data.core import PortID -from typing_extensions import assert_never -from tierkreis.controller.data.core import NodeIndex +from tierkreis.controller.data.core import NodeIndex, PortID from tierkreis.exceptions import TierkreisError logger = getLogger(__name__) @@ -20,23 +18,27 @@ class WorkerCallArgs(BaseModel): output_dir: Path done_path: Path error_path: Path - logs_path: Optional[Path] + logs_path: Path | None NodeStep = Literal["-"] | tuple[Literal["N", "L", "M"], NodeIndex] +MIN_LENGTH = 2 + class Loc(str): + __slots__ = [] + def __new__(cls, k: str = "-") -> "Loc": - return super(Loc, cls).__new__(cls, k) + return super().__new__(cls, k) - def N(self, idx: int) -> "Loc": + def N(self, idx: int) -> "Loc": # noqa: N802 return Loc(str(self) + f".N{idx}") - def L(self, idx: int) -> "Loc": + def L(self, idx: int) -> "Loc": # noqa: N802 return Loc(str(self) + f".L{idx}") - def M(self, idx: int) -> "Loc": + def M(self, idx: int) -> "Loc": # noqa: N802 return Loc(str(self) + f".M{idx}") @staticmethod @@ -84,13 +86,16 @@ def steps(self) -> list[NodeStep]: case ("M", idx_str): steps.append(("M", int(idx_str))) case _: - raise TierkreisError(f"Invalid Loc: {self}") + msg = f"Invalid Loc: {self}" + raise TierkreisError(msg) return steps @classmethod def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler + cls, + source_type: Any, # noqa: ANN401 inherited from pydantic + handler: GetCoreSchemaHandler, ) -> CoreSchema: return core_schema.no_info_after_validator_function(cls, handler(str)) @@ -98,22 +103,26 @@ def pop_first(self) -> tuple[NodeStep, "Loc"]: if self == "-": return "-", Loc("") steps = self.steps() - if len(steps) < 2: - raise TierkreisError("Malformed Loc") + if len(steps) < MIN_LENGTH: + msg = "Malformed Loc" + raise TierkreisError(msg) first = steps.pop(1) if first == "-": - raise TierkreisError("Malformed Loc") + msg = "Malformed Loc" + raise TierkreisError(msg) return first, Loc.from_steps(steps) def pop_last(self) -> tuple[NodeStep, "Loc"]: if self == "-": return "-", Loc("") steps = self.steps() - if len(steps) < 2: - raise TierkreisError("Malformed Loc") + if len(steps) < MIN_LENGTH: + msg = "Malformed Loc" + raise TierkreisError(msg) last = steps.pop(-1) if last == "-": - raise TierkreisError("Malformed Loc") + msg = "Malformed Loc" + raise TierkreisError(msg) return last, Loc.from_steps(steps) def peek(self) -> NodeStep: diff --git a/tierkreis/tierkreis/controller/data/models.py b/tierkreis/tierkreis/controller/data/models.py index 50877a13a..e5487da03 100644 --- a/tierkreis/tierkreis/controller/data/models.py +++ b/tierkreis/tierkreis/controller/data/models.py @@ -13,7 +13,9 @@ overload, runtime_checkable, ) + from typing_extensions import TypeIs + from tierkreis.controller.data.core import ( NodeIndex, PortID, @@ -56,7 +58,8 @@ def value_ref(self) -> ValueRef: class TNamedModel(RestrictedNamedTuple[TKR[PType] | None], Protocol): """A struct whose members are restricted to being references to PTypes. - E.g. in graph builder code these are outputs of tasks.""" + E.g. in graph builder code these are outputs of tasks. + """ TModel = TNamedModel | TKR @@ -84,7 +87,7 @@ def is_portmapping( return hasattr(o, TKR_PORTMAPPING_FLAG) -def is_tnamedmodel(o) -> TypeIs[type[TNamedModel]]: +def is_tnamedmodel(o) -> TypeIs[type[TNamedModel]]: # noqa: ANN001 inherited from get_origin origin = get_origin(o) if origin is not None: return is_tnamedmodel(origin) @@ -115,10 +118,10 @@ def dict_from_tmodel(tmodel: TModel) -> dict[PortID, ValueRef]: def model_fields(model: type[PModel] | type[TModel]) -> list[str]: if is_portmapping(model): - return getattr(model, "_fields") + return model._fields if is_tnamedmodel(model): - return getattr(model, "_fields") + return model._fields return ["value"] @@ -134,7 +137,7 @@ def init_tmodel[T: TModel](tmodel: type[T], refs: list[ValueRef]) -> T: if get_origin(param) == Union: param = next(x for x in get_args(param) if x) args.append(param(ref[0], ref[1])) - return cast(T, model(*args)) + return cast("T", model(*args)) return tmodel(refs[0][0], refs[0][1]) diff --git a/tierkreis/tierkreis/controller/data/types.py b/tierkreis/tierkreis/controller/data/types.py index 6ca416bc9..aa110e2af 100644 --- a/tierkreis/tierkreis/controller/data/types.py +++ b/tierkreis/tierkreis/controller/data/types.py @@ -1,19 +1,19 @@ -from collections import defaultdict +# ruff: noqa: ANN001 ANN003 ANN401 due to serialization and inheritance from json +import collections.abc +import json import logging +import pickle from base64 import b64decode, b64encode -import collections.abc +from collections import defaultdict +from collections.abc import Mapping, Sequence from inspect import Parameter, _empty, isclass from itertools import chain -import json -import pickle from types import NoneType, UnionType from typing import ( Annotated, Any, - Mapping, Protocol, Self, - Sequence, TypeVar, Union, assert_never, @@ -25,6 +25,8 @@ from pydantic import BaseModel, ValidationError from pydantic._internal._generics import get_args as pydantic_get_args +from typing_extensions import TypeIs + from tierkreis.controller.data.core import ( RestrictedNamedTuple, SerializationFormat, @@ -32,7 +34,6 @@ get_serializer, ) from tierkreis.exceptions import TierkreisError -from typing_extensions import TypeIs @runtime_checkable @@ -40,7 +41,8 @@ class NdarraySurrogate(Protocol): """A protocol to enable use of numpy.ndarray. By default the serialisation will be done using dumps - and the deserialisation using `pickle.loads`.""" + and the deserialisation using `pickle.loads`. + """ def dumps(self) -> bytes: ... def tobytes(self) -> bytes: ... @@ -99,7 +101,7 @@ class Struct(RestrictedNamedTuple[JsonType], Protocol): ... class TierkreisEncoder(json.JSONEncoder): """Encode bytes also.""" - def default(self, o): + def default(self, o) -> dict[str, Any] | dict[str, list[float]] | Any: if isinstance(o, bytes): return {"__tkr_bytes__": True, "bytes": b64encode(o).decode()} @@ -112,11 +114,11 @@ def default(self, o): class TierkreisDecoder(json.JSONDecoder): """Decode bytes also.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: kwargs.setdefault("object_hook", self._object_hook) super().__init__(**kwargs) - def _object_hook(self, d): + def _object_hook(self, d) -> bytes | complex | Any: """Try to decode an object containing bytes.""" if "__tkr_bytes__" in d and "bytes" in d: return b64decode(d["bytes"]) @@ -129,10 +131,7 @@ def _object_hook(self, d): def _is_union(o: object) -> bool: return ( - get_origin(o) == UnionType - or get_origin(o) == Union - or o == Union - or o == UnionType + o in (Union, UnionType) or get_origin(o) == UnionType or get_origin(o) == Union ) @@ -167,24 +166,23 @@ def is_ptype(annotation: Any) -> TypeIs[type[PType]]: ): return all(is_ptype(x) for x in get_args(annotation)) - elif isclass(annotation) and issubclass( - annotation, - (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel, Struct), - ): - return True - - elif annotation in get_args(ElementaryType.__value__): + if ( + isclass(annotation) + and issubclass( + annotation, + (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel, Struct), + ) + ) or annotation in get_args(ElementaryType.__value__): return True origin = get_origin(annotation) if origin is not None: return is_ptype(origin) and all(is_ptype(x) for x in get_args(annotation)) - else: - return False + return False -def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> Any: +def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> JsonType: if sr := get_serializer(annotation): return sr.serializer(ptype) @@ -241,8 +239,9 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: try: return coerce_from_annotation(ser, t) except (AssertionError, ValidationError): - logger.debug(f"Tried deserialising as {t}") - raise TierkreisError(f"Could not deserialise {ser} as {annotation}") + logger.debug("Tried deserialising as %s", t) + msg = f"Could not deserialise {ser} as {annotation}" + raise TierkreisError(msg) origin = get_origin(annotation) if origin is None: @@ -262,18 +261,24 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: return ser if issubclass(origin, DictConvertible): - assert issubclass(annotation, origin) + if not issubclass(annotation, origin): + msg = "Invalid subclass relation encountered." + raise TypeError(msg) return annotation.from_dict(ser) if issubclass(origin, ListConvertible): - assert issubclass(annotation, origin) + if not issubclass(annotation, origin): + msg = "Invalid subclass relation encountered." + raise TypeError(msg) return annotation.from_list(ser) if issubclass(origin, NdarraySurrogate): return pickle.loads(ser) if issubclass(origin, BaseModel): - assert issubclass(annotation, origin) + if not issubclass(annotation, origin): + msg = "Invalid subclass relation encountered." + raise TypeError(msg) return annotation(**ser) if issubclass(origin, Struct): @@ -281,21 +286,24 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: k: coerce_from_annotation(ser[k], v) for k, v in origin.__annotations__.items() } - return cast(T, origin(**d)) + return cast("T", origin(**d)) if issubclass(origin, collections.abc.Sequence): args = get_args(annotation) if len(args) == 0: return ser - return cast(T, [coerce_from_annotation(x, args[0]) for x in ser]) + return cast("T", [coerce_from_annotation(x, args[0]) for x in ser]) if issubclass(origin, collections.abc.Mapping): args = get_args(annotation) if len(args) == 0: return ser - return cast(T, {k: coerce_from_annotation(v, args[1]) for k, v in ser.items()}) + return cast( + "T", + {k: coerce_from_annotation(v, args[1]) for k, v in ser.items()}, + ) assert_never(ser) @@ -329,7 +337,7 @@ def ptype_from_bytes[T: PType](bs: bytes, annotation: type[T] | None = None) -> j = json.loads(bs, cls=TierkreisDecoder) return coerce_from_annotation(j, annotation) except (json.JSONDecodeError, UnicodeDecodeError): - return cast(T, bs) + return cast("T", bs) case _: assert_never(method) @@ -352,7 +360,7 @@ def generics_in_ptype(ptype: type[PType]) -> set[str]: return set() if issubclass(ptype, BaseModel): - return set((str(x) for x in pydantic_get_args(ptype))) + return {str(x) for x in pydantic_get_args(ptype)} assert_never(ptype) diff --git a/tierkreis/tierkreis/controller/executor/__init__.py b/tierkreis/tierkreis/controller/executor/__init__.py index e69de29bb..a8511eb76 100644 --- a/tierkreis/tierkreis/controller/executor/__init__.py +++ b/tierkreis/tierkreis/controller/executor/__init__.py @@ -0,0 +1 @@ +"""Tierkreis executors to launch worker tasks.""" diff --git a/tierkreis/tierkreis/controller/executor/hpc/__init__.py b/tierkreis/tierkreis/controller/executor/hpc/__init__.py new file mode 100644 index 000000000..79c77e347 --- /dev/null +++ b/tierkreis/tierkreis/controller/executor/hpc/__init__.py @@ -0,0 +1 @@ +"""Collection of HPC executors.""" diff --git a/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py b/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py index 53b8edc65..9591a3b28 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py +++ b/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py @@ -1,18 +1,34 @@ +"""Interface implementation for HPCExecutors.""" + import logging import subprocess +from collections.abc import Callable from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Callable, Protocol +from typing import Protocol from tierkreis.consts import TKR_DIR_KEY from tierkreis.controller.executor.hpc.job_spec import JobSpec from tierkreis.exceptions import TierkreisError - logger = logging.getLogger(__name__) class HPCExecutor(Protocol): + """Generic protocol for an HPC executor. + + :fields: + launchers_path (Path | None): The locations to search for workers. + This will change the location from where the command is invoked + by appending "cd launchers_path && " + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + spec (JobSpec): A definition of the job specification. + script_fn (Callable[[JobSpec], str]): A template function to generate the + submission script from. + command (str): The base command to use. + """ + launchers_path: Path | None logs_path: Path errors_path: Path @@ -22,15 +38,42 @@ class HPCExecutor(Protocol): def generate_script( - template_fn: Callable[[JobSpec], str], spec: JobSpec, path: Path + template_fn: Callable[[JobSpec], str], + spec: JobSpec, + path: Path, ) -> None: - with open(path, "w+", encoding="utf-8") as fh: + """Generate a scheduler script by calling a template function. + + :param template_fn: The template function to call. + :type template_fn: Callable[[JobSpec], str] + :param spec: The job definition to generate the script for. + :type spec: JobSpec + :param path: The path to save the script to. + :type path: Path + """ + with Path.open(path, "w+", encoding="utf-8") as fh: fh.write(template_fn(spec)) def run_hpc_executor( - executor: HPCExecutor, launcher_name: str, worker_call_args_path: Path + executor: HPCExecutor, + launcher_name: str, + worker_call_args_path: Path, ) -> None: + """Run a worker function on an HPC executor. + + This is a generic function to run with with an HPC executor. + Similar to the :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + run function. + + :param executor: The executor to use for running + :type executor: HPCExecutor + :param launcher_name: Module description fo the worker to run + :type launcher_name: str + :param worker_call_args_path: Location of the worker call args. + :type worker_call_args_path: Path + :raises TierkreisError: When job submission fails. + """ logger.info("START %s %s", launcher_name, worker_call_args_path) spec = executor.spec.model_copy() @@ -64,13 +107,15 @@ def run_hpc_executor( process = subprocess.run( submission_cmd, + check=False, start_new_session=True, capture_output=True, - universal_newlines=True, + text=True, ) if process.returncode != 0: - with open(executor.errors_path, "a") as efh: + with Path.open(executor.errors_path, "a") as efh: efh.write("Error from script") efh.write(process.stderr) - raise TierkreisError(f"Executor failed with return code {process.returncode}") + msg = f"Executor failed with return code {process.returncode}" + raise TierkreisError(msg) logger.info("Submitted job with return code %s", process.stdout.rstrip()) diff --git a/tierkreis/tierkreis/controller/executor/hpc/job_spec.py b/tierkreis/tierkreis/controller/executor/hpc/job_spec.py index fd4335b46..5073a3ce9 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/job_spec.py +++ b/tierkreis/tierkreis/controller/executor/hpc/job_spec.py @@ -1,14 +1,41 @@ -from pathlib import Path +"""Definition of HPC resource classes. + +These are used to map resource requirements to the respective +settings in resource management systems (schedulers). +Using the value `None` typically will unset the flag. +It is not guaranteed that all schedulers can realize all the configurations. +""" + import platform +from pathlib import Path + from pydantic import BaseModel, Field class MpiSpec(BaseModel): + """MPI configuration. + + :fields: + max_proc_per_node (int | None): Number of MPI processes per compute node, + defaults to 1. + proc (int | None): Number of MPI processes (ranks), defaults to None (unset). + + """ + max_proc_per_node: int | None = 1 proc: int | None = None class ResourceSpec(BaseModel): + """General resource definitions. + + :fields: + nodes (int): Number of compute nodes, defaults to 1. + cores_per_node (int | None): Number of cores to ues per node, defaults to 1. + memory_gb (int | None): Memory per node in GB, defaults to 4. + gpus_per_nod (int | None): Physical GPUs to reserve on the node, defaults to 0. + """ + nodes: int = 1 cores_per_node: int | None = 1 memory_gb: int | None = 4 @@ -16,10 +43,31 @@ class ResourceSpec(BaseModel): class UserSpec(BaseModel): + """User specific configuration. + + :fields: + mail (str | None): User email to send job updates. + """ + mail: str | None = None # some clusters require this class ContainerSpec(BaseModel): + """Configuration for the use of container images in HPC. + + :warning: + Not fully supported yet. + + :fields: + images (str): URL to the container image. + engine (str): which engine to use. + name (str | None): Explicit image name, defaults to None. + extra_args (dict[str, str | None]): Environment args to pass to the container, + defaults to {}. + env_file (str | None): Path to a file with variable export definitions, + defaults to None. + """ + image: str engine: str # e.g. singularity, docker, enroot? name: str | None = None @@ -28,11 +76,44 @@ class ContainerSpec(BaseModel): class JobSpec(BaseModel): + """Resource definition for an HPC job. + + This is used to generate the job script for the scheduler. + + :fields: + job_name (str): Reference name for the job. + command (str): The command to execute on hpc. + E.g. "mpi run ..." + resource (ResourceSpec): Resource specification for the job. + account: (str | None): Account or group used to submit this job, + defaults to None. + mpi: (MpiSpec | None): The MPI specification. If this is set, will prepend + "mpirun" to the command string, defaults to None. + container: (ContainerSpec | None): The container specification for the job, + defaults to None. + walltime (str): Maximum walltime of the job in HH:MM:SS format, + defaults to "01:00:00". + queue: (str | None): Named queue to submit to, HPC center specific, + defaults to None. + output_path: (Path | None): Explicit job output, if not tkr output will be used, + defaults to None. + error_path: (Path | None): Explicit error output, if not tkr output will be used + defaults to None. + extra_scheduler_args (dict[str, str | None]): Configure additional flags and + options that are not provided in the spec. Flags are set as + extra_scheduler_args["flag_name"] = None, options set as + extra_scheduler_args["option_name"] = "option_value". + Defaults to {} + environment: (dict[str, str]): Provide additional environment variables to the + job, defaults to {}. + include_no_check_directory_flag: (bool): Set "--no-check-directory", + defaults to false. + """ + job_name: str command: str # used instead of popen.input resource: ResourceSpec account: str | None = None - """Account or group used to submit this job.""" mpi: MpiSpec | None = None user: UserSpec | None = None container: ContainerSpec | None = None @@ -46,12 +127,17 @@ class JobSpec(BaseModel): def pjsub_large_spec() -> JobSpec: + """Generate an example large job specification for FUGAKU. + + :return: A job spec running uv on FUGAKU. + :rtype: JobSpec + """ arch = platform.machine() uv_path = Path.home() / ".local" / f"bin_{arch}" / "uv" return JobSpec( job_name="pjsub_large", account="hp240496", - command=f"{str(uv_path)} run main.py", + command=f"{uv_path!s} run main.py", queue="q-QTM-M", resource=ResourceSpec(nodes=32), environment={ @@ -69,12 +155,17 @@ def pjsub_large_spec() -> JobSpec: def pjsub_small_spec() -> JobSpec: + """Generate an example small job specification for FUGAKU. + + :return: A job spec running uv on FUGAKU. + :rtype: JobSpec + """ arch = platform.machine() uv_path = Path.home() / ".local" / f"bin_{arch}" / "uv" return JobSpec( job_name="pjsub_small", account="hp240496", - command=f"{str(uv_path)} run main.py", + command=f"{uv_path!s} run main.py", resource=ResourceSpec(nodes=1), environment={ "VIRTUAL_ENVIRONMENT": "", diff --git a/tierkreis/tierkreis/controller/executor/hpc/pbs.py b/tierkreis/tierkreis/controller/executor/hpc/pbs.py index bf77b5728..05c861808 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/pbs.py +++ b/tierkreis/tierkreis/controller/executor/hpc/pbs.py @@ -1,21 +1,33 @@ +"""Template and Executor for PBS.""" + +# ruff: noqa: ERA001 from pathlib import Path # from typing import Callable # from tierkreis.controller.executor.hpc.hpc_executor import run_hpc_executor from tierkreis.controller.executor.hpc.job_spec import JobSpec - _COMMAND_PREFIX = "#PBS" -def generate_pbs_script(spec: JobSpec) -> str: +def generate_pbs_script(spec: JobSpec) -> str: # noqa: C901, PLR0912 complexity to cover options + """Generate a job submission script according to PBS. + + This uses the "PBS"/qsub syntax and represents a mapping from JobSpec + to the native flags. + + :param spec: The job to generate a script for. + :type spec: JobSpec + :return: A job script for the PBS scheduler. + :rtype: str + """ # 1. Shebang and file header lines = [ """#!/bin/bash # # PBS Job Script generated by TIERKREIS -# --- Core Job Specifications ---""" +# --- Core Job Specifications ---""", ] # 2. Name lines.append(f"{_COMMAND_PREFIX} -N {spec.job_name}") @@ -40,10 +52,9 @@ def generate_pbs_script(spec: JobSpec) -> str: lines.append("\n# --- User Details ---") if spec.account is not None: lines.append(f"{_COMMAND_PREFIX} -A {spec.account}") - if spec.user is not None: - if spec.user.mail is not None: - lines.append(f"{_COMMAND_PREFIX} -m e") # end only - lines.append(f"{_COMMAND_PREFIX} -M {spec.user.mail}") + if spec.user is not None and spec.user.mail is not None: + lines.append(f"{_COMMAND_PREFIX} -m e") # end only + lines.append(f"{_COMMAND_PREFIX} -M {spec.user.mail}") # 5. Output and Error handling lines.append("\n# --- Output and Error Handling ---") @@ -82,7 +93,8 @@ def generate_pbs_script(spec: JobSpec) -> str: lines.append(f"{_COMMAND_PREFIX} -l {key}={value}") if spec.container.env_file is not None: lines.append( - f"{_COMMAND_PREFIX} -l {spec.container.engine}_env_file={spec.container.env_file}" + f"{_COMMAND_PREFIX} -l {spec.container.engine}_env_file" + f"={spec.container.env_file}", ) # check if this makes sense for others beside enroot # 10. User Command, (prologue), command, (epilogue) @@ -93,7 +105,7 @@ def generate_pbs_script(spec: JobSpec) -> str: # Disabled for now, needs testing with a PBS system, will be re-enabled later -# See: Issue #182 +# See: TODO@philipp-seitz: Issue #182 # class PBSExecutor: # def __init__( # self, diff --git a/tierkreis/tierkreis/controller/executor/hpc/pjsub.py b/tierkreis/tierkreis/controller/executor/hpc/pjsub.py index f7abad31f..7ca6815ec 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/pjsub.py +++ b/tierkreis/tierkreis/controller/executor/hpc/pjsub.py @@ -1,7 +1,8 @@ -# from functools import partial +"""Template and Executor for PJSUB(FUGAKU).""" + from functools import partial from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING from tierkreis.controller.executor.hpc.hpc_executor import run_hpc_executor from tierkreis.controller.executor.hpc.job_spec import ( @@ -10,18 +11,30 @@ pjsub_small_spec, ) +if TYPE_CHECKING: + from collections.abc import Callable _COMMAND_PREFIX = "#PJM" -def generate_pjsub_script(spec: JobSpec) -> str: +def generate_pjsub_script(spec: JobSpec) -> str: # noqa: C901, PLR0912 complexity to cover options + """Generate a job submission script according to PJSUB. + + This uses the "PJM"/pjsub syntax and represents a mapping from JobSpec + to the native flags. + + :param spec: The job to generate a script for. + :type spec: JobSpec + :return: A job script for the PJSUB scheduler. + :rtype: str + """ # 1. Shebang and file header lines = [ """#!/bin/bash # # PJSUB Job Script generated by TIERKREIS -# --- Core Job Specifications ---""" +# --- Core Job Specifications ---""", ] # 2. Name lines.append(f"{_COMMAND_PREFIX} -N {spec.job_name}") @@ -39,10 +52,9 @@ def generate_pjsub_script(spec: JobSpec) -> str: lines.append("\n# --- User Details ---") if spec.account is not None: lines.append(f"{_COMMAND_PREFIX} -g {spec.account}") - if spec.user is not None: - if spec.user.mail is not None: - lines.append(f"{_COMMAND_PREFIX} -m e") # end only - lines.append(f"{_COMMAND_PREFIX} --mail-list {spec.user.mail}") + if spec.user is not None and spec.user.mail is not None: + lines.append(f"{_COMMAND_PREFIX} -m e") # end only + lines.append(f"{_COMMAND_PREFIX} --mail-list {spec.user.mail}") # 5. Output and Error handling lines.append("\n# --- Output and Error Handling ---") @@ -61,7 +73,8 @@ def generate_pjsub_script(spec: JobSpec) -> str: lines.append(f'{_COMMAND_PREFIX} --mpi "proc={spec.mpi.proc}"') if spec.mpi.max_proc_per_node is not None: lines.append( - f'{_COMMAND_PREFIX} --mpi "max-proc-per-node={spec.mpi.max_proc_per_node}"' + f'{_COMMAND_PREFIX} --mpi "max-proc-per-node' + f'={spec.mpi.max_proc_per_node}"', ) # 7. User specific @@ -80,13 +93,19 @@ def generate_pjsub_script(spec: JobSpec) -> str: lines.append("\n# --- User Command ---") lines.append(spec.command) - with open("./script", "w+") as fh: + with Path.open(Path("./script"), "w+") as fh: fh.write("\n".join(lines)) return "\n".join(lines) class PJSUBExecutor: + """An executor for the PJSUB submission system. + + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + Implements: :py:class:`tierkreis.controller.executor.hpc.hpc_executor.HPCExecutor` + """ + def __init__( self, registry_path: Path | None, @@ -106,6 +125,13 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: + """Run the node according to ControllerExecutor protocol. + + :param launcher_name: module description of worker to run. + :type launcher_name: str + :param worker_call_args_path: Location of the worker call args. + :type worker_call_args_path: Path + """ self.errors_path = ( self.logs_path.parent.parent / worker_call_args_path.parent / "errors" ) diff --git a/tierkreis/tierkreis/controller/executor/hpc/slurm.py b/tierkreis/tierkreis/controller/executor/hpc/slurm.py index a6d7f5d28..d88ea0559 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/slurm.py +++ b/tierkreis/tierkreis/controller/executor/hpc/slurm.py @@ -1,20 +1,35 @@ +"""Template and executor for SLURM.""" + from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING + from tierkreis.controller.executor.hpc.hpc_executor import run_hpc_executor from tierkreis.controller.executor.hpc.job_spec import JobSpec +if TYPE_CHECKING: + from collections.abc import Callable _COMMAND_PREFIX = "#SBATCH" -def generate_slurm_script(spec: JobSpec) -> str: +def generate_slurm_script(spec: JobSpec) -> str: # noqa: C901, PLR0912 complexity to cover options + """Generate a job submission script according to SLURM. + + This uses the "sbatch" syntax and represents a mapping from JobSpec + to the native flags. + + :param spec: The job to generate a script for. + :type spec: JobSpec + :return: A job script for the SLURM scheduler. + :rtype: str + """ # 1. Shebang and file header lines = [ """#!/bin/bash # # SLURM Job Script generated by TIERKREIS -# --- Core Job Specifications ---""" +# --- Core Job Specifications ---""", ] # 2. Name lines.append(f"{_COMMAND_PREFIX} --job-name={spec.job_name}") @@ -23,7 +38,7 @@ def generate_slurm_script(spec: JobSpec) -> str: lines.append(f"{_COMMAND_PREFIX} --nodes={spec.resource.nodes}") if spec.resource.cores_per_node is not None: lines.append( - f"{_COMMAND_PREFIX} --cpus-per-task={spec.resource.cores_per_node}" + f"{_COMMAND_PREFIX} --cpus-per-task={spec.resource.cores_per_node}", ) if spec.resource.memory_gb is not None: lines.append(f"{_COMMAND_PREFIX} --mem={spec.resource.memory_gb}G") @@ -36,10 +51,9 @@ def generate_slurm_script(spec: JobSpec) -> str: lines.append("\n# --- User Details ---") if spec.account is not None: lines.append(f"{_COMMAND_PREFIX} --account={spec.account}") - if spec.user is not None: - if spec.user.mail is not None: - lines.append(f"{_COMMAND_PREFIX} --mail-type=END") # end only - lines.append(f"{_COMMAND_PREFIX} --mail-user={spec.user.mail}") + if spec.user is not None and spec.user.mail is not None: + lines.append(f"{_COMMAND_PREFIX} --mail-type=END") # end only + lines.append(f"{_COMMAND_PREFIX} --mail-user={spec.user.mail}") # 5. Output and Error handling lines.append("\n# --- Output and Error Handling ---") @@ -48,14 +62,14 @@ def generate_slurm_script(spec: JobSpec) -> str: if spec.output_path is not None: lines.append(f"{_COMMAND_PREFIX} --output={spec.output_path}") - # 6. MPI, #TODO check if this makes sense + # 6. MPI, #TODO@philipp-seitz: check if this makes sense if spec.mpi is not None: lines.append("\n# --- MPI ---") if spec.mpi.proc is not None: lines.append(f"{_COMMAND_PREFIX} --ntasks={spec.mpi.proc}") if spec.mpi.max_proc_per_node is not None: lines.append( - f"{_COMMAND_PREFIX} --ntasks-per-node={spec.mpi.max_proc_per_node}" + f"{_COMMAND_PREFIX} --ntasks-per-node={spec.mpi.max_proc_per_node}", ) # 7. User specific @@ -79,7 +93,8 @@ def generate_slurm_script(spec: JobSpec) -> str: if spec.mpi.max_proc_per_node is None: spec.mpi.max_proc_per_node = 1 lines.append( - f"mpirun -n {spec.resource.nodes * spec.mpi.max_proc_per_node} {spec.command}" + f"mpirun -n {spec.resource.nodes * spec.mpi.max_proc_per_node}" + f" {spec.command}", ) else: lines.append(spec.command) @@ -88,6 +103,12 @@ def generate_slurm_script(spec: JobSpec) -> str: class SLURMExecutor: + """An executor for the SLURM submission system. + + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + Implements: :py:class:`tierkreis.controller.executor.hpc.hpc_executor.HPCExecutor` + """ + def __init__( self, registry_path: Path | None, @@ -107,6 +128,13 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: + """Run the node according to ControllerExecutor protocol. + + :param launcher_name: module description of worker to run. + :type launcher_name: str + :param worker_call_args_path: Location of the worker call args. + :type worker_call_args_path: Path + """ self.errors_path = ( self.logs_path.parent.parent / worker_call_args_path.parent / "errors" ) diff --git a/tierkreis/tierkreis/controller/executor/in_memory_executor.py b/tierkreis/tierkreis/controller/executor/in_memory_executor.py index 4a5a1d85b..69232edbc 100644 --- a/tierkreis/tierkreis/controller/executor/in_memory_executor.py +++ b/tierkreis/tierkreis/controller/executor/in_memory_executor.py @@ -1,21 +1,30 @@ +"""In memory implementation.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) +import importlib.util import json import logging -import importlib.util from pathlib import Path from tierkreis.controller.data.location import WorkerCallArgs from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage -from tierkreis.worker.storage.in_memory import InMemoryWorkerStorage from tierkreis.exceptions import TierkreisError - +from tierkreis.worker.storage.in_memory import InMemoryWorkerStorage logger = logging.getLogger(__name__) class InMemoryExecutor: - """Executes workers in the same process as the controller. + """Execute workers in the same process as the controller. + Loads the worker as python module if possible. + Cannot only run python workers in conjunction with ControllerInMemoryStorage. Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + registry_path (Path): The locations to search for worker modules. + storage (ControllerInMemoryStorage): + Storage reference to access in memory values. """ def __init__(self, registry_path: Path, storage: ControllerInMemoryStorage) -> None: @@ -29,15 +38,17 @@ def run( ) -> None: logger.info("START %s %s", launcher_name, worker_call_args_path) call_args = WorkerCallArgs( - **json.loads(self.storage.read(worker_call_args_path)) + **json.loads(self.storage.read(worker_call_args_path)), ) spec = importlib.util.spec_from_file_location( - "in_memory", self.registry_path / launcher_name / "main.py" + "in_memory", + self.registry_path / launcher_name / "main.py", ) if spec is None or spec.loader is None: + msg = f"Couldn't load main.py in {self.registry_path / launcher_name}" raise TierkreisError( - f"Couldn't load main.py in {self.registry_path / launcher_name}" + msg, ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) diff --git a/tierkreis/tierkreis/controller/executor/multiple.py b/tierkreis/tierkreis/controller/executor/multiple.py index 4dbd7c6ef..f686ea16e 100644 --- a/tierkreis/tierkreis/controller/executor/multiple.py +++ b/tierkreis/tierkreis/controller/executor/multiple.py @@ -1,3 +1,6 @@ +"""A meta executor consisting of multiple single executors.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) from pathlib import Path from tierkreis.controller.executor.protocol import ControllerExecutor @@ -7,7 +10,15 @@ class MultipleExecutor: """Composes multiple executors into a single object. + Will execute all worker tasks on the assigned executor or default. Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + default (ControllerExecutor): The default executor to use for all unspecified + tasks. + executors (dict[str, ControllerExecutor]): A mapping of name -> executor. + assignments (dict[str, string]): A mapping of worker to executor name + """ def __init__( @@ -24,7 +35,6 @@ def run( self, launcher_name: str, worker_call_args_path: Path, - enable_logging: bool = True, ) -> None: executor_name = self.assignments.get(launcher_name, None) # If there is no assignment for the worker, use the default. @@ -32,8 +42,12 @@ def run( return self.default.run(launcher_name, worker_call_args_path) executor = self.executors.get(executor_name) if executor is None: + msg = ( + f"{launcher_name} is assigned to non-existent" + f" executor name: {executor_name}." + ) raise TierkreisError( - f"{launcher_name} is assigned to non-existent executor name: {executor_name}." + msg, ) return executor.run(launcher_name, worker_call_args_path) diff --git a/tierkreis/tierkreis/controller/executor/protocol.py b/tierkreis/tierkreis/controller/executor/protocol.py index 7d3a11cf4..872eccadb 100644 --- a/tierkreis/tierkreis/controller/executor/protocol.py +++ b/tierkreis/tierkreis/controller/executor/protocol.py @@ -1,3 +1,5 @@ +"""The base executor protocol.""" + from pathlib import Path from typing import Protocol @@ -14,16 +16,20 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: - """Run the node defined by the node_definition path. + """Run the node defined by the worker_call_args_path path. Specifies the worker to run by its launcher name. - For example the function "builtins.iadd" will call the builtins worker's iadd function. - The call arguments for the function call are retrieved retrieved from its location. + For example the function "builtins.iadd" will call the builtins worker's + iadd function. The call arguments for the function call are retrieved retrieved + from its location. + + The executor ensures workers are progressed correctly; This includes: + - setting up error and log files and making them available + - checking progress (e.g. _done file) + - enabling path resolution between tkr paths and worker inputs - :param launcher_name: module description of launcher to run. + :param launcher_name: module description of worker to run. :type launcher_name: str :param worker_call_args_path: Location of the worker call args. :type worker_call_args_path: Path """ - - ... diff --git a/tierkreis/tierkreis/controller/executor/shell_executor.py b/tierkreis/tierkreis/controller/executor/shell_executor.py index 72653ae94..0a813b24f 100644 --- a/tierkreis/tierkreis/controller/executor/shell_executor.py +++ b/tierkreis/tierkreis/controller/executor/shell_executor.py @@ -1,3 +1,6 @@ +"""Default executor for arbitrary scripts.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import json import os import subprocess @@ -11,7 +14,19 @@ class ShellExecutor: """Executes workers in an unix shell. + Simply runs any shell script as a worker, if certain conditions on input/output + conditions are met, namely the paths/values are provided through the process + environment and the script is responsible for reading/writing them. + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + launchers_path (Path): The locations to search for external workers. + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + workflow_dir (Path): The workflow dir to resolve relative paths. + timeout (int): Timeout for the process communication, defaults to 10 seconds. + env: (dict[str,str]): Additional environments to hand to the spawned subprocess. """ def __init__( @@ -32,50 +47,70 @@ def run( self, launcher_name: str, worker_call_args_path: Path, + *, export_values: bool = False, ) -> None: launcher_path = self.launchers_path / launcher_name self.errors_path = worker_call_args_path.parent / "logs" if not launcher_path.exists(): - raise TierkreisError(f"Launcher not found: {launcher_name}.") + msg = f"Launcher not found: {launcher_name}." + raise TierkreisError(msg) if launcher_path.is_dir() and not (launcher_path / "main.sh").exists(): - raise TierkreisError(f"Expected launcher file. Got {launcher_path}.") + msg = f"Expected launcher file. Got {launcher_path}." + raise TierkreisError(msg) if launcher_path.is_dir() and not (launcher_path / "main.sh").is_file(): - raise TierkreisError(f"Expected launcher file. Got {launcher_path}/main.sh") + msg = f"Expected launcher file. Got {launcher_path}/main.sh" + raise TierkreisError(msg) if launcher_path.is_dir() and (launcher_path / "main.sh").is_file(): launcher_path = launcher_path / "main.sh" - with open(self.workflow_dir.parent / worker_call_args_path) as fh: + with Path.open(self.workflow_dir.parent / worker_call_args_path) as fh: call_args = WorkerCallArgs(**json.load(fh)) env = os.environ.copy() | self.env.copy() - env.update(self._create_env(call_args, self.workflow_dir.parent, export_values)) + env.update( + self._create_env( + call_args, + self.workflow_dir.parent, + export_values=export_values, + ), + ) env["worker_call_args_file"] = str( - self.workflow_dir.parent / worker_call_args_path + self.workflow_dir.parent / worker_call_args_path, ) done_path = self.workflow_dir.parent / call_args.done_path _error_path = self.errors_path.parent / "_error" if TKR_DIR_KEY not in env: env[TKR_DIR_KEY] = str(self.logs_path.parent.parent) - tee_str = f">(tee -a {str(self.errors_path)} {str(self.logs_path)} >/dev/null)" + tee_str = f">(tee -a {self.errors_path!s} {self.logs_path!s} >/dev/null)" proc = subprocess.Popen( - ["bash"], + ["/bin/bash"], start_new_session=True, stdin=subprocess.PIPE, env=env, ) proc.communicate( - f"({launcher_path} {worker_call_args_path} > {tee_str} 2> {tee_str} && touch {done_path}|| touch {_error_path})&".encode(), + f"({launcher_path} {worker_call_args_path} > {tee_str} 2> {tee_str} " + f"&& touch {done_path}|| touch {_error_path})&".encode(), timeout=self.timeout, ) def _create_env( - self, call_args: WorkerCallArgs, base_dir: Path, export_values: bool + self, + call_args: WorkerCallArgs, + base_dir: Path, + *, + export_values: bool, ) -> dict[str, str]: + """Set up an environment as interface between controller and worker function. + + If export_values is set, will also write the values of ports to the env. + This is useful if you don't want / can't read the files directly. + """ env = { "checkpoints_directory": str(base_dir), "function_name": str(base_dir / call_args.function_name), @@ -97,6 +132,6 @@ def _create_env( return env values = {} for k, v in call_args.inputs.items(): - with open(v) as fh: + with Path.open(v) as fh: values[f"input_{k}_value"] = fh.read() return env diff --git a/tierkreis/tierkreis/controller/executor/stdinout.py b/tierkreis/tierkreis/controller/executor/stdinout.py index 8025115d2..d40bdd190 100644 --- a/tierkreis/tierkreis/controller/executor/stdinout.py +++ b/tierkreis/tierkreis/controller/executor/stdinout.py @@ -1,3 +1,6 @@ +"""Special case implementation for external workers.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import json import subprocess from pathlib import Path @@ -9,7 +12,17 @@ class StdInOut: """Executes workers in an unix shell. + Assumes the worker takes a single input from stdin and will produce a single output + to stdout. + Will pipe other outputs to errors / logs. + Works by creating a subprocess Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + launchers_path (Path): The locations to search for external workers. + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + workflow_dir (Path): The workflow dir to resolve relative paths. """ def __init__(self, registry_path: Path, workflow_dir: Path) -> None: @@ -26,32 +39,36 @@ def run( launcher_path = self.launchers_path / launcher_name self.errors_path = worker_call_args_path.parent / "errors" if not launcher_path.exists(): - raise TierkreisError(f"Launcher not found: {launcher_name}.") + msg = f"Launcher not found: {launcher_name}." + raise TierkreisError(msg) if launcher_path.is_dir() and not (launcher_path / "main.sh").exists(): - raise TierkreisError(f"Expected launcher file. Got {launcher_path}.") + msg = f"Expected launcher file. Got {launcher_path}." + raise TierkreisError(msg) if launcher_path.is_dir() and not (launcher_path / "main.sh").is_file(): - raise TierkreisError(f"Expected launcher file. Got {launcher_path}/main.sh") + msg = f"Expected launcher file. Got {launcher_path}/main.sh" + raise TierkreisError(msg) if launcher_path.is_dir() and (launcher_path / "main.sh").is_file(): launcher_path = launcher_path / "main.sh" - with open(self.workflow_dir.parent / worker_call_args_path) as fh: + with Path.open(self.workflow_dir.parent / worker_call_args_path) as fh: call_args = WorkerCallArgs(**json.load(fh)) - input_file = self.workflow_dir.parent / list(call_args.inputs.values())[0] - output_file = self.workflow_dir.parent / list(call_args.outputs.values())[0] + input_file = self.workflow_dir.parent / next(iter(call_args.inputs.values())) + output_file = self.workflow_dir.parent / next(iter(call_args.outputs.values())) done_path = self.workflow_dir.parent / call_args.done_path - tee_str = f">(tee -a {str(self.errors_path)} {str(self.logs_path)} >/dev/null)" + tee_str = f">(tee -a {self.errors_path!s} {self.logs_path!s} >/dev/null)" _error_path = self.errors_path.parent / "_error" proc = subprocess.Popen( - ["bash"], + ["/bin/bash"], start_new_session=True, stdin=subprocess.PIPE, ) proc.communicate( - f"({launcher_path} <{input_file} > {output_file} 2> {tee_str} && touch {done_path}|| touch {_error_path})&".encode(), + f"({launcher_path} <{input_file} > {output_file} 2> {tee_str}" + f" && touch {done_path}|| touch {_error_path})&".encode(), timeout=10, ) diff --git a/tierkreis/tierkreis/controller/executor/uv_executor.py b/tierkreis/tierkreis/controller/executor/uv_executor.py index a3e3b488f..5835fd3f2 100644 --- a/tierkreis/tierkreis/controller/executor/uv_executor.py +++ b/tierkreis/tierkreis/controller/executor/uv_executor.py @@ -1,3 +1,6 @@ +"""Default python executor based on uv.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import logging import os import shutil @@ -13,11 +16,26 @@ class UvExecutor: """Executes workers in an UV python environment. + Depends on uv to run, hence the worker needs a pyproject.toml / a respective script. + Works out of the box with the cli worker definitions. + The env field can be used to provide additional variables; for example + controlling the python / uv version through $VIRTUAL_ENVIRONMENT. + Also to resolve paths, the $TKR_DIR will be set to the workflow directory. + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + launchers_path (Path): The locations to search for python workers. + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + env: (dict[str,str]): Additional environments to hand to the spawned subprocess. """ def __init__( - self, registry_path: Path, logs_path: Path, env: dict[str, str] | None = None + self, + registry_path: Path, + logs_path: Path, + env: dict[str, str] | None = None, ) -> None: self.launchers_path = registry_path self.logs_path = logs_path @@ -33,14 +51,15 @@ def run( self.errors_path = ( self.logs_path.parent.parent / worker_call_args_path.parent - / "logs" # made we should change this + / "logs" # maybe we should change this ) logger.info("START %s %s", launcher_name, worker_call_args_path) if uv_path is None: uv_path = shutil.which("uv") if uv_path is None: - raise TierkreisError("uv is required to use the uv_executor") + msg = "uv is required to use the uv_executor" + raise TierkreisError(msg) worker_path = self.launchers_path / launcher_name @@ -50,15 +69,16 @@ def run( if TKR_DIR_KEY not in env: env[TKR_DIR_KEY] = str(self.logs_path.parent.parent) _error_path = self.errors_path.parent / "_error" - tee_str = f">(tee -a {str(self.errors_path)} {str(self.logs_path)} >/dev/null)" + tee_str = f">(tee -a {self.errors_path!s} {self.logs_path!s} >/dev/null)" proc = subprocess.Popen( - ["bash"], + ["/bin/bash"], start_new_session=True, stdin=subprocess.PIPE, cwd=worker_path, env=env, ) proc.communicate( - f"({uv_path} run main.py {worker_call_args_path} > {tee_str} 2> {tee_str} || touch {_error_path}) &".encode(), + f"({uv_path} run main.py {worker_call_args_path} > {tee_str} 2> {tee_str}" + f" || touch {_error_path}) &".encode(), timeout=10, ) diff --git a/tierkreis/tierkreis/controller/start.py b/tierkreis/tierkreis/controller/start.py index b2b44e8b9..68771467e 100644 --- a/tierkreis/tierkreis/controller/start.py +++ b/tierkreis/tierkreis/controller/start.py @@ -1,23 +1,22 @@ -from dataclasses import dataclass import logging -from pathlib import Path import subprocess import sys - -from tierkreis.controller.data.core import PortID -from tierkreis.controller.data.types import bytes_from_ptype, ptype_from_bytes -from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor -from tierkreis.controller.storage.adjacency import outputs_iter -from typing_extensions import assert_never +from dataclasses import dataclass +from pathlib import Path +from typing import assert_never from tierkreis.consts import PACKAGE_PATH +from tierkreis.controller.data.core import PortID from tierkreis.controller.data.graph import Eval, GraphData, NodeDef from tierkreis.controller.data.location import Loc, OutputLoc +from tierkreis.controller.data.types import bytes_from_ptype, ptype_from_bytes +from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor from tierkreis.controller.executor.protocol import ControllerExecutor -from tierkreis.controller.storage.protocol import ControllerStorage +from tierkreis.controller.storage.adjacency import outputs_iter from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage -from tierkreis.labels import Labels +from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.exceptions import TierkreisError +from tierkreis.labels import Labels logger = logging.getLogger(__name__) @@ -44,7 +43,7 @@ def start_nodes( def run_builtin(def_path: Path, logs_path: Path) -> None: logger.info("START builtin %s", def_path) - with open(logs_path, "a") as fh: + with Path.open(logs_path, "a") as fh: subprocess.Popen( [sys.executable, "main.py", def_path], start_new_session=True, @@ -58,7 +57,6 @@ def start( storage: ControllerStorage, executor: ControllerExecutor, node_run_data: NodeRunData, - enable_logging: bool = True, ) -> None: node_location = node_run_data.node_location node = node_run_data.node @@ -68,22 +66,27 @@ def start( parent = node_location.parent() if parent is None: - raise TierkreisError(f"{node.type} node must have parent Loc.") + msg = f"{node.type} node must have parent Loc." + raise TierkreisError(msg) ins = {k: (parent.N(idx), p) for k, (idx, p) in node.inputs.items()} - logger.debug(f"start {node_location} {node} {ins} {output_list}") + logger.debug("start %s %s %s %s", node_location, node, ins, output_list) if node.type == "function": name = node.function_name launcher_name = ".".join(name.split(".")[:-1]) name = name.split(".")[-1] call_args_path = storage.write_worker_call_args( - node_location, name, ins, output_list + node_location, + name, + ins, + output_list, ) - logger.debug(f"Executing {(str(node_location), name, ins, output_list)}") + logger.debug("Executing %s", (str(node_location), name, ins, output_list)) if isinstance(storage, ControllerInMemoryStorage) and isinstance( - executor, InMemoryExecutor + executor, + InMemoryExecutor, ): executor.run(launcher_name, call_args_path) elif launcher_name == "builtins": @@ -118,9 +121,7 @@ def start( elif node.type == "loop": ins["body"] = (parent.N(node.body[0]), node.body[1]) pipe_inputs_to_output_location(storage, node_location.N(-1), ins) - if ( - node.name is not None - ): # should we do this only in debug mode? -> need to think through how this would work + if node.name is not None: storage.write_debug_data(node.name, node_location) start( storage, @@ -146,17 +147,17 @@ def start( else: eval_inputs[k] = (i, port) pipe_inputs_to_output_location( - storage, node_location.M(idx).N(-1), eval_inputs + storage, + node_location.M(idx).N(-1), + eval_inputs, ) # Necessary in the node visualization storage.write_node_def( - node_location.M(idx), Eval((-1, "body"), node.inputs, node.outputs) + node_location.M(idx), + Eval((-1, "body"), node.inputs, node.outputs), ) - elif node.type == "ifelse": - pass - - elif node.type == "eifelse": + elif node.type in {"ifelse", "eifelse"}: pass else: assert_never(node) diff --git a/tierkreis/tierkreis/controller/storage/__init__.py b/tierkreis/tierkreis/controller/storage/__init__.py index e69de29bb..085e534b4 100644 --- a/tierkreis/tierkreis/controller/storage/__init__.py +++ b/tierkreis/tierkreis/controller/storage/__init__.py @@ -0,0 +1 @@ +"""Storage definitions for the controller.""" diff --git a/tierkreis/tierkreis/controller/storage/adjacency.py b/tierkreis/tierkreis/controller/storage/adjacency.py index 2c5e1e3a7..1bc8a1866 100644 --- a/tierkreis/tierkreis/controller/storage/adjacency.py +++ b/tierkreis/tierkreis/controller/storage/adjacency.py @@ -1,3 +1,5 @@ +"""Graph information based on adjacency.""" + import logging from typing import assert_never @@ -12,7 +14,19 @@ def in_edges(node: NodeDef) -> dict[PortID, ValueRef]: - parents = {k: v for k, v in node.inputs.items()} + """Find the incoming edges of a node. + + Finds all the defined inputs and adds the special constructions: + - Graph body for map, loop, eval + - Prediction for ifelse + - All nodes for eager if else + + :param node: The node to evaluate. + :type node: NodeDef + :return: MApping of port names to value references. + :rtype: dict[PortID, ValueRef] + """ + parents = dict(node.inputs.items()) match node.type: case "eval": @@ -36,13 +50,40 @@ def in_edges(node: NodeDef) -> dict[PortID, ValueRef]: def unfinished_inputs( - storage: ControllerStorage, loc: Loc, node: NodeDef + storage: ControllerStorage, + loc: Loc, + node: NodeDef, ) -> list[ValueRef]: + """Find the unfinished inputs of a node. + + :param storage: The storage to write from. + :type storage: ControllerStorage + :param loc: The node location to check for. + :type loc: Loc + :param node: The node definition containing the output names. + :type node: NodeDef + :return: A list of references to node inputs. + :rtype: list[ValueRef] + """ ins = in_edges(node).values() - ins = [x for x in ins if x[0] >= 0] # inputs at -1 already finished + ins = [x for x in ins if x[0] >= 0] # inputs at -1 already finished they're linked return [x for x in ins if not storage.is_node_finished(loc.N(x[0]))] def outputs_iter(storage: ControllerStorage, loc: Loc) -> list[tuple[int, PortID]]: + """Find all the outputs of a node and provide them with their index as map elements. + + This is only used in map nodes to go from the * port to the values + of actual map elements. + This can be from an unfold where we get (index, index) + or map (index, "eval_output_name-index") + + :param storage: The storage to read from. + :type storage: ControllerStorage + :param loc: The location to get the outputs from. + :type loc: Loc + :return: A tuple of (index, portname) of + :rtype: list[tuple[int, PortID]] + """ eles = storage.read_output_ports(loc) return [(int(x.split("-")[-1]), x) for x in eles] diff --git a/tierkreis/tierkreis/controller/storage/filestorage.py b/tierkreis/tierkreis/controller/storage/filestorage.py index 41885cbc3..82e14e435 100644 --- a/tierkreis/tierkreis/controller/storage/filestorage.py +++ b/tierkreis/tierkreis/controller/storage/filestorage.py @@ -1,21 +1,33 @@ +"""Default file storage implementation.""" + import os import shutil from pathlib import Path from time import time_ns +from typing import override from uuid import UUID from tierkreis.controller.storage.protocol import ( - StorageEntryMetadata, ControllerStorage, + StorageEntryMetadata, ) +DEFAULT_DIRECTORY = Path.home() / ".tierkreis" / "checkpoints" + class ControllerFileStorage(ControllerStorage): + """Storage backend using the filesystem. + + This storage implementation operates by relegating calls to the os filesystem. + Calling with `do_cleanup` will ensure that previous runs are deleted. + """ + def __init__( self, workflow_id: UUID, name: str | None = None, - tierkreis_directory: Path = Path.home() / ".tierkreis" / "checkpoints", + tierkreis_directory: Path = DEFAULT_DIRECTORY, + *, do_cleanup: bool = False, ) -> None: self.tkr_dir = tierkreis_directory @@ -24,6 +36,7 @@ def __init__( if do_cleanup: self.delete(self.workflow_dir) + @override def delete(self, path: Path) -> None: uid = os.getuid() tmp_dir = Path(f"/tmp/{uid}/tierkreis/archive/{self.workflow_id}/{time_ns()}") @@ -31,26 +44,32 @@ def delete(self, path: Path) -> None: if self.exists(path): shutil.move(path, tmp_dir) + @override def exists(self, path: Path) -> bool: return path.exists() + @override def list_subpaths(self, path: Path) -> list[Path]: - return [sub_path for sub_path in path.iterdir()] + return list(path.iterdir()) + @override def link(self, src: Path, dst: Path) -> None: dst.parent.mkdir(parents=True, exist_ok=True) if dst.exists() and dst.resolve() == src: return # We have already linked correctly os.link(src, dst) + @override def mkdir(self, path: Path) -> None: return path.mkdir(parents=True, exist_ok=True) + @override def read(self, path: Path) -> bytes: - with open(path, "rb") as fh: + with Path.open(path, "rb") as fh: return fh.read() - def touch(self, path: Path, is_dir: bool = False) -> None: + @override + def touch(self, path: Path, *, is_dir: bool = False) -> None: if is_dir: path.mkdir(parents=True, exist_ok=True) return @@ -58,10 +77,12 @@ def touch(self, path: Path, is_dir: bool = False) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.touch() + @override def stat(self, path: Path) -> StorageEntryMetadata: return StorageEntryMetadata(path.stat().st_mtime) + @override def write(self, path: Path, value: bytes) -> None: path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "wb+") as fh: + with Path.open(path, "wb+") as fh: fh.write(value) diff --git a/tierkreis/tierkreis/controller/storage/graphdata.py b/tierkreis/tierkreis/controller/storage/graphdata.py index 168b1484c..2bcc8c69d 100644 --- a/tierkreis/tierkreis/controller/storage/graphdata.py +++ b/tierkreis/tierkreis/controller/storage/graphdata.py @@ -1,9 +1,11 @@ +"""Virtual GraphStorage for visualization.""" + from pathlib import Path +from typing import Any, override from uuid import UUID -from typing import Any - from pydantic import BaseModel, Field + from tierkreis.controller.data.core import PortID from tierkreis.controller.data.graph import ( Eval, @@ -13,8 +15,8 @@ ) from tierkreis.controller.data.location import Loc, OutputLoc, WorkerCallArgs from tierkreis.controller.storage.protocol import ( - StorageEntryMetadata, ControllerStorage, + StorageEntryMetadata, ) from tierkreis.exceptions import TierkreisError @@ -34,6 +36,13 @@ class NodeData(BaseModel): class GraphDataStorage(ControllerStorage): + """Storage backend using in-memory GraphData for workflow execution. + + This storage implementation operates read-only on a GraphData object without + writing to disk. + Used for visualization without running the workflow. + """ + def __init__( self, workflow_id: UUID, @@ -46,36 +55,57 @@ def __init__( self.graph = graph self.tkr_dir = Path.home() / ".tierkreis" + @override def delete(self, path: Path) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def exists(self, path: Path) -> bool: - raise NotImplementedError("GraphDataStorage is only for graph construction.") + msg = "GraphDataStorage is only for graph construction." + raise NotImplementedError(msg) + @override def list_subpaths(self, path: Path) -> list[Path]: - raise NotImplementedError("GraphDataStorage uses GraphData not paths.") + msg = "GraphDataStorage uses GraphData not paths." + raise NotImplementedError(msg) + @override def link(self, src: Path, dst: Path) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def mkdir(self, path: Path) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read(self, path: Path) -> bytes: - raise NotImplementedError("GraphDataStorage uses GraphData not paths.") + msg = "GraphDataStorage uses GraphData not paths." + raise NotImplementedError(msg) + @override def stat(self, path: Path) -> StorageEntryMetadata: - raise NotImplementedError("GraphDataStorage is only for graph construction.") + msg = "GraphDataStorage is only for graph construction." + raise NotImplementedError(msg) - def touch(self, path: Path, is_dir: bool = False) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + @override + def touch(self, path: Path) -> None: + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def write(self, path: Path, value: bytes) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def write_node_def(self, node_location: Loc, node: NodeDef) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_node_def(self, node_location: Loc) -> NodeDef: try: if node_location.pop_last()[0][0] in ["M", "L"]: @@ -85,6 +115,7 @@ def read_node_def(self, node_location: Loc) -> NodeDef: node, _ = graph_node_from_loc(node_location, self.graph) return node + @override def write_worker_call_args( self, node_location: Loc, @@ -92,28 +123,39 @@ def write_worker_call_args( inputs: dict[PortID, OutputLoc], output_list: list[PortID], ) -> Path: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_worker_call_args(self, node_location: Loc) -> WorkerCallArgs: + msg = f"Node location {node_location} doesn't have a associate call args." raise TierkreisError( - f"Node location {node_location} doesn't have a associate call args." + msg, ) + @override def read_errors(self, node_location: Loc) -> str: return "" + @override def node_has_error(self, node_location: Loc) -> bool: return False + @override def write_node_errors(self, node_location: Loc, error_logs: str) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def mark_node_finished(self, node_location: Loc) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def is_node_finished(self, node_location: Loc) -> bool: return False + @override def link_outputs( self, new_location: Loc, @@ -121,16 +163,23 @@ def link_outputs( old_location: Loc, old_port: PortID, ) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def write_output( - self, node_location: Loc, output_name: PortID, value: bytes + self, + node_location: Loc, + output_name: PortID, + value: bytes, ) -> Path: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_output(self, node_location: Loc, output_name: PortID) -> bytes: node, graph = graph_node_from_loc(node_location, self.graph) - if -1 == node_location.peek_index() and output_name == "body": + if node_location.peek_index() == -1 and output_name == "body": return graph.model_dump_json().encode() outputs = _build_node_outputs(node) @@ -138,31 +187,39 @@ def read_output(self, node_location: Loc, output_name: PortID) -> bytes: if output := outputs[output_name]: return output return b"null" - raise TierkreisError(f"No output named {output_name} in node {node_location}") + msg = f"No output named {output_name} in node {node_location}" + raise TierkreisError(msg) + @override def read_output_ports(self, node_location: Loc) -> list[PortID]: node, _ = graph_node_from_loc(node_location, self.graph) outputs = _build_node_outputs(node) return list(filter(lambda k: k != "*", outputs.keys())) + @override def is_node_started(self, node_location: Loc) -> bool: return False + @override def read_metadata(self, node_location: Loc) -> dict[str, Any]: return self.nodes[node_location].metadata + @override def write_metadata(self, node_location: Loc) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_started_time(self, node_location: Loc) -> str | None: return None + @override def read_finished_time(self, node_location: Loc) -> str | None: return None def _build_node_outputs(node: NodeDef) -> dict[PortID, None | bytes]: - outputs: dict[PortID, None | bytes] = {val: None for val in node.outputs} + outputs: dict[PortID, None | bytes] = dict.fromkeys(node.outputs) if "*" in outputs: outputs["0"] = None return outputs diff --git a/tierkreis/tierkreis/controller/storage/in_memory.py b/tierkreis/tierkreis/controller/storage/in_memory.py index fc5fe3e0d..ffd2cc925 100644 --- a/tierkreis/tierkreis/controller/storage/in_memory.py +++ b/tierkreis/tierkreis/controller/storage/in_memory.py @@ -1,14 +1,24 @@ +"""In memory implementation of a storage layer.""" + from pathlib import Path -from uuid import UUID from time import time +from typing import override +from uuid import UUID from tierkreis.controller.storage.protocol import ( - StorageEntryMetadata, ControllerStorage, + StorageEntryMetadata, ) class InMemoryFileData: + """Class to emulate the file system behaviour in memory. + + :fields: + value (bytes): The content of a file, typically used for outputs or empty. + stats (StorageEntryMetadata): A metadata entry. + """ + value: bytes stats: StorageEntryMetadata @@ -18,6 +28,13 @@ def __init__(self, value: bytes) -> None: class ControllerInMemoryStorage(ControllerStorage): + """In-memory implementation of ControllerStorage. + + Stores workflow files in memory using a dictionary instead of the filesystem. + Uses a mapping Path -> FileData to emulate the required filesystem structure. + Useful when debugging applications in conjunction with the InMemoryExecutor. + """ + def __init__( self, workflow_id: UUID, @@ -30,38 +47,45 @@ def __init__( self.files: dict[Path, InMemoryFileData] = {} + @override def delete(self, path: Path) -> None: self.files = {} + @override def exists(self, path: Path) -> bool: return path in list(self.files.keys()) + @override def list_subpaths(self, path: Path) -> list[Path]: if path == self.workflow_dir: - nodes = set( - [ - Path("/".join(str(x).split("/")[:2])) - for x in self.files.keys() - if str(x).startswith(str(path) + "/") - ] - ) + nodes = { + Path("/".join(str(x).split("/")[:2])) + for x in self.files + if str(x).startswith(str(path) + "/") + } return list(nodes) - return [x for x in self.files.keys() if str(x).startswith(str(path) + "/")] + return [x for x in self.files if str(x).startswith(str(path) + "/")] + @override def link(self, src: Path, dst: Path) -> None: self.files[dst] = self.files[src] + @override def mkdir(self, path: Path) -> None: return + @override def read(self, path: Path) -> bytes: return self.files[path].value - def touch(self, path: Path, is_dir: bool = False) -> None: + @override + def touch(self, path: Path) -> None: self.files[path] = InMemoryFileData(b"") + @override def stat(self, path: Path) -> StorageEntryMetadata: return self.files[path].stats + @override def write(self, path: Path, value: bytes) -> None: self.files[path] = InMemoryFileData(value) diff --git a/tierkreis/tierkreis/controller/storage/protocol.py b/tierkreis/tierkreis/controller/storage/protocol.py index 54f0f9830..c06db3d64 100644 --- a/tierkreis/tierkreis/controller/storage/protocol.py +++ b/tierkreis/tierkreis/controller/storage/protocol.py @@ -1,14 +1,17 @@ -from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass -from datetime import datetime +"""The storage interface.""" + import json import logging +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from datetime import UTC, datetime from pathlib import Path from typing import Any, assert_never from uuid import UUID + +from tierkreis.controller.data.core import PortID from tierkreis.controller.data.graph import NodeDef, NodeDefModel from tierkreis.controller.data.location import Loc, OutputLoc, WorkerCallArgs -from tierkreis.controller.data.core import PortID from tierkreis.exceptions import TierkreisError logger = logging.getLogger(__name__) @@ -18,79 +21,162 @@ class StorageEntryMetadata: """Collection of commonly found metadata. - Storage implementations should decide which are applicable.""" + Storage implementations should decide which are applicable. + :fields: + st_mtime (float | None): The start time of a node, defaults to None. + """ st_mtime: float | None = None @dataclass class StorageDebugData: - """Collection of commonly found debugdata. + """Collection of commonly found debug data. Currently only used for loop_nodes - Storage implementations should decide which are applicable.""" + Storage implementations should decide which are applicable. + :fields: + loop_loc (str | None): The Loc of loop node, defaults to None. + Can only be known after the node has run. + """ loop_loc: str | None = None class ControllerStorage(ABC): + """Storage interface for the tierkreis controller. + + :abstract: + Conceptually, the storage layer represents the current state of the operation. + It includes information such as: + - Checkpoints + - Node definitions + - Outputs and their values + - Metadata + - Debug Data + + This interface primarily targets filesystems as underlying storage method. + For storages based on other methods, you manually need to map paths + representing node locations to an internal address. + This interface already handles translations from Locs to Paths + + :fields: + tkr_dir (Path): The base of the storage. + workflow_id (UUID): The unique workflow id. + name (str | None): Optional name for a workflow, defaults to None. + + :properties: + workflow_dir (Path): The workflow storage location. + logs_path (Path): Location of the workflow logs. + debug_path (Path): Location of the workflow debug information. + """ + tkr_dir: Path workflow_id: UUID name: str | None @abstractmethod def delete(self, path: Path) -> None: - """Delete the storage entry at the specified path. + r"""Delete the storage entry at the specified path. - Also delete any related data of the form \"{path}/**/*\".""" + Also delete any related data of the form \"{path}/**/*\". + Only necessary for persistent storage types. + + :param path: The storage location to delete. + :type path: Path + """ @abstractmethod def exists(self, path: Path) -> bool: - """Is there an entry in the storage at the specified path?""" + """Check whether there is an entry in the storage at the specified path. + + :param path: The storage location to check. + :type path: Path + :return: True if it exists. + :rtype: bool + """ @abstractmethod def link(self, src: Path, dst: Path) -> None: - """The storage entry at `dst` should have the same value as the entry at `src`.""" + """Link storage entry at `dst` to have the same value as the entry at `src`. + + :param src: The source entry. + :type src: Path + :param dst: The destination entry. + :type dst: Path + """ @abstractmethod def list_subpaths(self, path: Path) -> list[Path]: - """List all the paths starting with the specified path. + """List all the paths starting with the specified path in the storage. This is used when the number of entries can only be determined at runtime. - For example in a map node.""" + For example in a map node. + + :param path: The location to list children in. + :type path: Path + :return: A list of child entries. + :rtype: list[Path] + """ @abstractmethod def mkdir(self, path: Path) -> None: - """Create an empty directory (and parents) at this path. + """Create an empty directory (and parents) at this storage location. + + Probably only required for file-based storage. - Probably only required for file-based storage.""" + :param path: The location where to create the directory. + :type path: Path + """ @abstractmethod def read(self, path: Path) -> bytes: - """Read the storage entry at the specified path.""" + """Read the storage entry at the specified path. + + :param path: The location to read from. + :type path: Path + :return: The value at the storage location. + :rtype: bytes + """ @abstractmethod def stat(self, path: Path) -> StorageEntryMetadata: - """Get applicable stats for storage entry.""" + """Get applicable stats for storage entry. + + :param path: The location to get stats for. + :type path: Path + :return: The stats of this entry. + :rtype: StorageEntryMetadata + """ @abstractmethod def touch(self, path: Path) -> None: - """Create empty storage entry at the specified path.""" + """Create empty storage entry at the specified location. + + :param path: The location where to generate the entry. + :type path: Path + """ @abstractmethod def write(self, path: Path, value: bytes) -> None: - """Write the given bytes to the storage entry at the specified path.""" + """Write the given bytes to the storage entry at the specified path. + + :param path: The location to write to. + :type path: Path + :param value: The value to write. + :type value: bytes + """ @property - def workflow_dir(self) -> Path: + def workflow_dir(self) -> Path: # noqa: D102 documented in class return self.tkr_dir / str(self.workflow_id) @property - def logs_path(self) -> Path: + def logs_path(self) -> Path: # noqa: D102 documented in class return self.workflow_dir / "logs" @property - def debug_path(self) -> Path: + def debug_path(self) -> Path: # noqa: D102 documented in class return self.workflow_dir / "debug" def _nodedef_path(self, node_location: Loc) -> Path: @@ -118,13 +204,28 @@ def _worker_logs_path(self, node_location: Loc) -> Path: return self.workflow_dir / str(node_location) / "logs" def clean_graph_files(self) -> None: + """Delete the workflow directory of a graph.""" self.delete(self.workflow_dir) - def write_node_def(self, node_location: Loc, node: NodeDef): + def write_node_def(self, node_location: Loc, node: NodeDef) -> None: + """Write a node definition to storage. + + :param node_location: The location to write to. + :type node_location: Loc + :param node: The node definition to write. + :type node: NodeDef + """ bs = NodeDefModel(root=node).model_dump_json().encode() self.write(self._nodedef_path(node_location), bs) def read_node_def(self, node_location: Loc) -> NodeDef: + """Read the definition of a node in storage. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The retrieved node definition. + :rtype: NodeDef + """ bs = self.read(self._nodedef_path(node_location)) return NodeDefModel(**json.loads(bs)).root @@ -135,6 +236,19 @@ def write_worker_call_args( inputs: dict[PortID, OutputLoc], output_list: list[PortID], ) -> Path: + """Write the call arguments for a worker task to storage. + + :param node_location: The location to write to. + :type node_location: Loc + :param function_name: The task name. + :type function_name: str + :param inputs: The inputs to the task (outputs from previous nodes). + :type inputs: dict[PortID, OutputLoc] + :param output_list: The list of output ports of the task. + :type output_list: list[PortID] + :return: The respective path in storage. + :rtype: Path + """ call_args_path = self._worker_call_args_path(node_location) node_definition = WorkerCallArgs( function_name=function_name, @@ -160,6 +274,13 @@ def write_worker_call_args( return call_args_path.relative_to(self.tkr_dir) def read_worker_call_args(self, node_location: Loc) -> WorkerCallArgs: + """Read the worker call arguments from storage. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The call arguments. + :rtype: WorkerCallArgs + """ node_definition_path = self._worker_call_args_path(node_location) return WorkerCallArgs(**json.loads(self.read(node_definition_path))) @@ -170,30 +291,79 @@ def link_outputs( old_location: Loc, old_port: PortID, ) -> None: + """Link an output from one node to another. + + Linking ensures that the values at both locations are the same. + + :param new_location: The new location to write to. + :type new_location: Loc + :param new_port: The port to link to. + :type new_port: PortID + :param old_location: The old location to read from. + :type old_location: Loc + :param old_port: The old location to link from. + :type old_port: PortID + :raises TierkreisError: If linking is not possible. + """ new_dir = self._output_path(new_location, new_port) try: self.link(self._output_path(old_location, old_port), new_dir) except FileNotFoundError as e: logger.warning( - f"Could not link {e.filename} to {e.filename2}." - " Possibly a mislabelled variable?" + "Could not link %s to %s. Possibly a mislabelled variable?", + e.filename, + e.filename2, ) except OSError as e: + msg = "Workflow already exists." + "Try running with a different ID or do_cleanup." raise TierkreisError( - "Workflow already exists. Try running with a different ID or do_cleanup." + msg, ) from e def write_output( - self, node_location: Loc, output_name: PortID, value: bytes + self, + node_location: Loc, + output_name: PortID, + value: bytes, ) -> Path: + """Write the value of an output to storage. + + :param node_location: The location to write to. + :type node_location: Loc + :param output_name: The port for which to write the value. + :type output_name: PortID + :param value: The value to write. + :type value: bytes + :return: The respective path in storage. + :rtype: Path + """ output_path = self._output_path(node_location, output_name) self.write(output_path, bytes(value)) return output_path def read_output(self, node_location: Loc, output_name: PortID) -> bytes: + """Read the named output at the node location. + + :param node_location: The location to read from. + :type node_location: Loc + :param output_name: The port identifying the output. + :type output_name: PortID + :return: The value at the port. + :rtype: bytes + """ return self.read(self._output_path(node_location, output_name)) def read_errors(self, node_location: Loc) -> str: + """Read the errors that occurred at the node location. + + Only valid for function nodes (tasks) and the top level node ("-"). + + :param node_location: The location to read from. + :type node_location: Loc + :return: The error message that occurred. "" if nothing was logged. + :rtype: str + """ if self.exists(self._worker_logs_path(node_location)): return self.read(self._worker_logs_path(node_location)).decode() if self.exists(self._error_path(node_location)): @@ -201,63 +371,163 @@ def read_errors(self, node_location: Loc) -> str: return "" def write_node_errors(self, node_location: Loc, error_logs: str) -> None: + """Write the errors of a node to storage. + + Only valid for function nodes (tasks) and the top level node ("-"). + + :param node_location: The location to write to. + :type node_location: Loc + :param error_logs: The error message to write. + :type error_logs: str + """ self.write(self._worker_logs_path(node_location), error_logs.encode()) def read_output_ports(self, node_location: Loc) -> list[PortID]: + """Read the list of named outputs of a node in storage. + + :param node_location: The location to read from. + :type node_location: Loc + :return: A list of output names. + :rtype: list[PortID] + """ dir_list = self.list_subpaths(self._outputs_dir(node_location)) dir_list.sort() return [x.name for x in dir_list] def is_node_started(self, node_location: Loc) -> bool: + """Check whether a node is started. + + A node is started <==> the controller has written its definition. + + :param node_location: The location to check. + :type node_location: Loc + :return: True if the node is started. + :rtype: bool + """ return self.exists(Path(self._nodedef_path(node_location))) def is_node_finished(self, node_location: Loc) -> bool: + """Check whether a node is finished. + + A node is finished <==> A _done file/marker is set. + + :param node_location: The location to check. + :type node_location: Loc + :return: True if the node is finished. + :rtype: bool + """ return self.exists(self._done_path(node_location)) def latest_loop_iteration(self, loc: Loc) -> Loc: + """Find the latest iteration location of a loop node. + + :param loc: The location to check. + :type loc: Loc + :return: A location representing the latest iteration. + :rtype: Loc + """ i = 0 while self.is_node_started(loc.L(i + 1)): i += 1 return loc.L(i) def node_has_error(self, node_location: Loc) -> bool: + """Check whether a node has encountered an error. + + Only valid for function nodes (tasks) and the top level node ("-"). + A node is errored <==> A _error file/marker is set. + + :param node_location: The location to check. + :type node_location: Loc + :return: True if the node has an error. + :rtype: bool + """ return self.exists(self._error_path(node_location)) def mark_node_finished(self, node_location: Loc) -> None: + """Mark a node as successfully finished. + + :param node_location: The location to mark. + :type node_location: Loc + """ self.touch(self._done_path(node_location)) if (parent := node_location.parent()) is not None: self.touch(self._metadata_path(parent)) def write_metadata(self, node_location: Loc) -> None: - j = json.dumps({"name": self.name, "start_time": datetime.now().isoformat()}) - self.write(self._metadata_path(node_location), j.encode()) + """Write the metadata for a node. + + Currently metadata contains name and start time. + + :param node_location: The location to write to. + :type node_location: Loc + """ + json_string = json.dumps( + {"name": self.name, "start_time": datetime.now(UTC).isoformat()}, + ) + self.write(self._metadata_path(node_location), json_string.encode()) def read_metadata(self, node_location: Loc) -> dict[str, Any]: + """Read the metadata of a node. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The metadata stored at the location. + :rtype: dict[str, Any] + """ return json.loads(self.read(self._metadata_path(node_location))) def read_started_time(self, node_location: Loc) -> str | None: + """Read the start time of a node. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The time as is isoformat if the node has started. + :rtype: str | None + """ node_def = Path(self._nodedef_path(node_location)) if not self.exists(node_def): return None since_epoch = self.stat(node_def).st_mtime if since_epoch is None: return None - return datetime.fromtimestamp(since_epoch).isoformat() + return datetime.fromtimestamp(since_epoch, UTC).isoformat() def read_finished_time(self, node_location: Loc) -> str | None: + """Read the finish time of a node. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The time as is isoformat if the node has finished. + :rtype: str | None + """ done = Path(self._done_path(node_location)) if not self.exists(done): return None since_epoch = self.stat(done).st_mtime if since_epoch is None: return None - return datetime.fromtimestamp(since_epoch).isoformat() + return datetime.fromtimestamp(since_epoch, UTC).isoformat() def read_loop_trace(self, node_location: Loc, output_name: PortID) -> list[bytes]: + """Read the trace of a loop node for a given output. + + A trace is a list of values v, where v[i] represents the value of the output at + iteration i. + + :param node_location: The location to read from. + :type node_location: Loc + :param output_name: The output to trace. + :type output_name: PortID + :raises TierkreisError: If the node at the location is not a loop node. + :return: A list of values considered the trace. + :rtype: list[bytes] + """ definition = self.read_node_def(node_location) if definition.type != "loop": - raise TierkreisError("Can only read traces from loop nodes.") + msg = "Can only read traces from loop nodes." + raise TierkreisError(msg) result = [] i = 0 @@ -267,57 +537,104 @@ def read_loop_trace(self, node_location: Loc, output_name: PortID) -> list[bytes return result def loc_from_node_name(self, node_name: str) -> Loc | None: + """Find the location of a node for a given name. + + Currently only loop nodes can be named. + Loop names are stored as debug data. + This can only be invoked after running the workflow. + + :param node_name: The name to search for. + :type node_name: str + :return: Returns the location if found in storage. + :rtype: Loc | None + """ debug_data = StorageDebugData(**self.read_debug_data(node_name)) if debug_data.loop_loc is not None: return Loc(debug_data.loop_loc) + return None + + def write_debug_data(self, name: str, node_location: Loc) -> None: + """Write the debug data of a node. - def write_debug_data(self, name: str, loc: Loc) -> None: + Currently name is derived from a named loop. + + :param name: The name to write to. + :type name: str + :param node_location: The location to write as debug information. + :type node_location: Loc + """ self.mkdir(self.debug_path) - data = StorageDebugData(loop_loc=loc) + data = StorageDebugData(loop_loc=node_location) self.write(self.debug_path / name, json.dumps(asdict(data)).encode()) def read_debug_data(self, name: str) -> dict[str, Any]: + """Read the debug data for a given name. + + Currently name is derived from a named loop. + + :param name: The name to read from. + :type name: str + :return: The data available for the name. + :rtype: dict[str, Any] + """ return json.loads(self.read(self.debug_path / name)) - def dependents(self, loc: Loc) -> set[Loc]: - """Nodes that are fully invalidated if the node at the given loc is invalidated. + def dependents(self, node_location: Loc) -> set[Loc]: + """Get the dependents (successors) of a node. - This does not include the direct parent Loc, which is only partially invalidated. + Dependents are nodes that are fully invalidated if + the node at the given loc is invalidated. + This does not include the direct parent Loc, + which is only partially invalidated. + + :param loc: The location to get the dependents for. + :type loc: Loc + :return: A set of dependent nodes. + :rtype: set[Loc] """ - descs: set[Loc] = set() - step, parent = loc.pop_last() + descendants: set[Loc] = set() + step, parent = node_location.pop_last() match step: case "-": pass case ("N", _): - nodedef = self.read_node_def(loc) + nodedef = self.read_node_def(node_location) if nodedef.type == "output": - descs.update(self.dependents(parent)) + descendants.update(self.dependents(parent)) for output in nodedef.outputs.values(): - descs.add(parent.N(output)) - descs.update(self.dependents(parent.N(output))) + descendants.add(parent.N(output)) + descendants.update(self.dependents(parent.N(output))) case ("M", _): - descs.update(self.dependents(parent)) + descendants.update(self.dependents(parent)) case ("L", idx): latest_idx = self.latest_loop_iteration(parent).peek_index() - [descs.add(parent.L(i)) for i in range(idx + 1, latest_idx + 1)] - descs.update(self.dependents(parent)) + [descendants.add(parent.L(i)) for i in range(idx + 1, latest_idx + 1)] + descendants.update(self.dependents(parent)) case _: assert_never(step) - return descs + return descendants def restart_task(self, loc: Loc) -> list[Loc]: """Restart the task/function node at the given loc. - Fully dependent nodes will be removed from the storage. + A node is restarted by removing its (and its dependents) started flag. + The controller will then pick it up as not started. + Fully dependent nodes will be completely removed from the storage. The parent locs will be partially invalidated. - Returns the invalidated nodes.""" + Returns the invalidated nodes. + :param loc: _description_ + :type loc: Loc + :raises TierkreisError: If the node at the location is not a task. + :return: _description_ + :rtype: list[Loc] + """ nodedef = self.read_node_def(loc) if nodedef.type != "function": - raise TierkreisError("Can only restart task/function nodes.") + msg = "Can only restart task/function nodes." + raise TierkreisError(msg) # Remove fully invalidated nodes. deps = self.dependents(loc) @@ -328,7 +645,8 @@ def restart_task(self, loc: Loc) -> list[Loc]: [self.delete(self._done_path(x)) for x in partials] [self.delete(self.workflow_dir / a / "outputs") for a in partials] - # Mark given Loc as not started, so that the controller picks it up on the next tick. + # Mark given Loc as not started + # so that the controller picks it up on the next tick. self.delete(self._nodedef_path(loc)) return list(deps) diff --git a/tierkreis/tierkreis/controller/storage/walk.py b/tierkreis/tierkreis/controller/storage/walk.py index b8b244868..b79684e1e 100644 --- a/tierkreis/tierkreis/controller/storage/walk.py +++ b/tierkreis/tierkreis/controller/storage/walk.py @@ -1,3 +1,11 @@ +"""Functions to walk a (computational) workflow graph. + +In conjunction with `start()` this is one of the primary functions of the +tierkreis controller. +By walking the graph we update nodes with new inputs from finished nodes until +they can be started. +""" + from dataclasses import dataclass, field from logging import getLogger from typing import assert_never @@ -24,11 +32,27 @@ @dataclass class WalkResult: + """Dataclass to keep track of the nodes we encounter during the walk. + + :fields: + inputs_ready (list[NodeRunData]): A list of nodes that now have all inputs ready + and therefore can be started. + started (list[Loc]): A list of locations that have been started (on this walk). + errored (list[Loc]): A list of locations that have encountered an error. + """ + inputs_ready: list[NodeRunData] started: list[Loc] errored: list[Loc] = field(default_factory=list[Loc]) def extend(self, walk_result: "WalkResult") -> None: + """Extend a current walk result with an existing one. + + Simply extends all three list fields accordingly. + + :param walk_result: The walk_result to update self with. + :type walk_result: WalkResult + """ self.inputs_ready.extend(walk_result.inputs_ready) self.started.extend(walk_result.started) self.errored.extend(walk_result.errored) @@ -41,19 +65,66 @@ def unfinished_results( node: NodeDef, graph: GraphData, ) -> int: + """Find and walk all the unfinished results. + + Finds all "blocking" nodes of the current nodes and marks them to be started. + Blocking nodes are inputs that are not done yet. + We walk recursively into the nodes that are not finished yet to progress them, + by marking them ready for starting or done. + + :param result: The walk result, where we add all unfinished nodes. + Used to bubble up the nodes from the recursive walk call. + :type result: WalkResult + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + e.g, the eval node containing the current node. + :type parent: Loc + :param node: The current node for which we check inputs. + :type node: NodeDef + :param graph: The graph to walk. + :type graph: GraphData + :return: The number of nodes that have unfinished inputs. + :rtype: int + """ unfinished = unfinished_inputs(storage, parent, node) [result.extend(walk_node(storage, parent, x[0], graph)) for x in unfinished] return len(unfinished) def walk_node( - storage: ControllerStorage, parent: Loc, idx: NodeIndex, graph: GraphData + storage: ControllerStorage, + parent: Loc, + idx: NodeIndex, + graph: GraphData, ) -> WalkResult: - """Should only be called when a node has not finished.""" + """Walk a graph node. + + Should only be called when a node has not finished. + This is one of the core functions of the controller. + It checks for a current node how to proceed: + - Continue if its already done + - Mark it for starting if its inputs are now ready and its not started. + - Mark the respective next nodes to start, e.g. in case of an eval start + the first nodes inside (that now have their inputs ready). + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + e.g, the eval node containing the current node. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param graph: The graph to walk. + :type graph: GraphData + :return: A summary list of finished, errored, and ready nodes. + :rtype: WalkResult + """ loc = parent.N(idx) if storage.node_has_error(loc): - logger.error(f"Node {loc} has encountered an error.") - logger.debug(f"\n\n{storage.read_errors(loc)}\n\n") + # we immediately stop if a node has an error and bubble the error up + logger.error("Node %s has encountered an error.", loc) + logger.debug("\n\n%s\n\n", storage.read_errors(loc)) return WalkResult([], [], [loc]) node = graph.nodes[idx] @@ -61,13 +132,20 @@ def walk_node( result = WalkResult([], []) if unfinished_results(result, storage, parent, node, graph): + # cannot start, don't have all inputs yet return result if not storage.is_node_started(loc): + # have all inputs, start current node return WalkResult([node_run_data], []) + # Handle cases where we have nested graphs. + # Basically we have to forward the now available outputs from outer scope + # into the scope of the nested graph, so we check inside the nested + # for new candidates. match node.type: case "eval": + # step inside the nested graph for walking message = storage.read_output(parent.N(node.graph[0]), node.graph[1]) g = ptype_from_bytes(message, GraphData) return walk_node(storage, loc, g.output_idx(), g) @@ -85,6 +163,7 @@ def walk_node( return walk_map(storage, parent, idx, node) case "ifelse": + # walk the next node only after we have the value of "pred" pred = storage.read_output(parent.N(node.pred[0]), node.pred[1]) next_node = node.if_true if pred == b"true" else node.if_false next_loc = parent.N(next_node[0]) @@ -92,13 +171,13 @@ def walk_node( storage.link_outputs(loc, Labels.VALUE, next_loc, next_node[1]) storage.mark_node_finished(loc) return WalkResult([], []) - else: - return walk_node(storage, parent, next_node[0], graph) + return walk_node(storage, parent, next_node[0], graph) case "eifelse": return walk_eifelse(storage, parent, idx, node) case "function": + # Current node can start, done will be marked by executor. return WalkResult([], [loc]) case "input": @@ -108,33 +187,66 @@ def walk_node( def walk_loop( - storage: ControllerStorage, parent: Loc, idx: NodeIndex, loop: Loop + storage: ControllerStorage, + parent: Loc, + idx: NodeIndex, + loop: Loop, ) -> WalkResult: + """Walk a loop node. + + The controller walks a loop by: + - checking the current iteration + - checking the `should_continue` port + - mapping outputs to inputs between iterations + - and finally producing the outputs. + + Each iteration is evaluated by inserting a virtual eval node + at the location PARENT_LOC.L.N-1 that then gets picked up by walk + and the next start step. + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + E.g. the "eval" containing this statement. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param loop: The loop node being walked. + :type loop: Loop + :return: A summary list of finished, errored, and ready nodes: + Either empty (loop is done), or next loop iteration. + :rtype: WalkResult + """ loc = parent.N(idx) if storage.is_node_finished(loc): - return WalkResult([], [], []) + return WalkResult([], []) + # find the last iteration new_location = storage.latest_loop_iteration(loc) - + # and read the graph definition message = storage.read_output(loc.N(-1), BODY_PORT) g = ptype_from_bytes(message, GraphData) loop_outputs = g.nodes[g.output_idx()].inputs - + # if the iteration is not done finished, walk its nested graph (body) if not storage.is_node_finished(new_location): return walk_node(storage, new_location, g.output_idx(), g) # Latest iteration is finished. Do we BREAK or CONTINUE? should_continue = ptype_from_bytes( - storage.read_output(new_location, loop.continue_port), bool + storage.read_output(new_location, loop.continue_port), + bool, ) if should_continue is False: + # were done here, set outputs to parent scope for k in loop_outputs: storage.link_outputs(loc, k, new_location, k) storage.mark_node_finished(loc) return WalkResult([], []) - ins = {k: (-1, k) for k in loop.inputs.keys()} + # continue looping, provide the inputs for the next iter from the current + ins = {k: (-1, k) for k in loop.inputs} ins.update(loop_outputs) + # Mark the next node as ready node_run_data = NodeRunData( loc.L(new_location.peek_index() + 1), Eval((-1, BODY_PORT), ins, loop.outputs), @@ -144,26 +256,56 @@ def walk_loop( def walk_map( - storage: ControllerStorage, parent: Loc, idx: NodeIndex, map: Map + storage: ControllerStorage, + parent: Loc, + idx: NodeIndex, + map_node: Map, ) -> WalkResult: + """Walk a map node. + + We refer to the evaluation of the graph with a set of inputs as map elements. + I. e. one map consist map fun [a,b,c] will have the element fun a, fun b, fun c. + Each of these elements are treated as an virtual eval + at the location PARENT_LOC.M.N-1 that then gets picked up by walk. + In contrast to loop, as all map elements can be immediately started this is set + in the start function, which makes the virtual evals optional; + They are treated differently than evals/loops. + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + E.g. the "eval" containing this statement. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param map_node: The map node being walked. + :type map_node: Map + :return: A summary list of finished, errored, and ready nodes: + Either empty (map is done), or all intermediate nodes in the map elements. + :rtype: WalkResult + """ loc = parent.N(idx) result = WalkResult([], []) if storage.is_node_finished(loc): return result - first_ref = next(x for x in map.inputs.values() if x[1] == "*") + # find all values to map over + first_ref = next(x for x in map_node.inputs.values() if x[1] == "*") map_eles = outputs_iter(storage, parent.N(first_ref[0])) + # find all map elements that are not done unfinished = [i for i, _ in map_eles if not storage.is_node_finished(loc.M(i))] + # Read the graph def message = storage.read_output(loc.M(0).N(-1), BODY_PORT) g = ptype_from_bytes(message, GraphData) + # Walk all map elements simultaneously [result.extend(walk_node(storage, loc.M(p), g.output_idx(), g)) for p in unfinished] if len(unfinished) > 0: return result - + # All map elements are done, mark the entire map done map_outputs = g.nodes[g.output_idx()].inputs for i, j in map_eles: - for output in map_outputs.keys(): + for output in map_outputs: storage.link_outputs(loc, f"{output}-{j}", loc.M(i), output) storage.mark_node_finished(loc) @@ -176,6 +318,24 @@ def walk_eifelse( idx: NodeIndex, node: EagerIfElse, ) -> WalkResult: + """Walk an eager if else node. + + In an eager if else node we have already evaluated all its + inputs (pred, if_true, if_false). + Therefore we just need to move the correct inputs to its outputs. + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + E.g. the "eval" containing this statement. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param node: The eager if else node being walked. + :type node: EagerIfElse + :return: An empty walk result since we here nothing else is to do. + :rtype: WalkResult + """ loc = parent.N(idx) pred = storage.read_output(parent.N(node.pred[0]), node.pred[1]) next_node = node.if_true if pred == b"true" else node.if_false diff --git a/tierkreis/tierkreis/exceptions.py b/tierkreis/tierkreis/exceptions.py index 2e1c8357d..f98859773 100644 --- a/tierkreis/tierkreis/exceptions.py +++ b/tierkreis/tierkreis/exceptions.py @@ -1,2 +1,5 @@ +"""Tierkreis exception definitions.""" + + class TierkreisError(Exception): """An error thrown in the Tierkreis library.""" diff --git a/tierkreis/tierkreis/executor.py b/tierkreis/tierkreis/executor.py index 38f6ab6e7..7504cc239 100644 --- a/tierkreis/tierkreis/executor.py +++ b/tierkreis/tierkreis/executor.py @@ -1,7 +1,8 @@ +"""Tierkreis executors definitions.""" + +from tierkreis.controller.executor.hpc.pjsub import PJSUBExecutor +from tierkreis.controller.executor.multiple import MultipleExecutor from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.executor.uv_executor import UvExecutor -from tierkreis.controller.executor.multiple import MultipleExecutor -from tierkreis.controller.executor.hpc.pjsub import PJSUBExecutor - -__all__ = ["ShellExecutor", "UvExecutor", "MultipleExecutor", "PJSUBExecutor"] +__all__ = ["MultipleExecutor", "PJSUBExecutor", "ShellExecutor", "UvExecutor"] diff --git a/tierkreis/tierkreis/graphs/__init__.py b/tierkreis/tierkreis/graphs/__init__.py new file mode 100644 index 000000000..0b23ab3ca --- /dev/null +++ b/tierkreis/tierkreis/graphs/__init__.py @@ -0,0 +1 @@ +"""Preconstructed graphs for reuse.""" diff --git a/tierkreis/tierkreis/graphs/fold.py b/tierkreis/tierkreis/graphs/fold.py index 274cabb98..fd33bc4f2 100644 --- a/tierkreis/tierkreis/graphs/fold.py +++ b/tierkreis/tierkreis/graphs/fold.py @@ -1,4 +1,7 @@ -from typing import Generic, NamedTuple, TypeVar +"""Preconstructed graph for folding operations.""" + +from typing import NamedTuple, TypeVar + from tierkreis.builder import GraphBuilder, TypedGraphRef from tierkreis.builtins.stubs import head, igt, tkr_len from tierkreis.controller.data.graph import GraphData @@ -6,25 +9,28 @@ from tierkreis.controller.data.types import PType -class FoldGraphOuterInputs[A: PType, B: PType](NamedTuple): +class _FoldGraphOuterInputs[A: PType, B: PType](NamedTuple): func: TKR[GraphData] accum: TKR[B] values: TKR[list[A]] -class FoldGraphOuterOutputs[A: PType, B: PType](NamedTuple): +class _FoldGraphOuterOutputs[A: PType, B: PType](NamedTuple): accum: TKR[B] values: TKR[list[A]] should_continue: TKR[bool] -class InnerFuncInput[A: PType, B: PType](NamedTuple): +class _InnerFuncInput[A: PType, B: PType](NamedTuple): accum: TKR[B] value: TKR[A] -def _fold_graph_outer[A: PType, B: PType](): - g = GraphBuilder(FoldGraphOuterInputs[A, B], FoldGraphOuterOutputs[A, B]) +def _fold_graph_outer[A: PType, B: PType]() -> GraphBuilder[ + _FoldGraphOuterInputs[A, B], + _FoldGraphOuterOutputs[A, B], +]: + g = GraphBuilder(_FoldGraphOuterInputs[A, B], _FoldGraphOuterOutputs[A, B]) func = g.inputs.func accum = g.inputs.accum @@ -38,41 +44,67 @@ def _fold_graph_outer[A: PType, B: PType](): headed = g.task(head(values)) # Apply the function if we were able to pop off a value. - tgd = TypedGraphRef[InnerFuncInput, TKR[B]]( - func.value_ref(), TKR[B], InnerFuncInput + tgd = TypedGraphRef[_InnerFuncInput, TKR[B]]( + func.value_ref(), + TKR[B], + _InnerFuncInput, ) - applied_next = g.eval(tgd, InnerFuncInput(accum, headed.head)) + applied_next = g.eval(tgd, _InnerFuncInput(accum, headed.head)) next_accum = g.ifelse(non_empty, applied_next, accum) next_values = g.ifelse(non_empty, headed.rest, values) - g.outputs(FoldGraphOuterOutputs(next_accum, next_values, non_empty)) + g.outputs(_FoldGraphOuterOutputs(next_accum, next_values, non_empty)) return g -A = TypeVar("A", bound=PType, covariant=True) -B = TypeVar("B", bound=PType, covariant=True) +A_co = TypeVar("A_co", bound=PType, covariant=True) +B_co = TypeVar("B_co", bound=PType, covariant=True) + +class FoldGraphInputs[A: PType, B: PType](NamedTuple): + """Inputs to a fold graph. + + :fields: + initial (B): The initial value. + values (list[A]): The list of values to fold over. + """ -class FoldGraphInputs(NamedTuple, Generic[A, B]): initial: TKR[B] values: TKR[list[A]] -class FoldFunctionInput(NamedTuple, Generic[A, B]): +class FoldFunctionInput[A: PType, B: PType](NamedTuple): + """Input type of a fold function. + + :fields: + accum (B): The accumulator. + value (A): The current value. + """ + accum: TKR[B] value: TKR[A] -# fold : {func: (b -> a -> b)} -> {initial: b} -> {values: list[a]} -> {value: b} -# fold : { A x B -> B } -> { list[A] x B -> B } def fold_graph( - func: GraphBuilder[FoldFunctionInput[A, B], TKR[B]], -) -> GraphBuilder[FoldGraphInputs[A, B], TKR[B]]: - g = GraphBuilder(FoldGraphInputs[A, B], TKR[B]) - foldfunc = g._graph_const(func) - # TODO: include the computation inside the fold - ins = FoldGraphOuterInputs( - TKR(*foldfunc.graph_ref), g.inputs.initial, g.inputs.values + func: GraphBuilder[FoldFunctionInput[A_co, B_co], TKR[B_co]], +) -> GraphBuilder[FoldGraphInputs[A_co, B_co], TKR[B_co]]: + """Construct a fold graph. + + fold : {func: (b -> a -> b)} -> {initial: b} -> {values: list[a]} -> {value: b} + fold : { A x B -> B } -> { list[A] x B -> B } + + :param func: The function to fold over. + :type func: GraphBuilder[FoldFunctionInput[A_co, B_co], TKR[B_co]] + :return: A graph implementing the fold function. + :rtype: GraphBuilder[FoldGraphInputs[A_co, B_co], TKR[B_co]] + """ + g = GraphBuilder(FoldGraphInputs[A_co, B_co], TKR[B_co]) + foldfunc = g._graph_const(func) # noqa: SLF001 + # TODO @mwpb: include the computation inside the fold + ins = _FoldGraphOuterInputs( + TKR(*foldfunc.graph_ref), + g.inputs.initial, + g.inputs.values, ) loop = g.loop(_fold_graph_outer(), ins) g.outputs(loop.accum) diff --git a/tierkreis/tierkreis/graphs/nexus/__init__.py b/tierkreis/tierkreis/graphs/nexus/__init__.py new file mode 100644 index 000000000..0a94d4f45 --- /dev/null +++ b/tierkreis/tierkreis/graphs/nexus/__init__.py @@ -0,0 +1 @@ +"""Nexus graphs.""" diff --git a/tierkreis/tierkreis/graphs/nexus/submit_poll.py b/tierkreis/tierkreis/graphs/nexus/submit_poll.py index 622b228b6..0d779bed0 100644 --- a/tierkreis/tierkreis/graphs/nexus/submit_poll.py +++ b/tierkreis/tierkreis/graphs/nexus/submit_poll.py @@ -1,27 +1,46 @@ +"""Sample graphs to interact with nexus using the Nexus Worker.""" + # ruff: noqa: F821 from typing import NamedTuple + from tierkreis.builder import GraphBuilder from tierkreis.builtins.stubs import tkr_sleep from tierkreis.controller.data.models import TKR, OpaqueType from tierkreis.nexus_worker import ( - upload_circuit, - start_execute_job, - is_running, get_results, + is_running, + start_execute_job, + upload_circuit, ) -type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] +type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] # noqa: SLF001 type BackendResult = OpaqueType["pytket.backends.backendresult.BackendResult"] type ExecuteJobRef = OpaqueType["qnexus.models.references.ExecuteJobRef"] -type ExecutionProgram = OpaqueType["qnexus.models.references.ExecuteJobRef"] class UploadCircuitInputs(NamedTuple): + """The inputs to upload a circuit. + + :fields: + project_name (str): The name of the project to upload to. + circuit (Circuit): The tket circuit to upload. + """ + project_name: TKR[str] circuit: TKR[Circuit] class JobInputs(NamedTuple): + """The inputs to a nexus job. + + :fields: + project_name (str): The name of the project to upload to. + job_name (str): The name of the job. + circuit (list[Circuit]): The list of circuits part of the job. + n_shots (int): The number of shots (repetitions) of each circuit. + backend_config (BackendConfig): The qnexus configuration of the backend. + """ + project_name: TKR[str] job_name: TKR[str] circuits: TKR[list[Circuit]] @@ -29,33 +48,54 @@ class JobInputs(NamedTuple): backend_config: TKR[OpaqueType["qnexus.BackendConfig"]] -class LoopOutputs(NamedTuple): +class _LoopOutputs(NamedTuple): results: TKR[list[BackendResult]] should_continue: TKR[bool] -def upload_circuit_graph(): - g = GraphBuilder(UploadCircuitInputs, TKR[ExecutionProgram]) +def upload_circuit_graph() -> GraphBuilder[UploadCircuitInputs, TKR[ExecuteJobRef]]: + """Construct a graph to upload a circuit to nexus. + + :return: A uploading graph. + :rtype: GraphBuilder[UploadCircuitInputs, TKR[ExecuteJobRef]] + """ + g = GraphBuilder(UploadCircuitInputs, TKR[ExecuteJobRef]) programme = g.task(upload_circuit(g.inputs.project_name, g.inputs.circuit)) - g.outputs(programme) # type: ignore + g.outputs(programme) # type: ignore[arg-type] return g -def polling_loop_body(polling_interval: float): - g = GraphBuilder(TKR[ExecuteJobRef], LoopOutputs) +def _polling_loop_body( + polling_interval: float, +) -> GraphBuilder[TKR[ExecuteJobRef], _LoopOutputs]: + g = GraphBuilder(TKR[ExecuteJobRef], _LoopOutputs) pred = g.task(is_running(g.inputs)) - wait = g.ifelse(pred, g.task(tkr_sleep(g.const(polling_interval))), g.const(False)) + wait = g.ifelse( + pred, + g.task(tkr_sleep(g.const(polling_interval))), + g.const(value=False), + ) results = g.ifelse(pred, g.const([]), g.task(get_results(g.inputs))) - g.outputs(LoopOutputs(results=results, should_continue=wait)) + g.outputs(_LoopOutputs(results=results, should_continue=wait)) return g -def nexus_submit_and_poll(polling_interval: float = 30.0): +def nexus_submit_and_poll( + polling_interval: float = 30.0, +) -> GraphBuilder[JobInputs, TKR[list[BackendResult]]]: + """Construct a graph submitting and polling a nexus job. + + :param polling_interval: The polling interval in seconds, defaults to 30.0 + :type polling_interval: float, optional + :return: A graph performing submission and polling. + :rtype: GraphBuilder[JobInputs, TKR[list[BackendResult]]] + """ g = GraphBuilder(JobInputs, TKR[list[BackendResult]]) upload_inputs = g.map( - lambda x: UploadCircuitInputs(g.inputs.project_name, x), g.inputs.circuits + lambda x: UploadCircuitInputs(g.inputs.project_name, x), + g.inputs.circuits, ) programmes = g.map(upload_circuit_graph(), upload_inputs) @@ -63,12 +103,12 @@ def nexus_submit_and_poll(polling_interval: float = 30.0): start_execute_job( g.inputs.project_name, g.inputs.job_name, - programmes, # type: ignore + programmes, # type: ignore[arg-type] g.inputs.n_shots, - g.inputs.backend_config, # type: ignore - ) + g.inputs.backend_config, # type: ignore[arg-type] + ), ) - res = g.loop(polling_loop_body(polling_interval), ref) + res = g.loop(_polling_loop_body(polling_interval), ref) g.outputs(res.results) return g diff --git a/tierkreis/tierkreis/graphs/simulate/__init__.py b/tierkreis/tierkreis/graphs/simulate/__init__.py new file mode 100644 index 000000000..9647a1a74 --- /dev/null +++ b/tierkreis/tierkreis/graphs/simulate/__init__.py @@ -0,0 +1 @@ +"""Simulation Graphs.""" diff --git a/tierkreis/tierkreis/graphs/simulate/compile_simulate.py b/tierkreis/tierkreis/graphs/simulate/compile_simulate.py index 77e9d003d..bddf0ff36 100644 --- a/tierkreis/tierkreis/graphs/simulate/compile_simulate.py +++ b/tierkreis/tierkreis/graphs/simulate/compile_simulate.py @@ -1,23 +1,38 @@ +"""Sample graphs to simulate quantum circuits on different backends.""" + # ruff: noqa: F821 from typing import Literal, NamedTuple -from tierkreis.builder import GraphBuilder -from tierkreis.controller.data.models import TKR, OpaqueType -from tierkreis.builtins.stubs import tkr_zip, untuple + from tierkreis.aer_worker import ( get_compiled_circuit as aer_compile, +) +from tierkreis.aer_worker import ( run_circuit as aer_run, ) +from tierkreis.builder import GraphBuilder +from tierkreis.builtins.stubs import str_eq, tkr_zip, untuple +from tierkreis.controller.data.models import TKR, OpaqueType from tierkreis.qulacs_worker import ( get_compiled_circuit as qulacs_compile, +) +from tierkreis.qulacs_worker import ( run_circuit as qulacs_run, ) -from tierkreis.builtins.stubs import str_eq type BackendResult = OpaqueType["pytket.backends.backendresult.BackendResult"] -type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] +type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] # noqa: SLF001 class SimulateJobInputs(NamedTuple): + """Input to simulate multiple quantum circuits on a local backend. + + :fields: + simulator_name (Literal): either 'aer' or 'qulacs'. + circuits (list[Ciruit]): The list of circuits to run. + n_shots (int): The number of shots (repetitions) of each circuit. + compilation_optimisation_level (int): tket optimisation level in [0,1,2,3]. + """ + simulator_name: TKR[Literal["aer", "qulacs"]] circuits: TKR[list[Circuit]] n_shots: TKR[list[int]] @@ -25,12 +40,27 @@ class SimulateJobInputs(NamedTuple): class SimulateJobInputsSingle(NamedTuple): + """Input to simulate multiple quantum circuits on a local backend. + + :fields: + simulator_name (Literal): either 'aer' or 'qulacs'. + circuit_shots (tuple[Ciruit, int]): The circuits to run and the number of shots. + compilation_optimisation_level (int): tket optimisation level in [0,1,2,3]. + """ + simulator_name: TKR[Literal["aer", "qulacs"]] circuit_shots: TKR[tuple[Circuit, int]] compilation_optimisation_level: TKR[int] -def aer_simulate_single(): +def aer_simulate_single() -> GraphBuilder[SimulateJobInputsSingle, TKR[BackendResult]]: + """Construct a graph to simulate a single circuit using qiskit aer. + + This ignores the simulator_name field. + + :return: The graph for the simulation. + :rtype: GraphBuilder[SimulateJobInputsSingle, TKR[BackendResult]] + """ g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult]) circuit_shots = g.task(untuple(g.inputs.circuit_shots)) @@ -38,14 +68,24 @@ def aer_simulate_single(): aer_compile( circuit=circuit_shots.a, optimisation_level=g.inputs.compilation_optimisation_level, - ) + ), ) res = g.task(aer_run(compiled_circuit, circuit_shots.b)) g.outputs(res) return g -def qulacs_simulate_single(): +def qulacs_simulate_single() -> GraphBuilder[ + SimulateJobInputsSingle, + TKR[BackendResult], +]: + """Construct a graph to simulate a single circuit using qulacs. + + This ignores the simulator_name field. + + :return: The graph for the simulation. + :rtype: GraphBuilder[SimulateJobInputsSingle, TKR[BackendResult]] + """ g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult]) circuit_shots = g.task(untuple(g.inputs.circuit_shots)) @@ -53,27 +93,42 @@ def qulacs_simulate_single(): qulacs_compile( circuit=circuit_shots.a, optimisation_level=g.inputs.compilation_optimisation_level, - ) + ), ) res = g.task(qulacs_run(compiled_circuit, circuit_shots.b)) g.outputs(res) return g -def compile_simulate_single(): +def compile_simulate_single() -> GraphBuilder[ + SimulateJobInputsSingle, + TKR[BackendResult], +]: + """CConstruct a graph to simulate a single job on either aer or qulacs. + + :return: The graph for the simulation. + :rtype: GraphBuilder[ SimulateJobInputsSingle, TKR[BackendResult], ] + """ g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult]) aer_res = g.eval(aer_simulate_single(), g.inputs) qulacs_res = g.eval(qulacs_simulate_single(), g.inputs) res = g.ifelse( - g.task(str_eq(g.inputs.simulator_name, g.const("aer"))), aer_res, qulacs_res + g.task(str_eq(g.inputs.simulator_name, g.const("aer"))), + aer_res, + qulacs_res, ) g.outputs(res) return g -def compile_simulate(): +def compile_simulate() -> GraphBuilder[SimulateJobInputs, TKR[list[BackendResult]]]: + """Construct a graph to simulate multiple jobs on either aer or qulacs. + + :return: The graph for the simulation. + :rtype: GraphBuilder[SimulateJobInputs, TKR[list[BackendResult]]] + """ g = GraphBuilder(SimulateJobInputs, TKR[list[BackendResult]]) circuits_shots = g.task(tkr_zip(g.inputs.circuits, g.inputs.n_shots)) diff --git a/tierkreis/tierkreis/hpc.py b/tierkreis/tierkreis/hpc.py index dbb60488d..4839a951b 100644 --- a/tierkreis/tierkreis/hpc.py +++ b/tierkreis/tierkreis/hpc.py @@ -1,3 +1,5 @@ -from tierkreis.controller.executor.hpc.job_spec import JobSpec, ResourceSpec, MpiSpec +"""Tierkreis HPC utilities.""" -__all__ = ["JobSpec", "ResourceSpec", "MpiSpec"] +from tierkreis.controller.executor.hpc.job_spec import JobSpec, MpiSpec, ResourceSpec + +__all__ = ["JobSpec", "MpiSpec", "ResourceSpec"] diff --git a/tierkreis/tierkreis/idl/__init__.py b/tierkreis/tierkreis/idl/__init__.py new file mode 100644 index 000000000..728c8171a --- /dev/null +++ b/tierkreis/tierkreis/idl/__init__.py @@ -0,0 +1 @@ +"""Parsing for the Tierkreis typespec for external workers.""" diff --git a/tierkreis/tierkreis/idl/models.py b/tierkreis/tierkreis/idl/models.py index 8bce83804..66e167e7a 100644 --- a/tierkreis/tierkreis/idl/models.py +++ b/tierkreis/tierkreis/idl/models.py @@ -1,31 +1,43 @@ +"""Tierkreis IDL models representation used for TSP parsing.""" + +from collections.abc import Mapping, Sequence from dataclasses import dataclass from types import NoneType -from typing import Annotated, Mapping, Self, Sequence, get_args, get_origin +from typing import Annotated, Self, get_args, get_origin from tierkreis.controller.data.core import RestrictedNamedTuple from tierkreis.controller.data.types import _is_generic - type ElementaryType = ( - type[int] - | type[float] - | type[bytes] - | type[str] - | type[bool] - | type[NoneType] - | type[Mapping] - | type[Sequence] + type[int | float | bytes | str | bool | NoneType | Mapping | Sequence] | str # Custom type e.g. forward reference ) @dataclass class GenericType: + """A Tierkreis worker generic type. + + Represents a single (composed) type in worker definitions. + + :fields: + origin (ElementaryType): The base type, e.g., str in list[str]. + args: (Sequence[GenericType | str]) The nested types. + e.g., list[str] in set[list[str]] + """ + origin: ElementaryType args: "Sequence[GenericType | str]" @classmethod def from_type(cls, t: type) -> "Self": + """Construct a generic type from a python type. + + :param t: The python type. + :type t: type + :return: The Tierkreis type. + :rtype: Self + """ if get_origin(t) is Annotated: return cls.from_type(get_args(t)[0]) @@ -44,19 +56,49 @@ def _included_structs(cls, t: "GenericType") -> "set[GenericType]": return outs def included_structs(self) -> "set[GenericType]": + """Find the included structs of this type. + + A struct is an instance of RestrictedNamedTuple or opaque strings. + :return: The list of structs + :rtype: set[GenericType] + """ return GenericType._included_structs(self) def __hash__(self) -> int: + """Produce a hash of the generic type. + + :return: The hash. + :rtype: int + """ return hash(self.origin) def __eq__(self, value: object) -> bool: + """Check the equality of self with an object. + + self == value <==> self.origin == value.origin + + + :param other: The object to compare to. + :type other: object + :return: If bothe object have the same origin. + :rtype: bool + """ if not hasattr(value, "origin"): return False - return self.origin == getattr(value, "origin") + return self.origin == value.origin @dataclass class TypedArg: + """A Tierkreis worker method argument. + + Represents a single argument to a tasks in a worker. + :fields: + name (str): The argument name. + t (GenericType): The argument type. + has_default(bool): Whether the argument has a default value. + """ + name: str t: GenericType has_default: bool = False @@ -64,6 +106,17 @@ class TypedArg: @dataclass class Method: + """A Tierkreis worker method. + + Represents a tasks in a worker. + + :fields: + name (str): The method name. + args (list[TypedArg]): The list of method arguments. + return_type (GenericType): The method return type. + return_type_is_portmapping (bool): Whether the return_type is a portmapping. + """ + name: GenericType args: list[TypedArg] return_type: GenericType @@ -72,18 +125,51 @@ class Method: @dataclass class Interface: + """A Tierkreis worker interface. + + Represents a list of tasks contained in the worker. + + :fields: + name (str): The worker name. + methods (list[Method]): The available tasks in the worker. + """ + name: str methods: list[Method] @dataclass class Model: + """A Tierkreis worker model. + + Represents a type in a worker. + + :fields: + is_portmapping (bool): Whether the model is a portmapping. + t (GenericType): The type of the model. + decl (list[TypedArg]) The list of its typed arguments. + """ + is_portmapping: bool t: GenericType decls: list[TypedArg] def __hash__(self) -> int: + """Produce a hash of the model. + + :return: The hash. + :rtype: int + """ return hash(self.t.origin) - def __lt__(self, other: "Model"): + def __lt__(self, other: "Model") -> bool: + """Check order of two models. + + Uses lexicographical ordering of the origin of the models (generic) types. + + :param other: The model to compare to. + :type other: Model + :return: If self comes before other. + :rtype: bool + """ return str(self.t.origin) < str(other.t.origin) diff --git a/tierkreis/tierkreis/idl/parser.py b/tierkreis/tierkreis/idl/parser.py index f6e76ee02..3e5276629 100644 --- a/tierkreis/tierkreis/idl/parser.py +++ b/tierkreis/tierkreis/idl/parser.py @@ -5,31 +5,51 @@ doesn't type check things correctly. """ +from __future__ import annotations + +import contextlib import re -from typing import Callable, Never, overload +from typing import TYPE_CHECKING, Any, Never, overload +if TYPE_CHECKING: + from collections.abc import Callable from tierkreis.exceptions import TierkreisError -class ParserError(TierkreisError): ... +class ParserError(TierkreisError): + """An error raised when parsing fails in Tierkreis.""" class Parser[T]: + """A parser for an arbitrary type in tierkreis. + + :fields: + fn: The parsing function. + """ + fn: Callable[[str], tuple[T, str]] - def __init__(self, fn: Callable[[str], tuple[T, str]]): + def __init__(self, fn: Callable[[str], tuple[T, str]]) -> None: self.fn = fn def __call__(self, ins: str) -> tuple[T, str]: + """Call the parses on a string. + + :param ins: The string to parse. + :type ins: str + :return: The parsed string and its type. + :rtype: tuple[T, str] + """ ins = ins.strip() return self.fn(ins) def __or__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[T|S]": + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[T | S]: """Try the left parser and only if it fails try the right parser.""" - def f(ins: str): + def f(ins: str) -> tuple[T, str] | tuple[S, str]: try: return self(ins) except ParserError: @@ -38,11 +58,12 @@ def f(ins: str): return Parser(f) def __and__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[tuple[T,S]]": + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[tuple[T, S]]: """Use the left parser and then use the right parser on the remaining input.""" - def f(ins: str): + def f(ins: str) -> tuple[tuple[T, S], str]: s, remaining = self(ins) t, remaining = other(remaining) return (s, t), remaining @@ -50,11 +71,16 @@ def f(ins: str): return Parser(f) def __lshift__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[T]": - """Use the left parser and then the right parser but discard the result of the right parser.""" + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[T]: + """Leftshift parsers. + + Use the left parser and then the right parser + but discard the result of the right parser. + """ - def f(ins: str): + def f(ins: str) -> tuple[T, str]: t, remaining = self(ins) _, remaining = other(remaining) return t, remaining @@ -62,21 +88,26 @@ def f(ins: str): return Parser(f) def __rshift__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[S]": - """Use the left parser and then the right parser but discard the result of the left parser.""" + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[S]: + """Rightshift parsers. - def f(ins: str): + Use the left parser and then the right parser + but discard the result of the left parser. + """ + + def f(ins: str) -> tuple[S, str]: _, remaining = self(ins) s, remaining = other(remaining) return s, remaining return Parser(f) - def opt(self) -> "Parser[T|None]": + def opt(self) -> Parser[T | None]: """Make the parser optional; if it fails then return None and carry on.""" - def f(ins: str): + def f(ins: str) -> tuple[T, str] | tuple[None, str]: try: return self(ins) except ParserError: @@ -84,41 +115,41 @@ def f(ins: str): return Parser(f) - def map[A](self, fn: Callable[[T], A]) -> "Parser[A]": + def map[A](self, fn: Callable[[T], A]) -> Parser[A]: """Apply `fn` to transform the output of the parser.""" - def f(ins: str): + def f(ins: str) -> tuple[A, str]: t, remaining = self(ins) return fn(t), remaining return Parser(f) - def coerce[A](self, a: A) -> "Parser[A]": + def coerce[A](self, a: A) -> Parser[A]: """Shorthand for maps that don't need an argument. - Not strictly speaking required.""" + Not strictly speaking required. + """ - def f(ins: str): - t, remaining = self(ins) + def f(ins: str) -> tuple[A, str]: + _t, remaining = self(ins) return a, remaining return Parser(f) - def rep(self, sep: "Parser[str] | None" = None) -> "Parser[list[T]]": + def rep(self, sep: Parser[str] | None = None) -> Parser[list[T]]: """Repeatedly apply a parser with an optional separator. - The results of the separator parser are discarded.""" + The results of the separator parser are discarded. + """ - def f(ins: str): + def f(ins: str) -> tuple[list[T], str]: outs: list[T] = [] while True: try: t, ins = self(ins) if sep: - try: + with contextlib.suppress(ParserError): _, ins = sep(ins) - except ParserError: - pass outs.append(t) except ParserError: break @@ -126,14 +157,16 @@ def f(ins: str): return Parser(f) - def fail(self, entity: str) -> "Parser[Never]": + def fail(self, entity: str) -> Parser[Never]: """Fail early if we find something we don't support. - Not strictly speaking required.""" + Not strictly speaking required. + """ - def f(ins: str): + def f(ins: str) -> Never: self(ins) - raise TierkreisError(f"{entity} not supported.") + msg = f"{entity} not supported." + raise TierkreisError(msg) return Parser(f) @@ -153,10 +186,12 @@ def seq[A, B, C, D, E]( *args: *tuple[Parser[A], Parser[B], Parser[C], Parser[D], Parser[E]], ) -> Parser[tuple[A, B, C, D, E]]: ... def seq(*args: Parser) -> Parser[tuple]: - """Run a sequence of parsers one after the other - and collect their outputs in a tuple.""" + """Run a sequence of parsers. - def f(ins: str): + Runs parsers one after the other and collect their outputs in a tuple. + """ + + def f(ins: str) -> tuple[tuple[Any, ...], str]: outs = [] for arg in args: s, ins = arg(ins) @@ -167,31 +202,38 @@ def f(ins: str): def lit(*args: str) -> Parser[str]: - """If the input starts with one of the strings in `args` - then take the string off the stream and return it.""" + """Match literal strings at the start of stream and remove them. + + If the input starts with one of the strings in `args` + then take the string off the stream and return it. + """ - def f(ins: str): + def f(ins: str) -> tuple[str, str]: for a in args: if ins.startswith(a): return a, ins[len(a) :] - raise ParserError(f"lit: expected {args} found '{ins[:20]}'") + msg = f"lit: expected {args} found '{ins[:20]}'" + raise ParserError(msg) return Parser(f) def reg(regex: str) -> Parser[str]: - """If start of the input matches the `regex` - then take the matching text off the stream and return it. + """Match a regex against the start of stream and remove it. - Please don't pass match groups within the regex; they will be taken care of.""" + If start of the input matches the `regex` then take the matching text off + the stream and return it. + Please don't pass match groups within the regex; they will be taken care of. + """ - def f(ins: str): + def f(ins: str) -> tuple[str, str]: r = re.compile("^(" + regex + ")") if a := r.match(ins): return a.group(0), ins[a.end() :] - raise ParserError(f"reg: expected regex {regex} found '{ins[:20]}'") + msg = f"reg: expected regex {regex} found '{ins[:20]}'" + raise ParserError(msg) return Parser(f) diff --git a/tierkreis/tierkreis/idl/spec.py b/tierkreis/tierkreis/idl/spec.py index 9faee5b9c..c9f7c5f22 100644 --- a/tierkreis/tierkreis/idl/spec.py +++ b/tierkreis/tierkreis/idl/spec.py @@ -7,11 +7,9 @@ """ from tierkreis.idl.models import Interface, Method, Model, TypedArg - from tierkreis.idl.parser import lit, seq from tierkreis.idl.type_symbols import generic_t, ident, type_symbol - type_decl = ((ident << lit(":")) & type_symbol).map(lambda x: TypedArg(*x)) model = seq( lit("@portmapping").opt().map(lambda x: x is not None) << lit("model"), diff --git a/tierkreis/tierkreis/idl/type_symbols.py b/tierkreis/tierkreis/idl/type_symbols.py index ace5f3669..ca1c4a477 100644 --- a/tierkreis/tierkreis/idl/type_symbols.py +++ b/tierkreis/tierkreis/idl/type_symbols.py @@ -5,12 +5,10 @@ from types import NoneType from typing import ForwardRef + from tierkreis.idl.models import GenericType from tierkreis.idl.parser import Parser, lit, reg, seq -type _TypeT = type | ForwardRef - - signed_int = lit("integer", "int64", "int32", "int16", "int8", "safeint") unsigned_int = lit("uint64", "uint32", "uint16", "uint8") integer_t = (signed_int | unsigned_int).coerce(GenericType(int, [])) @@ -30,19 +28,40 @@ def array_t(ins: str) -> tuple[GenericType, str]: + """Parse a array generic type. + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return (lit("Array<") >> type_symbol << lit(">")).map( - lambda x: GenericType(list, [x]) + lambda x: GenericType(list, [x]), )(ins) def record_t(ins: str) -> tuple[GenericType, str]: + """Parse a record generic type. + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return (lit("Record<") >> type_symbol << lit(">")).map( - lambda x: GenericType(dict, [GenericType(str, []), x]) + lambda x: GenericType(dict, [GenericType(str, []), x]), )(ins) @Parser def generic_t(ins: str) -> tuple[GenericType, str]: + """Parse a generic type. + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return seq( ident, (lit("<") >> ident.rep(lit(",")) << lit(">")).opt().map(lambda x: x or []), @@ -51,6 +70,15 @@ def generic_t(ins: str) -> tuple[GenericType, str]: @Parser def type_symbol(ins: str) -> tuple[GenericType, str]: + """Parse a regular type symbol. + + E.g. int, float, ... + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return ( integer_t | float_t diff --git a/tierkreis/tierkreis/labels.py b/tierkreis/tierkreis/labels.py index 0a6cf0474..c93513a35 100644 --- a/tierkreis/tierkreis/labels.py +++ b/tierkreis/tierkreis/labels.py @@ -4,8 +4,9 @@ class Labels: """Special port labels used by builtin functions.""" - def __init__(self): - raise RuntimeError("Do not instantiate") + def __init__(self) -> None: + msg = "Do not instantiate" + raise RuntimeError(msg) THUNK = "thunk" VALUE = "value" diff --git a/tierkreis/tierkreis/logger_setup.py b/tierkreis/tierkreis/logger_setup.py index 67f495b1f..392beee9f 100644 --- a/tierkreis/tierkreis/logger_setup.py +++ b/tierkreis/tierkreis/logger_setup.py @@ -1,17 +1,28 @@ +"""Sets up the Tierkreis logger.""" + import logging +import sys from os import getenv from pathlib import Path -import sys from tierkreis.consts import TKR_DATE_FMT_KEY, TKR_LOG_FMT_KEY, TKR_LOG_LEVEL_KEY -LOGGER_NAME = "tierkeis" +LOGGER_NAME = "tierkreis" def set_tkr_logger( file_name: Path, level: int | str = logging.INFO, ) -> None: + """Set up the 'tierkreis' logger. + + Adds a filehandler for use in the controller. + + :param file_name: The file to use for the logging. + :type file_name: Path + :param level: The log level, defaults to logging.INFO + :type level: int | str, optional + """ logger = logging.getLogger(LOGGER_NAME) if logger.hasHandlers(): [logger.removeHandler(h) for h in logger.handlers] @@ -23,10 +34,23 @@ def set_tkr_logger( logger.addHandler(handler) except FileNotFoundError: - logging.warning("Could not log to file, logging to std out instead.") + root_logger = logging.getLogger() + root_logger.warning("Could not log to file, logging to std out instead.") def add_handler_from_environment(logger: logging.Logger) -> logging.Handler: + """Add a handler to a logger from TKR env variables. + + Adds a stream handler on stderr with log level, format and date format + taken from the environment variables $TKR_LOG_LEVEL, $TKR_LOG_FMT and + $TKR_DATE_FORMAT. + Returns the created handler so it can be removed later if needed. + + :param logger: The logger to add the handler to. + :type logger: logging.Logger + :return: The created handler. + :rtype: logging.Handler + """ log_level = getenv(TKR_LOG_LEVEL_KEY, None) if log_level is not None: logger.setLevel(log_level) diff --git a/tierkreis/tierkreis/models.py b/tierkreis/tierkreis/models.py index 1405f4468..326fbb76c 100644 --- a/tierkreis/tierkreis/models.py +++ b/tierkreis/tierkreis/models.py @@ -1,5 +1,7 @@ +"""Tierkreis models for graph builder definitions.""" + from tierkreis.controller.data.core import EmptyModel from tierkreis.controller.data.models import TKR, portmapping from tierkreis.controller.data.types import Struct -__all__ = ["EmptyModel", "TKR", "portmapping", "Struct"] +__all__ = ["TKR", "EmptyModel", "Struct", "portmapping"] diff --git a/tierkreis/tierkreis/namespace.py b/tierkreis/tierkreis/namespace.py index b966572bd..df3cd8cbf 100644 --- a/tierkreis/tierkreis/namespace.py +++ b/tierkreis/tierkreis/namespace.py @@ -1,17 +1,20 @@ +"""Namespace for a tierkreis worker.""" + +import shutil +import subprocess +from collections.abc import Callable from dataclasses import dataclass, field from inspect import Signature, signature from logging import getLogger from pathlib import Path -import shutil -import subprocess -from typing import Callable, Self +from typing import Self + from tierkreis.codegen import format_method, format_model from tierkreis.controller.data.models import PModel, is_portmapping from tierkreis.controller.data.types import Struct, has_default, is_ptype from tierkreis.exceptions import TierkreisError -from tierkreis.idl.spec import spec from tierkreis.idl.models import GenericType, Interface, Method, Model, TypedArg - +from tierkreis.idl.spec import spec logger = getLogger(__name__) WorkerFunction = Callable[..., PModel] @@ -19,21 +22,36 @@ @dataclass class Namespace: + """The namespace of a worker. + + attr name: The name of the namespace. + attr methods: The methods in the namespace. + attr models: The models in the namespace. + """ + name: str - methods: list[Method] = field(default_factory=lambda: []) + methods: list[Method] = field(default_factory=list) models: set[Model] = field(default_factory=lambda: set()) - def add_struct(self, gt: GenericType) -> None: - if not isinstance(gt.origin, Struct) or Model(False, gt, []) in self.models: + def add_struct(self, generic_type: GenericType) -> None: + """Add a struct to the namespace. + + :param generic_type: The generic type to add. + :type generic_type: GenericType + """ + if ( + not isinstance(generic_type.origin, Struct) + or Model(is_portmapping=False, t=generic_type, decls=[]) in self.models + ): return - annotations = gt.origin.__annotations__ + annotations = generic_type.origin.__annotations__ decls = [TypedArg(k, GenericType.from_type(x)) for k, x in annotations.items()] for decl in decls: [self.add_struct(g) for g in decl.t.included_structs()] - portmapping_flag = True if is_portmapping(gt.origin) else False - model = Model(portmapping_flag, gt, decls) + portmapping_flag = bool(is_portmapping(generic_type.origin)) + model = Model(portmapping_flag, generic_type, decls) self.models.add(model) @staticmethod @@ -41,15 +59,22 @@ def _validate_signature(func: WorkerFunction) -> Signature: sig = signature(func) for param in sig.parameters.values(): if not is_ptype(param.annotation): - raise TierkreisError(f"Expected PType got {param.annotation}") + msg = f"Expected PType got {param.annotation}" + raise TierkreisError(msg) out = sig.return_annotation if not is_portmapping(out) and not is_ptype(out) and out is not None: - raise TierkreisError(f"Expected PModel found {out}") + msg = f"Expected PModel found {out}" + raise TierkreisError(msg) return sig def add_function(self, func: WorkerFunction) -> None: + """Add a function to the namespace. + + :param func: The function to add. + :type func: WorkerFunction + """ sig = self._validate_signature(func) method = Method( @@ -63,12 +88,22 @@ def add_function(self, func: WorkerFunction) -> None: ) self.methods.append(method) - for t in func.__annotations__.values(): - [self.add_struct(x) for x in GenericType.from_type(t).included_structs()] + for annotation_type in func.__annotations__.values(): + [ + self.add_struct(struct) + for struct in GenericType.from_type(annotation_type).included_structs() + ] @classmethod def from_spec_file(cls, path: Path) -> "Namespace": - with open(path) as fh: + """Generate a Namespace from a tsp spec file. + + :param path: The path to the spec file. + :type path: Path + :return: The generated namespace. + :rtype: Namespace + """ + with Path.open(path) as fh: namespace_spec = spec(fh.read()) return cls._from_spec(namespace_spec[0]) @@ -77,18 +112,23 @@ def _from_spec(cls, args: tuple[list[Model], Interface]) -> "Self": models = args[0] interface = args[1] namespace = cls(interface.name, models=set(models)) - for f in interface.methods: - model = next((x for x in models if x.t == f.return_type), None) + for method in interface.methods: + model = next((x for x in models if x.t == method.return_type), None) if model is not None: - f.return_type_is_portmapping = model.is_portmapping - namespace.methods.append(f) + method.return_type_is_portmapping = model.is_portmapping + namespace.methods.append(method) return namespace def stubs(self) -> str: - functions = [format_method(self.name, f) for f in self.methods] + """Generate type stubs strings for the namespace. + + :return: The generated stubs as string. + :rtype: str + """ + functions = [format_method(self.name, method) for method in self.methods] functions_str = "\n\n".join(functions) - models_str = "\n\n".join([format_model(x) for x in sorted(list(self.models))]) + models_str = "\n\n".join([format_model(model) for model in sorted(self.models)]) return f'''"""Code generated from {self.name} namespace. Please do not edit.""" @@ -103,17 +143,28 @@ def stubs(self) -> str: ''' def write_stubs(self, stubs_path: Path) -> None: - """Writes the type stubs to stubs_path. + """Write the type stubs to stubs_path. :param stubs_path: The location to write to. :type stubs_path: Path """ - with open(stubs_path, "w+") as fh: + with Path.open(stubs_path, "w+") as fh: fh.write(self.stubs()) ruff_binary = shutil.which("ruff") if ruff_binary: - subprocess.run([ruff_binary, "format", stubs_path]) - subprocess.run([ruff_binary, "check", "--fix", stubs_path]) + subprocess.run([ruff_binary, "format", stubs_path], check=False) + subprocess.run( + [ + ruff_binary, + "check", + "--fix", + "--ignore", + "D,N801,UP007", + "--unsafe-fixes", + stubs_path, + ], + check=False, + ) else: logger.warning("No ruff binary found. Stubs will contain raw codegen.") diff --git a/tierkreis/tierkreis/storage.py b/tierkreis/tierkreis/storage.py index 1c70693e4..4d65734d6 100644 --- a/tierkreis/tierkreis/storage.py +++ b/tierkreis/tierkreis/storage.py @@ -1,43 +1,76 @@ +"""Implementation to access node storage data.""" + from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import PType, ptype_from_bytes -from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.controller.storage.filestorage import ( ControllerFileStorage as FileStorage, ) from tierkreis.controller.storage.in_memory import ( ControllerInMemoryStorage as InMemoryStorage, ) +from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.exceptions import TierkreisError __all__ = ["FileStorage", "InMemoryStorage"] def read_outputs( - g: GraphData | GraphBuilder, storage: ControllerStorage + graph: GraphData | GraphBuilder, + storage: ControllerStorage, ) -> dict[str, PType] | PType: - if isinstance(g, GraphBuilder): - g = g.get_data() + """Read the outputs of a workflow graph. - out_ports = list(g.nodes[g.output_idx()].inputs.keys()) + :param graph: The graph to read. + :type graph: GraphData | GraphBuilder + :param storage: The storage of the workflow run. + :type storage: ControllerStorage + :return: The output values. If the graph has a single output port named "value" it + is returned directly, otherwise a dictionary mapping output port names to their + values is returned. + :rtype: dict[str, PType] | PType + """ + if isinstance(graph, GraphBuilder): + graph = graph.get_data() + + out_ports = list(graph.nodes[graph.output_idx()].inputs.keys()) if len(out_ports) == 1 and "value" in out_ports: return ptype_from_bytes(storage.read_output(Loc(), "value")) return {k: ptype_from_bytes(storage.read_output(Loc(), k)) for k in out_ports} def read_loop_trace( - g: GraphData | GraphBuilder, + graph: GraphData | GraphBuilder, storage: ControllerStorage, node_name: str, output_name: str | None = None, ) -> list[PType | dict[str, list[PType]]]: - """Reads the trace of a loop from storage.""" - if isinstance(g, GraphBuilder): - g = g.get_data() + """Read the trace of a named loop. + + This is useful to track intermediate values in a loop. + + :param graph: The graph to read. + :type graph: GraphData | GraphBuilder + :param storage: The storage of the workflow run. + :type storage: ControllerStorage + :param node_name: The name of the loop node. + :type node_name: str + :param output_name: The name of the output port to trace, defaults to None + :type output_name: str | None, optional + :raises TierkreisError: If the loop name is not found in debug data. + :raises TierkreisError: If the output name is not found in loop node output. + :return: A list of traced values. If output_name is None, each entry is a dict + mapping output port names to their values at each iteration, otherwise a list + of values for the specified output port is returned. + :rtype: list[PType | dict[str, list[PType]]] + """ + if isinstance(graph, GraphBuilder): + graph = graph.get_data() loc = storage.loc_from_node_name(node_name) if loc is None: - raise TierkreisError(f"Loop name {node_name} not found in debug data.") + msg = f"Loop name {node_name} not found in debug data." + raise TierkreisError(msg) output_names = storage.read_output_ports(loc) if output_name is None: traces = { @@ -45,9 +78,13 @@ def read_loop_trace( for name in output_names if name != "should_continue" } - return [dict(zip(traces.keys(), vals)) for vals in zip(*traces.values())] + return [ + dict(zip(traces.keys(), vals, strict=False)) + for vals in zip(*traces.values(), strict=False) + ] if output_name not in output_names: - raise TierkreisError(f"Output name {output_name} not found in loop node output") + msg = f"Output name {output_name} not found in loop node output" + raise TierkreisError(msg) results = storage.read_loop_trace(loc, output_name) return [ptype_from_bytes(r) for r in results] diff --git a/tierkreis/tierkreis/worker/__init__.py b/tierkreis/tierkreis/worker/__init__.py index e69de29bb..3ce15025b 100644 --- a/tierkreis/tierkreis/worker/__init__.py +++ b/tierkreis/tierkreis/worker/__init__.py @@ -0,0 +1 @@ +"""Tierkreis worker package for user defined tasks.""" diff --git a/tierkreis/tierkreis/worker/storage/__init__.py b/tierkreis/tierkreis/worker/storage/__init__.py index e69de29bb..d0bdf3a16 100644 --- a/tierkreis/tierkreis/worker/storage/__init__.py +++ b/tierkreis/tierkreis/worker/storage/__init__.py @@ -0,0 +1 @@ +"""Worker storage implementations.""" diff --git a/tierkreis/tierkreis/worker/storage/filestorage.py b/tierkreis/tierkreis/worker/storage/filestorage.py index cbb250043..1e35ab687 100644 --- a/tierkreis/tierkreis/worker/storage/filestorage.py +++ b/tierkreis/tierkreis/worker/storage/filestorage.py @@ -1,5 +1,7 @@ +"""Filestorage implementation analog to ControllerFileStorage.""" + +# ruff: noqa: D102 (class methods inherited from WorkerStorage) import json -from glob import glob import os from pathlib import Path @@ -8,6 +10,13 @@ class WorkerFileStorage: + """File storage implementation for workers. + + :fields: + tierkreis_dir: The directory to use for storing tierkreis data, + defaults to ~/.tierkreis/checkpoints. + """ + def __init__(self, tierkreis_dir: Path | None = None) -> None: if tierkreis_dir is not None: self.tierkreis_dir = tierkreis_dir @@ -21,23 +30,24 @@ def resolve(self, path: Path | str) -> Path: return path if path.is_absolute() else self.tierkreis_dir / path def read_call_args(self, path: Path) -> WorkerCallArgs: - with open(self.resolve(path), "r") as fh: + with Path.open(self.resolve(path)) as fh: return WorkerCallArgs(**json.loads(fh.read())) def read_input(self, path: Path) -> bytes: - with open(self.resolve(path), "rb") as fh: + with Path.open(self.resolve(path), "rb") as fh: return fh.read() def write_output(self, path: Path, value: bytes) -> None: - with open(self.resolve(path), "wb+") as fh: + with Path.open(self.resolve(path), "wb+") as fh: fh.write(value) def glob(self, path_string: str) -> list[str]: - return glob(str(self.resolve(path_string))) + tmp_path = self.resolve(path_string) + return [str(p) for p in tmp_path.parent.glob(tmp_path.parts[-1])] def mark_done(self, path: Path) -> None: self.resolve(path).touch() def write_error(self, path: Path, error_logs: str) -> None: - with open(self.resolve(path), "w+") as f: + with Path.open(self.resolve(path), "w+") as f: f.write(error_logs) diff --git a/tierkreis/tierkreis/worker/storage/in_memory.py b/tierkreis/tierkreis/worker/storage/in_memory.py index e6a0c3957..1849d348c 100644 --- a/tierkreis/tierkreis/worker/storage/in_memory.py +++ b/tierkreis/tierkreis/worker/storage/in_memory.py @@ -1,3 +1,6 @@ +"""In-memory storage implementation analog to ControllerInMemoryStorage.""" + +# ruff: noqa: D102 (class methods inherited from WorkerStorage) import fnmatch import json import logging @@ -10,11 +13,17 @@ ) from tierkreis.exceptions import TierkreisError - logger = logging.getLogger(__name__) class InMemoryWorkerStorage: + """In-memory storage implementation for workers. + + Delegates calls to the ControllerInMemoryStorage used for the workflow. + :fields: + controller_storage: The controller storage. + """ + def __init__(self, controller_storage: ControllerInMemoryStorage) -> None: self.controller_storage = controller_storage @@ -32,13 +41,13 @@ def write_output(self, path: Path, value: bytes) -> None: self.controller_storage.files[path] = InMemoryFileData(value) def glob(self, path_string: str) -> list[str]: - files = [str(x) for x in self.controller_storage.files.keys()] - matching = fnmatch.filter(files, path_string) - return matching + files = [str(x) for x in self.controller_storage.files] + return fnmatch.filter(files, path_string) def mark_done(self, path: Path) -> None: self.controller_storage.touch(path) - def write_error(self, path: Path, error_logs: str) -> None: + def write_error(self, _: Path, error_logs: str) -> None: logger.error(error_logs) - raise TierkreisError("Error occured when running graph in-memory.") + msg = "Error occurred when running graph in-memory." + raise TierkreisError(msg) diff --git a/tierkreis/tierkreis/worker/storage/protocol.py b/tierkreis/tierkreis/worker/storage/protocol.py index 040cede0a..c3bcca5b3 100644 --- a/tierkreis/tierkreis/worker/storage/protocol.py +++ b/tierkreis/tierkreis/worker/storage/protocol.py @@ -1,3 +1,5 @@ +"""Storage protocol for workers.""" + from pathlib import Path from typing import Protocol @@ -5,10 +7,91 @@ class WorkerStorage(Protocol): - def resolve(self, path: Path | str) -> Path: ... - def read_call_args(self, path: Path) -> WorkerCallArgs: ... - def read_input(self, path: Path) -> bytes: ... - def write_output(self, path: Path, value: bytes) -> None: ... - def glob(self, path_string: str) -> list[str]: ... - def mark_done(self, path: Path) -> None: ... - def write_error(self, path: Path, error_logs: str) -> None: ... + """Storage protocol for workers. + + :abstract: + """ + + def resolve(self, path: Path | str) -> Path: + """Resolve a path or str to a path in the storage. + + Transforms paths such that the storage can correctly read/write to them. + E.g. for a file storage, relative paths are resolved against a base directory. + + :param path: The path to resolve. + :type path: Path | str + :return: The resolved path according to the storage. + :rtype: Path + """ + ... + + def read_call_args(self, path: Path) -> WorkerCallArgs: + """Read the call args of a worker. + + The function name is part of the call args which is then used to + determine which function to call. + + :param path: The path to read from. + :type path: Path + :return: The call args of the worker function. + :rtype: WorkerCallArgs + """ + ... + + def read_input(self, path: Path) -> bytes: + """Read the input to a worker task. + + Input locations are defined in the call args. + + :param path: The path to read from. + :type path: Path + :return: The bytes read from the input location. + :rtype: bytes + """ + ... + + def write_output(self, path: Path, value: bytes) -> None: + """Write the outputs of a worker task. + + Output locations are defined in the call args. + + :param path: The path to write to. + :type path: Path + :param value: The bytes to write. + :type value: bytes + """ + ... + + def glob(self, path_string: str) -> list[str]: + """Get a list of paths matching the path glob. + + Used in map nodes to find all input values for the individual tasks. + + :param path_string: The glob string to match. + :type path_string: str + :return: A list of matching path strings. + :rtype: list[str] + """ + ... + + def mark_done(self, path: Path) -> None: + """Mark the task node as done. + + Done paths are defined in the call args. + + :param path: The path to mark as done. + :type path: Path + """ + ... + + def write_error(self, path: Path, error_logs: str) -> None: + """Write an error to the logs. + + Logs are stored in a location defined in the call args. + + :param path: The path to write the error logs to. + :type path: Path + :param error_logs: The message to write. + :type error_logs: str + """ + ... diff --git a/tierkreis/tierkreis/worker/worker.py b/tierkreis/tierkreis/worker/worker.py index a94d5d890..00ad52a19 100644 --- a/tierkreis/tierkreis/worker/worker.py +++ b/tierkreis/tierkreis/worker/worker.py @@ -1,7 +1,10 @@ -from inspect import Signature, signature +"""Tierkreis worker implementation.""" + import logging +from collections.abc import Callable +from inspect import Signature, signature from pathlib import Path -from typing import Callable, TypeVar +from typing import NoReturn, TypeVar from tierkreis.controller.data.core import PortID from tierkreis.controller.data.location import WorkerCallArgs @@ -28,7 +31,7 @@ class TierkreisWorkerError(TierkreisError): - pass + """Exception raised when a worker encounters an error.""" F = TypeVar("F", bound=Callable[..., PModel]) @@ -38,7 +41,8 @@ class Worker: """A worker bundles a set of functionality under a common namespace. The main usage of a worker is to convert python functions into atomic tasks, - which can then be executed by the :py:class:`tierkreis.controller.executor.uv_executor.UvExecutor` + which can then be executed by the + :py:class:`tierkreis.controller.executor.uv_executor.UvExecutor` or similar Executors. From the worker type stubs can be generated to statically check the function calls. @@ -52,10 +56,14 @@ class Worker: def exp(x: float, a: float) -> float: return value = a * np.exp(x) - :param name: The name of the worker. - :type name: str - :param storage: Storage layer for the worker to interact with the ControllerStorage. - :type storage: WorkerStorage + :fields: + name (str) The name of the worker. + storage (WorkerStorage) Storage layer for the + worker to interact with the ControllerStorage. + namespace (Namespace) The namespace of the worker. + types (dict[MethodName, Signature]) Mapping function names to their signatures. + functions (dict[str, Callable[[WorkerCallArgs], None]]) + Mapping function names to their implementations. """ functions: dict[str, Callable[[WorkerCallArgs], None]] @@ -73,26 +81,33 @@ def __init__(self, name: str, storage: WorkerStorage | None = None) -> None: self.storage = storage def _load_args( - self, f: WorkerFunction, inputs: dict[str, Path] + self, + f: WorkerFunction, + inputs: dict[str, Path], ) -> dict[str, PType]: bs: dict[str, bytes] = {} for k, p in inputs.items(): try: bs[k] = self.storage.read_input(p) - except FileNotFoundError: + except FileNotFoundError as e: if not has_default(self.types[f.__name__].parameters[k]): - raise TierkreisError(f"Input {k} not found at {p}.") + msg = f"Input {k} not found at {p}." + raise TierkreisError(msg) from e args = {} for k, b in bs.items(): args[k] = ptype_from_bytes( - b, self.types[f.__name__].parameters[k].annotation + b, + self.types[f.__name__].parameters[k].annotation, ) return args def _save_results( - self, f: WorkerFunction, outputs: dict[PortID, Path], results: PModel - ): + self, + f: WorkerFunction, + outputs: dict[PortID, Path], + results: PModel, + ) -> None: d = dict_from_pmodel(results) ret = annotations_from_pmodel(signature(f).return_annotation) for result_name, path in outputs.items(): @@ -100,15 +115,24 @@ def _save_results( self.storage.write_output(path, bs) def add_types(self, func: WorkerFunction) -> None: + """Add the types of a function to the worker. + + :param func: The function to add types for. + :type func: WorkerFunction + """ self.types[func.__name__] = signature(func) def primitive_task( self, ) -> Callable[[PrimitiveTask], None]: - """Registers a python function as a primitive task with the worker.""" + """Register a python function as a primitive task with the worker. + + :return: The wrapped task. + :rtype: Callable[[PrimitiveTask], None] + """ def function_decorator(func: PrimitiveTask) -> None: - def wrapper(args: WorkerCallArgs): + def wrapper(args: WorkerCallArgs) -> None: func(args, self.storage) self.functions[func.__name__] = wrapper @@ -116,13 +140,17 @@ def wrapper(args: WorkerCallArgs): return function_decorator def task(self) -> Callable[[F], F]: - """Registers a python function as a task with the worker.""" + """Register a python function as a task with the worker. + + :return: The wrapped function. + :rtype: Callable[[Callable[..., PModel]], Callable[..., PModel]] + """ def function_decorator(func: F) -> F: self.namespace.add_function(func) self.add_types(func) - def wrapper(node_definition: WorkerCallArgs): + def wrapper(node_definition: WorkerCallArgs) -> None: kwargs = self._load_args(func, node_definition.inputs) results = func(**kwargs) self._save_results(func, node_definition.outputs, results) @@ -142,27 +170,42 @@ def run(self, worker_definition_path: Path) -> None: node_definition = self.storage.read_call_args(worker_definition_path) logger.debug(node_definition.model_dump()) + def _check_function(msg: str) -> NoReturn: + raise TierkreisError(msg) + try: function = self.functions.get(node_definition.function_name, None) if function is None: - raise TierkreisError( - f"{self.name}: function name {node_definition.function_name} not found" + msg = ( + f"{self.name}: function name" + f"{node_definition.function_name} not found" ) - logger.info(f"running: {node_definition.function_name} in {self.name}") + _check_function(msg) + logger.info("running: %s in %s", node_definition.function_name, self.name) function(node_definition) self.storage.mark_done(node_definition.done_path) except Exception as err: - logger.error("encountered error", exc_info=err) + logger.exception("encountered error", exc_info=err) self.storage.write_error(node_definition.error_path, str(err)) - raise TierkreisWorkerError( - f"Worker {self.name} encountered error when executing {node_definition.function_name}." + msg = ( + f"Worker {self.name} encountered error when executing " + f"{node_definition.function_name}." ) + raise TierkreisWorkerError( + msg, + ) from err def app(self, argv: list[str]) -> None: - """Wrapper for UV execution.""" + """Run the worker as uv app. + + Either generate stubs or run the worker. + + :param argv: The cli args. + :type argv: list[str] + """ handler = add_handler_from_environment(logger) if argv[1] == "--stubs-path": self.namespace.write_stubs(Path(argv[2]))