Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pathlib import Path
from uuid import UUID
from tierkreis.builder import GraphBuilder
from tierkreis.controller import run_graph
from tierkreis.controller.data.models import TKR
from tierkreis.controller.executor.uv_executor import UvExecutor
from tierkreis.storage import FileStorage
from tierkreis.storage import read_outputs


from tierkreis.builtins.stubs import tkr_str


def simple_graph() -> GraphBuilder[TKR[int], TKR[str]]:
g = GraphBuilder(TKR[int], TKR[str])
a = g.task(tkr_str(g.inputs))
g.outputs(a)
return g


def main() -> None:
graph_a = simple_graph().data
graph_b = simple_graph().data
graph = graph_a + graph_b

inputs = {"value": 0}
storage = FileStorage(
UUID(int=222),
name="serial_graph",
)
executor = UvExecutor(
Path(__file__).parent / "tierkreis_workers", storage.logs_path
)
storage.clean_graph_files()
run_graph(
storage,
executor,
graph,
inputs,
)
res = read_outputs(graph, storage)
print(res)
graph = graph_a @ graph_b
storage = FileStorage(
UUID(int=223),
name="parallel_graph",
)
executor = UvExecutor(
Path(__file__).parent / "tierkreis_workers", storage.logs_path
)
storage.clean_graph_files()
run_graph(
storage,
executor,
graph,
inputs,
)
res = read_outputs(graph, storage)
print(res)


if __name__ == "__main__":
main()
print("All Done")
118 changes: 118 additions & 0 deletions tierkreis/tierkreis/controller/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,124 @@ class GraphData(BaseModel):
graph_output_idx: NodeIndex | None = None
named_nodes: dict[str, NodeIndex] = {}

def _find_input_node(self, name: str) -> NodeIndex:
for idx, node in enumerate(self.nodes):
if node.type == "input" and node.name == name:
return idx
raise TierkreisError(f"Input node with name {name} not found.")

def __matmul__(self, other: "GraphData") -> "GraphData":
output = self.nodes.pop(self.output_idx())
other_output = other.nodes.pop(other.output_idx())
new_fixed_inputs = {
port: loc
for port, loc in {**self.fixed_inputs, **other.fixed_inputs}.items()
}
new_graph_inputs = self.graph_inputs.union(other.graph_inputs)
input_mapping = {
other._find_input_node(node.name): self._find_input_node(node.name)
for node in other.nodes
if node.type == "input"
}
offset = len(self.nodes)
new_nodes = self.nodes.copy()
for node in other.nodes:
# Adjust inputs
if node.type == "input" and node.name in self.graph_inputs:
offset -= 1
continue

new_inputs = {
port: (idx + offset, p)
if idx not in input_mapping
else (input_mapping[idx], p)
for port, (idx, p) in node.inputs.items()
}
# Adjust outputs
new_outputs = {port: idx + offset for port, idx in node.outputs.items()}
new_node = node.__class__(
**{
**node.__dict__,
"inputs": new_inputs,
"outputs": new_outputs,
}
)
new_nodes.append(new_node)

new_named_nodes = {
name: idx + offset for name, idx in other.named_nodes.items()
}
new_named_nodes.update(self.named_nodes)
## restore self and other
self.nodes.append(output)
other.nodes.append(other_output)
graph = GraphData(
nodes=new_nodes,
fixed_inputs=new_fixed_inputs,
graph_inputs=new_graph_inputs,
named_nodes=new_named_nodes,
)
new_output_inputs = {
port: (idx, p) for port, (idx, p) in output.inputs.items()
} | {
f"{port}_b": (idx + offset, p)
for port, (idx, p) in other_output.inputs.items()
}
graph.output(new_output_inputs)
return graph

def __add__(self, other: "GraphData") -> "GraphData":
output = self.nodes.pop(self.output_idx())
new_fixed_inputs = {
port: loc
for port, loc in {**self.fixed_inputs, **other.fixed_inputs}.items()
}
new_graph_inputs = self.graph_inputs.copy()
output_mapping = {
other._find_input_node(node.name): output.inputs[node.name]
for node in other.nodes
if node.type == "input" and node.name in output.inputs
}
offset = len(self.nodes)
new_nodes = self.nodes.copy()
for node in other.nodes:
# Adjust inputs
if node.type == "input" and node.name in output.inputs:
offset -= 1
continue

new_inputs = {
port: (idx + offset, p)
if idx not in output_mapping
else output_mapping[idx]
for port, (idx, p) in node.inputs.items()
}
# Adjust outputs
new_outputs = {port: idx + offset for port, idx in node.outputs.items()}
new_node = node.__class__(
**{
**node.__dict__,
"inputs": new_inputs,
"outputs": new_outputs,
}
)
new_nodes.append(new_node)

new_named_nodes = {
name: idx + offset for name, idx in other.named_nodes.items()
}
new_named_nodes.update(self.named_nodes)
## restore self and other
self.nodes.append(output)
graph = GraphData(
nodes=new_nodes,
fixed_inputs=new_fixed_inputs,
graph_inputs=new_graph_inputs,
named_nodes=new_named_nodes,
graph_output_idx=len(new_nodes) - 1,
)
return graph

def input(self, name: str) -> ValueRef:
return self.add(Input(name))(name)

Expand Down