From ba610fc680e377ef9b272311702208b9ae181542 Mon Sep 17 00:00:00 2001 From: Philipp Seitz Date: Fri, 23 Jan 2026 15:03:45 +0000 Subject: [PATCH] Add experimental graph operations --- example.py | 64 ++++++++++ tierkreis/tierkreis/controller/data/graph.py | 118 +++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 example.py diff --git a/example.py b/example.py new file mode 100644 index 000000000..c731468a1 --- /dev/null +++ b/example.py @@ -0,0 +1,64 @@ +from pathlib import Path +from uuid import UUID +from tierkreis.builder import GraphBuilder +from tierkreis.controller import run_graph +from tierkreis.controller.data.models import TKR +from tierkreis.controller.executor.uv_executor import UvExecutor +from tierkreis.storage import FileStorage +from tierkreis.storage import read_outputs + + +from tierkreis.builtins.stubs import tkr_str + + +def simple_graph() -> GraphBuilder[TKR[int], TKR[str]]: + g = GraphBuilder(TKR[int], TKR[str]) + a = g.task(tkr_str(g.inputs)) + g.outputs(a) + return g + + +def main() -> None: + graph_a = simple_graph().data + graph_b = simple_graph().data + graph = graph_a + graph_b + + inputs = {"value": 0} + storage = FileStorage( + UUID(int=222), + name="serial_graph", + ) + executor = UvExecutor( + Path(__file__).parent / "tierkreis_workers", storage.logs_path + ) + storage.clean_graph_files() + run_graph( + storage, + executor, + graph, + inputs, + ) + res = read_outputs(graph, storage) + print(res) + graph = graph_a @ graph_b + storage = FileStorage( + UUID(int=223), + name="parallel_graph", + ) + executor = UvExecutor( + Path(__file__).parent / "tierkreis_workers", storage.logs_path + ) + storage.clean_graph_files() + run_graph( + storage, + executor, + graph, + inputs, + ) + res = read_outputs(graph, storage) + print(res) + + +if __name__ == "__main__": + main() + print("All Done") diff --git a/tierkreis/tierkreis/controller/data/graph.py b/tierkreis/tierkreis/controller/data/graph.py index 624dbd572..d593634e1 100644 --- a/tierkreis/tierkreis/controller/data/graph.py +++ b/tierkreis/tierkreis/controller/data/graph.py @@ -101,6 +101,124 @@ class GraphData(BaseModel): graph_output_idx: NodeIndex | None = None named_nodes: dict[str, NodeIndex] = {} + def _find_input_node(self, name: str) -> NodeIndex: + for idx, node in enumerate(self.nodes): + if node.type == "input" and node.name == name: + return idx + raise TierkreisError(f"Input node with name {name} not found.") + + def __matmul__(self, other: "GraphData") -> "GraphData": + output = self.nodes.pop(self.output_idx()) + other_output = other.nodes.pop(other.output_idx()) + new_fixed_inputs = { + port: loc + for port, loc in {**self.fixed_inputs, **other.fixed_inputs}.items() + } + new_graph_inputs = self.graph_inputs.union(other.graph_inputs) + input_mapping = { + other._find_input_node(node.name): self._find_input_node(node.name) + for node in other.nodes + if node.type == "input" + } + offset = len(self.nodes) + new_nodes = self.nodes.copy() + for node in other.nodes: + # Adjust inputs + if node.type == "input" and node.name in self.graph_inputs: + offset -= 1 + continue + + new_inputs = { + port: (idx + offset, p) + if idx not in input_mapping + else (input_mapping[idx], p) + for port, (idx, p) in node.inputs.items() + } + # Adjust outputs + new_outputs = {port: idx + offset for port, idx in node.outputs.items()} + new_node = node.__class__( + **{ + **node.__dict__, + "inputs": new_inputs, + "outputs": new_outputs, + } + ) + new_nodes.append(new_node) + + new_named_nodes = { + name: idx + offset for name, idx in other.named_nodes.items() + } + new_named_nodes.update(self.named_nodes) + ## restore self and other + self.nodes.append(output) + other.nodes.append(other_output) + graph = GraphData( + nodes=new_nodes, + fixed_inputs=new_fixed_inputs, + graph_inputs=new_graph_inputs, + named_nodes=new_named_nodes, + ) + new_output_inputs = { + port: (idx, p) for port, (idx, p) in output.inputs.items() + } | { + f"{port}_b": (idx + offset, p) + for port, (idx, p) in other_output.inputs.items() + } + graph.output(new_output_inputs) + return graph + + def __add__(self, other: "GraphData") -> "GraphData": + output = self.nodes.pop(self.output_idx()) + new_fixed_inputs = { + port: loc + for port, loc in {**self.fixed_inputs, **other.fixed_inputs}.items() + } + new_graph_inputs = self.graph_inputs.copy() + output_mapping = { + other._find_input_node(node.name): output.inputs[node.name] + for node in other.nodes + if node.type == "input" and node.name in output.inputs + } + offset = len(self.nodes) + new_nodes = self.nodes.copy() + for node in other.nodes: + # Adjust inputs + if node.type == "input" and node.name in output.inputs: + offset -= 1 + continue + + new_inputs = { + port: (idx + offset, p) + if idx not in output_mapping + else output_mapping[idx] + for port, (idx, p) in node.inputs.items() + } + # Adjust outputs + new_outputs = {port: idx + offset for port, idx in node.outputs.items()} + new_node = node.__class__( + **{ + **node.__dict__, + "inputs": new_inputs, + "outputs": new_outputs, + } + ) + new_nodes.append(new_node) + + new_named_nodes = { + name: idx + offset for name, idx in other.named_nodes.items() + } + new_named_nodes.update(self.named_nodes) + ## restore self and other + self.nodes.append(output) + graph = GraphData( + nodes=new_nodes, + fixed_inputs=new_fixed_inputs, + graph_inputs=new_graph_inputs, + named_nodes=new_named_nodes, + graph_output_idx=len(new_nodes) - 1, + ) + return graph + def input(self, name: str) -> ValueRef: return self.add(Input(name))(name)