diff --git a/docs/source/logging_and_errors.md b/docs/source/logging_and_errors.md index 877c347e5..b351f2915 100644 --- a/docs/source/logging_and_errors.md +++ b/docs/source/logging_and_errors.md @@ -117,6 +117,40 @@ The data contains the following information: - `env` the current user environment when launching the task. - `packages` a list of installed python packages. Only used for the [`UvExecutor`](#tierkreis.controller.executor.uv_executor.UvExecutor) +### Breakpoints + +All nodes can be declared ar breakpoints by adding `NodeMetaData` as follows: +```python +from tierkreis.controller.data.graph import NodeMetaData +g.task(..., NodeMetaData(has_breakpoint=True)) +``` +By default the controller will ignore this information unless you set `enable_breakpoints=True` in `run_graph`. +When running with breakpoints the graph execution will stop as soon as it hits a breakpoint node e.g. after running: +```python +run_graph( + storage, executor, graph, None, enable_breakpoints=enable_breakpoints +) +``` +you can examine the storage and current values of all previous nodes. +Afterward you can resume the execution with +```python +resume_graph(storage, executor) +``` + +### Debug Mode + +If you want to debug a graph with a python debugger you can use [](#tierkreis.controller.storgare.debug_graph.debug_graph). +It acts similar to `run_workflow` with some defaults enabled: +- Enables all set breakpoints +- Sets up logging +- Adds a specific storage and executor to enable python debugging +```{important} +This will only work with python based workers. +All workers need to be installed locally. +``` + + + ## Visualizer If you're using the visualize to debug workflow, error information will be immediately visible to you. diff --git a/tierkreis/tests/controller/test_breakpoint.py b/tierkreis/tests/controller/test_breakpoint.py new file mode 100644 index 000000000..a364fd1a6 --- /dev/null +++ b/tierkreis/tests/controller/test_breakpoint.py @@ -0,0 +1,54 @@ +from pathlib import Path +from uuid import UUID + +import pytest +from pytket._tket.circuit import Circuit +from tierkreis.builder import GraphBuilder +from tierkreis.builtins import iadd +from tierkreis.controller import resume_graph, run_graph +from tierkreis.controller.data.graph import NodeMetaData +from tierkreis.controller.data.location import Loc +from tierkreis.controller.data.models import TKR +from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor +from tierkreis.controller.storage.filestorage import ControllerFileStorage +from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage +from tierkreis.executor import ShellExecutor +from pytket_worker import n_qubits +from tierkreis.storage import read_outputs + + +def breakpoint_graph() -> GraphBuilder[TKR[Circuit], TKR[int]]: + g = GraphBuilder(TKR[Circuit], TKR[int]) + test = g.const(5) + nq = g.task(n_qubits(g.inputs), NodeMetaData(has_breakpoint=True)) # type: ignore + out = g.task(iadd(test, nq)) + g.outputs(out) + return g + + +storage_classes = [ControllerFileStorage, ControllerInMemoryStorage] +storage_ids = ["FileStorage", "In-memory"] + + +@pytest.mark.parametrize("storage_class", storage_classes, ids=storage_ids) +@pytest.mark.parametrize("enable_breakpoints", [True, False], ids=["True", "False"]) +def test_breakpoint( + storage_class: type[ControllerFileStorage | ControllerInMemoryStorage], + enable_breakpoints: bool, +) -> None: + graph = breakpoint_graph() + storage = storage_class(UUID(int=400), name="breakpoints") + executor = ShellExecutor(registry_path=None, workflow_dir=storage.workflow_dir) + if isinstance(storage, ControllerInMemoryStorage): + executor = InMemoryExecutor(Path("./tierkreis/tierkreis"), storage=storage) + storage.clean_graph_files() + run_graph( + storage, executor, graph, Circuit(2), enable_breakpoints=enable_breakpoints + ) + if enable_breakpoints: + assert not storage.is_node_finished(Loc()) + assert storage.exists(storage._breakpoint(Loc("-.N2"))) + resume_graph(storage, executor) + assert storage.is_node_finished(Loc()) + out = read_outputs(graph, storage) + assert out == 7 diff --git a/tierkreis/tierkreis/builder.py b/tierkreis/tierkreis/builder.py index 7d3c7398d..74f7cd03f 100644 --- a/tierkreis/tierkreis/builder.py +++ b/tierkreis/tierkreis/builder.py @@ -5,6 +5,7 @@ from collections.abc import Callable from copy import copy from dataclasses import dataclass +from functools import partial from inspect import isclass from typing import ( Any, @@ -15,8 +16,8 @@ runtime_checkable, ) -from tierkreis.controller.data.core import EmptyModel -from tierkreis.controller.data.graph import GraphData, ValueRef, reindex_inputs +from tierkreis.controller.data.core import EmptyModel, ValueRef +from tierkreis.controller.data.graph import GraphData, NodeMetaData, reindex_inputs from tierkreis.controller.data.models import ( TKR, TModel, @@ -126,16 +127,23 @@ class GraphBuilder[Inputs: TModel, Outputs: TModel]: outputs_type: type inputs: Inputs + _breakpoints_on_outputs: bool def __init__( self, inputs_type: type[Inputs] = EmptyModel, outputs_type: type[Outputs] = EmptyModel, + breakpoints_on_inputs: bool = False, + breakpoints_on_outputs: bool = False, ) -> None: self.data = GraphData() self.inputs_type = inputs_type self.outputs_type = outputs_type - self.inputs = init_tmodel(self.inputs_type, self.data.input) + input_fn = partial( + self.data.input, metadata=NodeMetaData(breakpoints_on_inputs) + ) + self.inputs = init_tmodel(self.inputs_type, input_fn) + self._breakpoints_on_outputs = breakpoints_on_outputs def get_data(self) -> GraphData: """Return the underlying graph from the builder. @@ -159,7 +167,10 @@ def outputs(self, outputs: Outputs) -> None: :param outputs: The output nodes. :type outputs: Outputs """ - self.data.output(inputs=dict_from_tmodel(outputs)) + self.data.output( + inputs=dict_from_tmodel(outputs), + metadata=NodeMetaData(self._breakpoints_on_outputs), + ) def embed[A: TModel, B: TModel](self, other: "GraphBuilder[A, B]", inputs: A) -> B: if other.data.graph_output_idx is None: @@ -196,7 +207,7 @@ def const[T: PType](self, value: T) -> TKR[T]: :return: The constant value. :rtype: TKR[T] """ - idx, port = self.data.const(value) + idx, port = self.data.const(value, NodeMetaData()) return TKR[T](idx, port) def ifelse[A: PType, B: PType]( @@ -204,6 +215,7 @@ def ifelse[A: PType, B: PType]( pred: TKR[bool], if_true: TKR[A], if_false: TKR[B], + metadata: NodeMetaData | None = None, ) -> TKR[A] | TKR[B]: """Add an if-else node to the graph. @@ -216,6 +228,8 @@ def ifelse[A: PType, B: PType]( :type if_true: TKR[A] :param if_false: The value if the predicate is false. :type if_false: TKR[B] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The outputs of the if-else expression. :rtype: TKR[A] | TKR[B] """ @@ -223,6 +237,7 @@ def ifelse[A: PType, B: PType]( pred.value_ref(), if_true.value_ref(), if_false.value_ref(), + metadata, )("value") return TKR(idx, port) @@ -231,6 +246,7 @@ def eifelse[A: PType, B: PType]( pred: TKR[bool], if_true: TKR[A], if_false: TKR[B], + metadata: NodeMetaData | None = None, ) -> TKR[A] | TKR[B]: """Add an eager if-else node to the graph. @@ -243,6 +259,8 @@ def eifelse[A: PType, B: PType]( :type if_true: TKR[A] :param if_false: The value if the predicate is false. :type if_false: TKR[B] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The outputs of the if-else expression. :rtype: TKR[A] | TKR[B] """ @@ -250,6 +268,7 @@ def eifelse[A: PType, B: PType]( pred.value_ref(), if_true.value_ref(), if_false.value_ref(), + metadata, )("value") return TKR(idx, port) @@ -258,24 +277,30 @@ def _graph_const[A: TModel, B: TModel]( graph: GraphBuilder[A, B], ) -> TypedGraphRef[A, B]: # TODO @philipp-seitz: Turn this into a public method? - idx, port = self.data.const(graph.data.model_dump()) + idx, port = self.data.const(graph.data.model_dump(), NodeMetaData()) return TypedGraphRef[A, B]( graph_ref=(idx, port), outputs_type=graph.outputs_type, inputs_type=graph.inputs_type, ) - def task[Out: TModel](self, func: Function[Out]) -> Out: + def task[Out: TModel]( + self, + func: Function[Out], + metadata: NodeMetaData | None = None, + ) -> Out: """Add a worker task node to the graph. :param func: The worker function. :type func: Function[Out] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The outputs of the task. :rtype: Out """ name = f"{func.namespace}.{func.__class__.__name__}" inputs = dict_from_tmodel(func) - idx, _ = self.data.func(name, inputs)("dummy") + idx, _ = self.data.func(name, inputs, metadata)("dummy") OutModel = func.out() # noqa: N806 return init_tmodel(OutModel, lambda p: (idx, p)) @@ -283,6 +308,7 @@ def eval[A: TModel, B: TModel]( self, body: GraphBuilder[A, B] | TypedGraphRef[A, B], eval_inputs: A, + metadata: NodeMetaData | None = None, ) -> B: """Add a evaluation node to the graph. @@ -293,13 +319,17 @@ def eval[A: TModel, B: TModel]( where A are the input type and B the output type of the graph. :param eval_inputs: The inputs to the graph. :type eval_inputs: A + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The outputs of the evaluation. :rtype: B """ if isinstance(body, GraphBuilder): body = self._graph_const(body) - idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(eval_inputs))("dummy") + idx, _ = self.data.eval( + body.graph_ref, dict_from_tmodel(eval_inputs), metadata + )("dummy") return init_tmodel(body.outputs_type, lambda p: (idx, p)) def loop[A: TModel, B: LoopOutput]( @@ -307,6 +337,7 @@ def loop[A: TModel, B: LoopOutput]( body: TypedGraphRef[A, B] | GraphBuilder[A, B], loop_inputs: A, name: str | None = None, + metadata: NodeMetaData | None = None, ) -> B: """Add a loop node to the graph. @@ -321,6 +352,8 @@ def loop[A: TModel, B: LoopOutput]( :type loop_inputs: A :param name: An optional name for the loop. :type name: str | None + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The outputs of the loop. :rtype: B """ @@ -333,6 +366,7 @@ def loop[A: TModel, B: LoopOutput]( dict_from_tmodel(loop_inputs), "should_continue", name, + metadata=metadata, )( "dummy", ) @@ -369,9 +403,10 @@ def _map_graph_full[A: TModel, B: TModel]( self, map_inputs: TList[A], body: TypedGraphRef[A, B], + metadata: NodeMetaData | None = None, ) -> TList[B]: ins = dict_from_tmodel(map_inputs._value) # noqa: SLF001 - idx, _ = self.data.map(body.graph_ref, ins)("x") + idx, _ = self.data.map(body.graph_ref, ins, metadata)("x") return TList(init_tmodel(body.outputs_type, lambda s: (idx, s + "-*"))) @@ -382,6 +417,7 @@ def map[A: PType, B: TNamedModel]( Callable[[TKR[A]], B] | TypedGraphRef[TKR[A], B] | GraphBuilder[TKR[A], B] ), map_inputs: TKR[list[A]], + metadata: NodeMetaData | None = None, ) -> TList[B]: ... @overload @@ -391,6 +427,7 @@ def map[A: TNamedModel, B: PType]( Callable[[A], TKR[B]] | TypedGraphRef[A, TKR[B]] | GraphBuilder[A, TKR[B]] ), map_inputs: TList[A], + metadata: NodeMetaData | None = None, ) -> TKR[list[B]]: ... @overload @@ -398,6 +435,7 @@ def map[A: TNamedModel, B: TNamedModel]( self, body: TypedGraphRef[A, B] | GraphBuilder[A, B], map_inputs: TList[A], + metadata: NodeMetaData | None = None, ) -> TList[B]: ... @overload @@ -409,12 +447,14 @@ def map[A: PType, B: PType]( | GraphBuilder[TKR[A], TKR[B]] ), map_inputs: TKR[list[A]], + metadata: NodeMetaData | None = None, ) -> TKR[list[B]]: ... def map( self, body: TypedGraphRef | Callable | GraphBuilder, map_inputs: TKR | TList, + metadata: NodeMetaData | None = None, ) -> Any: """Add a map node to the graph. @@ -422,6 +462,8 @@ def map( :type body: TypedGraphRef | Callable | GraphBuilder :param map_inputs: The values to map over. :type map_inputs: TKR | TList + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The outputs of the map. :rtype: Any """ @@ -437,7 +479,7 @@ def map( if isinstance(map_inputs, TKR): map_inputs = self._unfold_list(map_inputs) - out = self._map_graph_full(map_inputs, body) + out = self._map_graph_full(map_inputs, body, metadata) if not isclass(body.outputs_type) or not issubclass( body.outputs_type, diff --git a/tierkreis/tierkreis/controller/__init__.py b/tierkreis/tierkreis/controller/__init__.py index a989f7607..b774fcae1 100644 --- a/tierkreis/tierkreis/controller/__init__.py +++ b/tierkreis/tierkreis/controller/__init__.py @@ -15,9 +15,10 @@ from tierkreis.controller.data.models import TModel from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes from tierkreis.controller.executor.protocol import ControllerExecutor -from tierkreis.controller.start import NodeRunData, start, start_nodes +from tierkreis.controller.start import start, start_nodes from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.controller.storage.walk import walk_node +from tierkreis.controller.storage.walk_result import NodeRunData from tierkreis.exceptions import TierkreisError from tierkreis.logger_setup import set_tkr_logger @@ -37,6 +38,7 @@ def run_graph[A: TModel, B: TModel]( polling_interval_seconds: float = 0.01, *, enable_logging: bool = True, + enable_breakpoints: bool = False, ) -> None: """Start a graph execution. @@ -59,6 +61,8 @@ def run_graph[A: TModel, B: TModel]( :type polling_interval_seconds: float, optional :param enable_logging: Whether to enable logging, defaults to True :type enable_logging: bool, optional + :param enable_breakpoints: Whether to enable breakpoint nodes, defaults to False + :type enable_breakpoints: bool, optional :raises TierkreisError: If the graph encounters errors during execution. """ if isinstance(g, GraphBuilder): @@ -86,9 +90,17 @@ def run_graph[A: TModel, B: TModel]( inputs: dict[PortID, ValueRef] = { k: (-1, k) for k, _ in graph_inputs.items() if k != "body" } - node_run_data = NodeRunData(Loc(), Eval((-1, "body"), inputs), []) + node_run_data = NodeRunData( + node_location=Loc(), node=Eval((-1, "body"), inputs), output_list=[] + ) start(storage, executor, node_run_data) - resume_graph(storage, executor, n_iterations, polling_interval_seconds) + resume_graph( + storage, + executor, + n_iterations, + polling_interval_seconds, + enable_breakpoints=enable_breakpoints, + ) def resume_graph( @@ -96,6 +108,8 @@ def resume_graph( executor: ControllerExecutor, n_iterations: int = 10000, polling_interval_seconds: float = 0.01, + *, + enable_breakpoints: bool = False, ) -> None: """Resume a graph after initial start. @@ -113,16 +127,21 @@ def resume_graph( :type n_iterations: int, optional :param polling_interval_seconds: The polling interval in seconds, defaults to 0.01 :type polling_interval_seconds: float, optional + :param enable_breakpoints: Whether to enable breakpoint nodes, defaults to False + :type enable_breakpoints: bool, optional :raises TierkreisError: If the graph encounters errors during execution. """ message = storage.read_output(Loc().N(-1), "body") graph = ptype_from_bytes(message, GraphData) for _ in range(n_iterations): - walk_results = walk_node(storage, Loc(), graph.output_idx(), graph) + walk_results = None + if enable_breakpoints: + walk_results = storage.read_breakpoints() + if walk_results is None: + walk_results = walk_node(storage, Loc(), graph.output_idx(), graph) if walk_results.errored != []: - # TODO: add to base class after storage refactor - (storage.logs_path.parent / "-" / "_error").touch() + storage.touch(storage._error_path(Loc())) node_errors = "\n".join(x for x in walk_results.errored) storage.write_node_errors(Loc(), node_errors) @@ -139,6 +158,10 @@ def resume_graph( msg = "Graph encountered errors" raise TierkreisError(msg) + if enable_breakpoints and (walk_results.breaks is not None): + storage.write_breakpoint(walk_results.breaks, walk_results) + break + start_nodes(storage, executor, walk_results.inputs_ready) if storage.is_node_finished(Loc()): storage.write_workflow_completion_time() diff --git a/tierkreis/tierkreis/controller/data/graph.py b/tierkreis/tierkreis/controller/data/graph.py index 690de2735..98be830ab 100644 --- a/tierkreis/tierkreis/controller/data/graph.py +++ b/tierkreis/tierkreis/controller/data/graph.py @@ -24,6 +24,13 @@ logger = logging.getLogger(__name__) +@dataclass +class NodeMetaData: + """Metadata stored for each node""" + + has_breakpoint: bool = False + + @dataclass class NodeDefBase: """Map each out-port to the list of nodes that use it.""" @@ -285,31 +292,37 @@ class GraphData(BaseModel): graph_inputs: set[PortID] = set() graph_output_idx: NodeIndex | None = None named_nodes: dict[str, NodeIndex] = {} + node_metadata: dict[NodeIndex, NodeMetaData] = {} - def input(self, name: str) -> ValueRef: + def input(self, name: str, metadata: NodeMetaData | None = None) -> ValueRef: """Add an input name. :param name: The name of the input. :type name: str + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The reference to that value. :rtype: ValueRef """ - return self.add(Input(name))(name) + return self.add(Input(name), metadata)(name) - def const(self, value: PType) -> ValueRef: + def const(self, value: PType, metadata: NodeMetaData | None = None) -> ValueRef: """Add a constant value. :param value: The value to add. :type value: PType + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: The reference to that value. :rtype: ValueRef """ - return self.add(Const(value))("value") + return self.add(Const(value), metadata)("value") def func( self, function_name: str, inputs: dict[PortID, ValueRef], + metadata: NodeMetaData | None = None, ) -> Callable[[PortID], ValueRef]: """Add a function node (task). @@ -317,15 +330,18 @@ def func( :type function_name: str :param inputs: The mapping of the input values. :type inputs: dict[PortID, ValueRef] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: A function returning index given an output. :rtype: Callable[[PortID], ValueRef] """ - return self.add(Func(function_name, inputs)) + return self.add(Func(function_name, inputs), metadata) def eval( self, graph: ValueRef, inputs: dict[PortID, ValueRef], + metadata: NodeMetaData | None = None, ) -> Callable[[PortID], ValueRef]: """Add an eval node. @@ -333,10 +349,12 @@ def eval( :type graph: ValueRef :param inputs: The mapping of the input values. :type inputs: dict[PortID, ValueRef] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: A function returning index given an output. :rtype: Callable[[PortID], ValueRef] """ - return self.add(Eval(graph, inputs)) + return self.add(Eval(graph, inputs), metadata) def loop( self, @@ -344,6 +362,7 @@ def loop( inputs: dict[PortID, ValueRef], continue_port: PortID, name: str | None = None, + metadata: NodeMetaData | None = None, ) -> Callable[[PortID], ValueRef]: """Add a loop node. @@ -355,15 +374,18 @@ def loop( :type continue_port: PortID :param name: Name of the loop for tracing, defaults to None :type name: str | None, optional + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: A function returning index given an output. :rtype: Callable[[PortID], ValueRef] """ - return self.add(Loop(body, inputs, continue_port, name=name)) + return self.add(Loop(body, inputs, continue_port, name=name), metadata) def map( self, body: ValueRef, inputs: dict[PortID, ValueRef], + metadata: NodeMetaData | None = None, ) -> Callable[[PortID], ValueRef]: """Add a map node. @@ -371,16 +393,19 @@ def map( :type body: ValueRef :param inputs: The mapping of the input values. :type inputs: dict[PortID, ValueRef] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: A function returning index given an output. :rtype: Callable[[PortID], ValueRef] """ - return self.add(Map(body, inputs)) + return self.add(Map(body, inputs), metadata) def if_else( self, pred: ValueRef, if_true: ValueRef, if_false: ValueRef, + metadata: NodeMetaData | None = None, ) -> Callable[[PortID], ValueRef]: """Add an lazy if else node. @@ -390,16 +415,19 @@ def if_else( :type if_true: ValueRef :param if_false: The graph/value for the false branch. :type if_false: ValueRef + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: A function returning index given an output. :rtype: Callable[[PortID], ValueRef] """ - return self.add(IfElse(pred, if_true, if_false)) + return self.add(IfElse(pred, if_true, if_false), metadata) def eager_if_else( self, pred: ValueRef, if_true: ValueRef, if_false: ValueRef, + metadata: NodeMetaData | None = None, ) -> Callable[[PortID], ValueRef]: """Add an eager if else node. @@ -409,26 +437,37 @@ def eager_if_else( :type if_true: ValueRef :param if_false: The graph/value for the false branch. :type if_false: ValueRef + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :return: A function returning index given an output. :rtype: Callable[[PortID], ValueRef] """ - return self.add(EagerIfElse(pred, if_true, if_false)) + return self.add(EagerIfElse(pred, if_true, if_false), metadata) - def output(self, inputs: dict[PortID, ValueRef]) -> None: + def output( + self, inputs: dict[PortID, ValueRef], metadata: NodeMetaData | None = None + ) -> None: """Add an output node. Computation -> output. :param inputs: The inputs of the outup node. :type inputs: dict[PortID, ValueRef] + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional """ - _ = self.add(Output(inputs)) - def add(self, node: NodeDef) -> Callable[[PortID], ValueRef]: + _ = self.add(Output(inputs), metadata) + + def add( + self, node: NodeDef, metadata: NodeMetaData | None = None + ) -> Callable[[PortID], ValueRef]: """Add a node to the graph. :param node: The node to add. :type node: NodeDef + :param metadata: Optional metadata for the node, defaults to None + :type metadata: NodeMetaData | None, optional :raises TierkreisError: If multiple outputs are added. :return: A function given the output name of a node returns the index of the node it corresponds to. @@ -454,6 +493,8 @@ def add(self, node: NodeDef) -> Callable[[PortID], ValueRef]: self.named_nodes[node.name] = idx case _: assert_never(node) + if metadata is not None: + self.node_metadata[idx] = metadata for i, port in in_edges(node).values(): self.nodes[i].outputs.setdefault(port, []).append(idx) diff --git a/tierkreis/tierkreis/controller/executor/check_launcher.py b/tierkreis/tierkreis/controller/executor/check_launcher.py index c99930f1a..78e510510 100644 --- a/tierkreis/tierkreis/controller/executor/check_launcher.py +++ b/tierkreis/tierkreis/controller/executor/check_launcher.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Literal +from tierkreis.consts import PACKAGE_PATH from tierkreis.exceptions import TierkreisError logger = logging.getLogger(__name__) @@ -41,6 +42,8 @@ def check_and_set_launcher( :return: The full path to the worker executable. :rtype: Path """ + if launcher_name == "builtins": + return Path(PACKAGE_PATH / "tierkreis" / "builtins" / "main.py") if check_shell or launcher_path is None: if (path := shutil.which(launcher_name)) is not None: return Path(path) diff --git a/tierkreis/tierkreis/controller/executor/in_memory_executor.py b/tierkreis/tierkreis/controller/executor/in_memory_executor.py index 889114deb..799770b5e 100644 --- a/tierkreis/tierkreis/controller/executor/in_memory_executor.py +++ b/tierkreis/tierkreis/controller/executor/in_memory_executor.py @@ -1,6 +1,7 @@ """In memory implementation.""" # ruff: noqa: D102 (class methods inherited from ControllerExecutor) +from importlib.machinery import ModuleSpec import importlib.util import json import logging @@ -42,8 +43,10 @@ def run( call_args = WorkerCallArgs( **json.loads(self.storage.read(worker_call_args_path)), ) - launcher_path = check_and_set_launcher(self.registry_path, launcher_name, ".py") - spec = importlib.util.spec_from_file_location("in_memory", launcher_path) + launcher_path = check_and_set_launcher( + self.registry_path, launcher_name, ".py", check_shell=True + ) + spec = _find_spec(launcher_path) if spec is None or spec.loader is None: msg = ( f"Couldn't load module main.py in {self.registry_path / launcher_name}" @@ -64,3 +67,13 @@ def _generate_debug_data(self, launcher_path: Path) -> ExecutorDebugData: executor=str(__class__), launch_command=str(launcher_path), ) + + +def _find_spec(launcher_path: Path) -> ModuleSpec | None: + if launcher_path.suffix == ".py": + return importlib.util.spec_from_file_location("in_memory", launcher_path) + launcher = launcher_path.name + if not launcher.startswith("tkr"): + launcher = "tkr-" + launcher + + return importlib.util.find_spec((launcher + "_impl").replace("-", "_")) diff --git a/tierkreis/tierkreis/controller/start.py b/tierkreis/tierkreis/controller/start.py index a1ebbe8e1..9c4b45b80 100644 --- a/tierkreis/tierkreis/controller/start.py +++ b/tierkreis/tierkreis/controller/start.py @@ -3,13 +3,12 @@ import logging import subprocess import sys -from dataclasses import dataclass from pathlib import Path from typing import assert_never from tierkreis.consts import PACKAGE_PATH from tierkreis.controller.data.core import PortID -from tierkreis.controller.data.graph import Eval, GraphData, NodeDef +from tierkreis.controller.data.graph import Eval, GraphData from tierkreis.controller.data.location import Loc, OutputLoc from tierkreis.controller.data.types import bytes_from_ptype, ptype_from_bytes from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor @@ -18,27 +17,13 @@ from tierkreis.controller.storage.data import ExecutorDebugData from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage from tierkreis.controller.storage.protocol import ControllerStorage +from tierkreis.controller.storage.walk_result import NodeRunData from tierkreis.exceptions import TierkreisError from tierkreis.labels import Labels logger = logging.getLogger(__name__) -@dataclass -class NodeRunData: - """Data required to run a node. - - :fields: - node_location (Loc): The location of the node to run. - node (NodeDef): The node definition to run. - output_list (list[PortID]): The list of output port ids for the node. - """ - - node_location: Loc - node: NodeDef - output_list: list[PortID] - - def start_nodes( storage: ControllerStorage, executor: ControllerExecutor, @@ -179,13 +164,13 @@ def start( storage, executor, NodeRunData( - node_location.L(0), - Eval( + node_location=node_location.L(0), + node=Eval( (-1, "body"), {k: (-1, k) for k, _ in ins.items()}, outputs=node.outputs, ), - output_list, + output_list=output_list, ), ) diff --git a/tierkreis/tierkreis/controller/storage/debug_graph.py b/tierkreis/tierkreis/controller/storage/debug_graph.py new file mode 100644 index 000000000..96a4c57d0 --- /dev/null +++ b/tierkreis/tierkreis/controller/storage/debug_graph.py @@ -0,0 +1,49 @@ +"""Debug wrapper for the main controller.""" + +from uuid import UUID + +from tierkreis.builder import GraphBuilder +from tierkreis.consts import WORKERS_DIR +from tierkreis.controller import run_graph +from tierkreis.controller.data.graph import GraphData +from tierkreis.controller.data.models import TModel +from tierkreis.controller.data.types import PType +from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor +from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage + + +def debug_graph[A: TModel, B: TModel]( + g: GraphData | GraphBuilder[A, B], + graph_inputs: dict[str, PType] | PType, + n_iterations: int = 10000, + polling_interval_seconds: float = 0.01, +) -> None: + """Start a graph execution in debugging mode,. + + Uses debugging mode by setting up in memory storage and executor. + This also means that only python code can be run. + All workers must be available locally, too. + + :param g: The graph to run. + :type g: GraphData | GraphBuilder[A, B] + :param graph_inputs: The inputs to the graph. + If a single PType is provided, it will be provided as the input "value". + :type graph_inputs: dict[str, PType] | PType + :param n_iterations: The maximum number of iterations to run the graph, + defaults to 10000 + :type n_iterations: int, optional + :param polling_interval_seconds: The polling interval in seconds, defaults to 0.01 + :type polling_interval_seconds: float, optional + """ + storage = ControllerInMemoryStorage(UUID(int=0)) + executor = InMemoryExecutor(WORKERS_DIR, storage) + run_graph( + storage, + executor, + g, + graph_inputs, + n_iterations, + polling_interval_seconds, + enable_logging=True, + enable_breakpoints=True, + ) diff --git a/tierkreis/tierkreis/controller/storage/protocol.py b/tierkreis/tierkreis/controller/storage/protocol.py index 372f2e3cc..15e470664 100644 --- a/tierkreis/tierkreis/controller/storage/protocol.py +++ b/tierkreis/tierkreis/controller/storage/protocol.py @@ -20,6 +20,7 @@ WorkflowMetaData, ) from tierkreis.controller.storage.exceptions import EntryNotFoundError +from tierkreis.controller.storage.walk_result import WalkResult from tierkreis.exceptions import TierkreisError logger = logging.getLogger(__name__) @@ -187,6 +188,9 @@ def logs_path(self) -> Path: # noqa: D102 documented in class def debug_path(self) -> Path: # noqa: D102 documented in class return self.workflow_dir / "debug" + def _breakpoint(self, node_location: Loc) -> Path: + return self.debug_path / "breakpoints" / str(node_location) + def _exec_data_path(self, node_location: Loc) -> Path: return self.debug_path / "executors" / str(node_location) @@ -737,3 +741,25 @@ def restart_task(self, loc: Loc) -> list[Loc]: self.delete(self._nodedef_path(loc)) return list(deps) + + def write_breakpoint(self, node_location: Loc, walk_result: WalkResult) -> None: + walk_result.breaks = None + self.write( + self._breakpoint(node_location), walk_result.model_dump_json().encode() + ) + + def read_breakpoints(self) -> WalkResult | None: + result = WalkResult(inputs_ready=[], started=[]) + if not self._breakpoint(Loc()).parent.exists(): + return None + for path in self.list_subpaths(self._breakpoint(Loc()).parent): + if path.is_dir(): + msg = f"Breakpoint directory contains dir at {path}" + raise TierkreisError(msg) + data = json.loads(self.read(path)) + result.extend(WalkResult(**data)) + path.unlink() + self._breakpoint(Loc()).parent.rmdir() + if result == WalkResult(inputs_ready=[], started=[]): + return None + return result diff --git a/tierkreis/tierkreis/controller/storage/walk.py b/tierkreis/tierkreis/controller/storage/walk.py index dfe9f4ae1..e44a110ab 100644 --- a/tierkreis/tierkreis/controller/storage/walk.py +++ b/tierkreis/tierkreis/controller/storage/walk.py @@ -6,7 +6,6 @@ they can be started. """ -from dataclasses import dataclass, field from logging import getLogger from typing import assert_never @@ -22,42 +21,14 @@ ) from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import ptype_from_bytes -from tierkreis.controller.start import NodeRunData from tierkreis.controller.storage.adjacency import outputs_iter, unfinished_inputs from tierkreis.controller.storage.protocol import ControllerStorage +from tierkreis.controller.storage.walk_result import WalkResult, NodeRunData from tierkreis.labels import Labels logger = getLogger(__name__) -@dataclass -class WalkResult: - """Dataclass to keep track of the nodes we encounter during the walk. - - :fields: - inputs_ready (list[NodeRunData]): A list of nodes that now have all inputs ready - and therefore can be started. - started (list[Loc]): A list of locations that have been started (on this walk). - errored (list[Loc]): A list of locations that have encountered an error. - """ - - inputs_ready: list[NodeRunData] - started: list[Loc] - errored: list[Loc] = field(default_factory=list[Loc]) - - def extend(self, walk_result: "WalkResult") -> None: - """Extend a current walk result with an existing one. - - Simply extends all three list fields accordingly. - - :param walk_result: The walk_result to update self with. - :type walk_result: WalkResult - """ - self.inputs_ready.extend(walk_result.inputs_ready) - self.started.extend(walk_result.started) - self.errored.extend(walk_result.errored) - - def unfinished_results( result: WalkResult, storage: ControllerStorage, @@ -125,19 +96,23 @@ def walk_node( # we immediately stop if a node has an error and bubble the error up logger.error("Node %s has encountered an error.", loc) logger.debug("\n\n%s\n\n", storage.read_errors(loc)) - return WalkResult([], [], [loc]) + return WalkResult(inputs_ready=[], started=[], errored=[loc]) node = graph.nodes[idx] - node_run_data = NodeRunData(loc, node, list(node.outputs)) + node_run_data = NodeRunData( + node_location=loc, node=node, output_list=list(node.outputs) + ) - result = WalkResult([], []) + result = WalkResult(inputs_ready=[], started=[]) if unfinished_results(result, storage, parent, node, graph): # cannot start, don't have all inputs yet return result if not storage.is_node_started(loc): # have all inputs, start current node - return WalkResult([node_run_data], []) + if idx in graph.node_metadata and graph.node_metadata[idx].has_breakpoint: + return WalkResult(inputs_ready=[node_run_data], started=[], breaks=loc) + return WalkResult(inputs_ready=[node_run_data], started=[]) # Handle cases where we have nested graphs. # Basically we have to forward the now available outputs from outer scope @@ -151,10 +126,10 @@ def walk_node( return walk_node(storage, loc, g.output_idx(), g) case "output": - return WalkResult([node_run_data], []) + return WalkResult(inputs_ready=[node_run_data], started=[]) case "const": - return WalkResult([node_run_data], []) + return WalkResult(inputs_ready=[node_run_data], started=[]) case "loop": return walk_loop(storage, parent, idx, node) @@ -170,7 +145,7 @@ def walk_node( if storage.is_node_finished(next_loc): storage.link_outputs(loc, Labels.VALUE, next_loc, next_node[1]) storage.mark_node_finished(loc) - return WalkResult([], []) + return WalkResult(inputs_ready=[], started=[]) return walk_node(storage, parent, next_node[0], graph) case "eifelse": @@ -178,10 +153,10 @@ def walk_node( case "function": # Current node can start, done will be marked by executor. - return WalkResult([], [loc]) + return WalkResult(inputs_ready=[], started=[loc]) case "input": - return WalkResult([], []) + return WalkResult(inputs_ready=[], started=[]) case _: assert_never(node) @@ -219,7 +194,7 @@ def walk_loop( """ loc = parent.N(idx) if storage.is_node_finished(loc): - return WalkResult([], []) + return WalkResult(inputs_ready=[], started=[]) # find the last iteration new_location = storage.latest_loop_iteration(loc) @@ -241,18 +216,18 @@ def walk_loop( for k in loop_outputs: storage.link_outputs(loc, k, new_location, k) storage.mark_node_finished(loc) - return WalkResult([], []) + return WalkResult(inputs_ready=[], started=[]) # continue looping, provide the inputs for the next iter from the current ins = {k: (-1, k) for k in loop.inputs} ins.update(loop_outputs) # Mark the next node as ready node_run_data = NodeRunData( - loc.L(new_location.peek_index() + 1), - Eval((-1, BODY_PORT), ins, outputs=loop.outputs), - list(loop_outputs.keys()), + node_location=loc.L(new_location.peek_index() + 1), + node=Eval((-1, BODY_PORT), ins, outputs=loop.outputs), + output_list=list(loop_outputs.keys()), ) - return WalkResult([node_run_data], []) + return WalkResult(inputs_ready=[node_run_data], started=[]) def walk_map( @@ -285,7 +260,7 @@ def walk_map( :rtype: WalkResult """ loc = parent.N(idx) - result = WalkResult([], []) + result = WalkResult(inputs_ready=[], started=[]) if storage.is_node_finished(loc): return result @@ -343,4 +318,4 @@ def walk_eifelse( storage.link_outputs(loc, Labels.VALUE, next_loc, next_node[1]) storage.mark_node_finished(loc) - return WalkResult([], []) + return WalkResult(inputs_ready=[], started=[]) diff --git a/tierkreis/tierkreis/controller/storage/walk_result.py b/tierkreis/tierkreis/controller/storage/walk_result.py new file mode 100644 index 000000000..a983b8d74 --- /dev/null +++ b/tierkreis/tierkreis/controller/storage/walk_result.py @@ -0,0 +1,49 @@ +from pydantic import BaseModel, Field + +from tierkreis.controller.data.core import PortID +from tierkreis.controller.data.graph import NodeDef +from tierkreis.controller.data.location import Loc + + +class NodeRunData(BaseModel): + """Data required to run a node. + + :fields: + node_location (Loc): The location of the node to run. + node (NodeDef): The node definition to run. + output_list (list[PortID]): The list of output port ids for the node. + """ + + node_location: Loc + node: NodeDef + output_list: list[PortID] + + +class WalkResult(BaseModel): + """Dataclass to keep track of the nodes we encounter during the walk. + + :fields: + inputs_ready (list[NodeRunData]): A list of nodes that now have all inputs ready + and therefore can be started. + started (list[Loc]): A list of locations that have been started (on this walk). + errored (list[Loc]): A list of locations that have encountered an error. + """ + + inputs_ready: list[NodeRunData] + started: list[Loc] + errored: list[Loc] = Field(default_factory=list[Loc]) + breaks: Loc | None = None + + def extend(self, walk_result: "WalkResult") -> None: + """Extend a current walk result with an existing one. + + Simply extends all three list fields accordingly. + + :param walk_result: The walk_result to update self with. + :type walk_result: WalkResult + """ + self.inputs_ready.extend(walk_result.inputs_ready) + self.started.extend(walk_result.started) + self.errored.extend(walk_result.errored) + if self.breaks is None: + self.breaks = walk_result.breaks diff --git a/tierkreis_visualization/tierkreis_visualization/data/eval.py b/tierkreis_visualization/tierkreis_visualization/data/eval.py index 19bb5dc6f..97ea947e2 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/eval.py +++ b/tierkreis_visualization/tierkreis_visualization/data/eval.py @@ -2,10 +2,9 @@ from typing import assert_never from tierkreis.controller.data.core import NodeIndex -from tierkreis.controller.data.graph import GraphData, IfElse +from tierkreis.controller.data.graph import GraphData, IfElse, in_edges from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import ptype_from_bytes -from tierkreis.controller.storage.adjacency import in_edges from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.exceptions import TierkreisError from tierkreis_visualization.data.models import NodeStatus, PyEdge, PyNode @@ -73,7 +72,6 @@ def get_eval_node( pynodes: list[PyNode] = [] py_edges: list[PyEdge] = [] - for i, node in enumerate(graph.nodes): new_location = node_location.N(i) @@ -120,8 +118,6 @@ def get_eval_node( pynodes.append(pynode) for p0, (idx, p1) in in_edges(node).items(): - value: str | None = None - try: value = outputs_from_loc(storage, node_location.N(idx), p1) except (FileNotFoundError, TierkreisError, UnicodeDecodeError):