diff --git a/docs/source/examples/data/typed_eval.py b/docs/source/examples/data/typed_eval.py index 69e497861..e6a05decb 100644 --- a/docs/source/examples/data/typed_eval.py +++ b/docs/source/examples/data/typed_eval.py @@ -1,7 +1,8 @@ from typing import NamedTuple + +from tierkreis.builder import GraphBuilder from tierkreis.builtins.stubs import iadd, itimes from tierkreis.controller.data.core import EmptyModel -from tierkreis.builder import GraphBuilder from tierkreis.controller.data.models import TKR diff --git a/docs/source/examples/errors_and_debugging.ipynb b/docs/source/examples/errors_and_debugging.ipynb index 5701edcb8..9f0eaac05 100644 --- a/docs/source/examples/errors_and_debugging.ipynb +++ b/docs/source/examples/errors_and_debugging.ipynb @@ -27,12 +27,12 @@ "metadata": {}, "outputs": [], "source": [ + "from example_workers.error_worker.api.stubs import fail\n", + "\n", "from tierkreis.builder import GraphBuilder\n", "from tierkreis.controller.data.core import EmptyModel\n", "from tierkreis.controller.data.models import TKR\n", "\n", - "from example_workers.error_worker.api.stubs import fail\n", - "\n", "\n", "def error_graph() -> GraphBuilder:\n", " g = GraphBuilder(EmptyModel, TKR[str])\n", @@ -60,8 +60,8 @@ "from uuid import UUID\n", "\n", "from tierkreis.controller import run_graph\n", - "from tierkreis.controller.storage.filestorage import ControllerFileStorage\n", "from tierkreis.controller.executor.uv_executor import UvExecutor\n", + "from tierkreis.controller.storage.filestorage import ControllerFileStorage\n", "from tierkreis.exceptions import TierkreisError\n", "\n", "workflow_id = UUID(int=103)\n", @@ -69,7 +69,6 @@ "\n", "registry_path = Path().parent / \"example_workers\"\n", "executor = UvExecutor(registry_path=registry_path, logs_path=storage.logs_path)\n", - "print(\"Starting workflow at location:\", storage.logs_path)\n", "try:\n", " run_graph(\n", " storage,\n", @@ -79,8 +78,7 @@ " polling_interval_seconds=0.1,\n", " )\n", "except TierkreisError: # we will catch this here\n", - " output = storage.read_errors()\n", - " print(\"Errors are at:\", output)" + " output = storage.read_errors()" ] }, { @@ -105,6 +103,7 @@ "metadata": {}, "outputs": [], "source": [ + "import contextlib\n", "import logging\n", "\n", "logging.basicConfig(\n", @@ -114,16 +113,14 @@ ")\n", "\n", "storage.clean_graph_files()\n", - "try:\n", + "with contextlib.suppress(TierkreisError):\n", " run_graph(\n", " storage,\n", " executor,\n", " error_graph(),\n", " {\"value\": \"world!\"},\n", " polling_interval_seconds=0.1,\n", - " )\n", - "except TierkreisError:\n", - " pass" + " )" ] }, { @@ -143,8 +140,8 @@ "metadata": {}, "outputs": [], "source": [ - "from tierkreis.storage import InMemoryStorage\n", "from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor\n", + "from tierkreis.storage import InMemoryStorage\n", "\n", "storage = InMemoryStorage(UUID(int=103))\n", "executor = InMemoryExecutor(registry_path, storage)\n", diff --git a/docs/source/examples/example_workers/auth_worker/src/main.py b/docs/source/examples/example_workers/auth_worker/src/main.py index a478c504f..010b960eb 100644 --- a/docs/source/examples/example_workers/auth_worker/src/main.py +++ b/docs/source/examples/example_workers/auth_worker/src/main.py @@ -2,16 +2,22 @@ import secrets from sys import argv from time import time -from typing import NamedTuple, cast +from typing import TYPE_CHECKING, NamedTuple, cast -from tierkreis.controller.data.models import portmapping import pyscrypt # type: ignore from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.hashes import SHA256 from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey +from cryptography.hazmat.primitives.hashes import SHA256 + from tierkreis import Worker +from tierkreis.controller.data.models import portmapping + +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, + ) worker = Worker("auth_worker") logger = logging.getLogger(__name__) @@ -34,7 +40,7 @@ def encrypt(plaintext: str, work_factor: int) -> EncryptionResult: start_time = time() salt = secrets.token_bytes(32) ciphertext = pyscrypt.hash( # type:ignore - password=plaintext.encode(), salt=salt, N=work_factor, r=1, p=1, dkLen=32 + password=plaintext.encode(), salt=salt, N=work_factor, r=1, p=1, dkLen=32, ) time_taken = time() - start_time @@ -45,13 +51,13 @@ def encrypt(plaintext: str, work_factor: int) -> EncryptionResult: def sign(private_key: bytes, passphrase: bytes, message: str) -> SigningResult: start_time = time() key = cast( - RSAPrivateKey, + "RSAPrivateKey", serialization.load_pem_private_key(private_key, password=passphrase), ) signature = key.sign( message.encode(), padding=padding.PSS( - mgf=padding.MGF1(SHA256()), salt_length=padding.PSS.MAX_LENGTH + mgf=padding.MGF1(SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), algorithm=SHA256(), ).hex() @@ -62,13 +68,13 @@ def sign(private_key: bytes, passphrase: bytes, message: str) -> SigningResult: @worker.task() def verify(public_key: bytes, signature: str, message: str) -> bool: - key = cast(RSAPublicKey, serialization.load_pem_public_key(public_key)) + key = cast("RSAPublicKey", serialization.load_pem_public_key(public_key)) try: key.verify( bytes.fromhex(signature), message.encode(), padding=padding.PSS( - mgf=padding.MGF1(SHA256()), salt_length=padding.PSS.MAX_LENGTH + mgf=padding.MGF1(SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), algorithm=SHA256(), ) diff --git a/docs/source/examples/example_workers/error_worker/src/main.py b/docs/source/examples/example_workers/error_worker/src/main.py index f918b74c3..3ff671daf 100644 --- a/docs/source/examples/example_workers/error_worker/src/main.py +++ b/docs/source/examples/example_workers/error_worker/src/main.py @@ -1,5 +1,6 @@ import logging from sys import argv + from tierkreis import Worker logger = logging.getLogger(__name__) @@ -8,7 +9,8 @@ @worker.task() def fail() -> str: - raise Exception("I refuse!") + msg = "I refuse!" + raise Exception(msg) return "I failed to refuse" diff --git a/docs/source/examples/example_workers/hello_world_worker/src/main.py b/docs/source/examples/example_workers/hello_world_worker/src/main.py index 43aac55c5..28b9dcb24 100644 --- a/docs/source/examples/example_workers/hello_world_worker/src/main.py +++ b/docs/source/examples/example_workers/hello_world_worker/src/main.py @@ -1,5 +1,6 @@ import logging from sys import argv + from tierkreis import Worker logger = logging.getLogger(__name__) diff --git a/docs/source/examples/example_workers/qsci_worker/src/chemistry/active_space.py b/docs/source/examples/example_workers/qsci_worker/src/chemistry/active_space.py index e52d00e17..f4372f6e6 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/chemistry/active_space.py +++ b/docs/source/examples/example_workers/qsci_worker/src/chemistry/active_space.py @@ -7,7 +7,8 @@ def get_frozen( n_orbs = len(mo_occ) n_core = get_n_core(mo_occ, n_elecas) if n_core + n_cas > n_orbs: - raise ValueError("active space is larger than basis set") + msg = "active space is larger than basis set" + raise ValueError(msg) for i in range(n_orbs): # print(i, i < n_core, i >= n_core + n_cas) if i < n_core or i >= n_core + n_cas: @@ -20,8 +21,7 @@ def get_n_core( n_elecas: int, ) -> int: n_elec = int(sum(mo_occ)) - n_core = (n_elec - n_elecas) // 2 - return n_core + return (n_elec - n_elecas) // 2 def get_n_active( @@ -30,8 +30,7 @@ def get_n_active( n_elecas: int, ) -> int: n_frozen = len(get_frozen(mo_occ, n_cas, n_elecas)) - n_active = len(mo_occ) - n_frozen - return n_active + return len(mo_occ) - n_frozen def get_n_virtual( @@ -41,5 +40,4 @@ def get_n_virtual( ) -> int: n_frozen = len(get_frozen(mo_occ, n_cas, n_elecas)) n_core = get_n_core(mo_occ, n_elecas) - n_virtual = n_frozen - n_core - return n_virtual + return n_frozen - n_core diff --git a/docs/source/examples/example_workers/qsci_worker/src/chemistry/molecule.py b/docs/source/examples/example_workers/qsci_worker/src/chemistry/molecule.py index 56eb3a677..956e0c0c2 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/chemistry/molecule.py +++ b/docs/source/examples/example_workers/qsci_worker/src/chemistry/molecule.py @@ -1,12 +1,14 @@ -from typing import Optional, cast +from typing import TYPE_CHECKING, cast import numpy as np -from numpy.typing import NDArray from pyscf import ao2mo, gto, scf +if TYPE_CHECKING: + from numpy.typing import NDArray + def _extract_hamiltonian_rhf( - mol: gto.Mole, frozen: Optional[list[int]] = None + mol: gto.Mole, frozen: list[int] | None = None, ) -> tuple[float, np.ndarray, np.ndarray]: """Extract the fermionic Hamiltonian from a mean-field calculation. @@ -23,7 +25,7 @@ def _extract_hamiltonian_rhf( mf.kernel() # Get the MOs - mo = cast(NDArray, mf.mo_coeff) + mo = cast("NDArray", mf.mo_coeff) if frozen: mo = np.delete(mo, frozen, axis=1) nmo = mo.shape[1] @@ -55,7 +57,7 @@ def extract_hamiltonian_rhf( basis: str, charge: int = 0, spin: int = 0, - frozen: Optional[list[int]] = None, + frozen: list[int] | None = None, ) -> tuple[float, np.ndarray, np.ndarray]: """Generate the Hamiltonian in a qubit representation. diff --git a/docs/source/examples/example_workers/qsci_worker/src/main.py b/docs/source/examples/example_workers/qsci_worker/src/main.py index c1ac237de..0b469e75a 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/main.py +++ b/docs/source/examples/example_workers/qsci_worker/src/main.py @@ -1,8 +1,11 @@ import logging +from collections import Counter from sys import argv -from typing import Counter, NamedTuple, cast +from typing import NamedTuple, cast import numpy as np +from chemistry.active_space import get_frozen +from chemistry.molecule import extract_hamiltonian_rhf from pytket._tket.circuit import Circuit from pytket.backends.backendresult import BackendResult from pytket.circuit import Qubit @@ -13,9 +16,6 @@ from qsci.postprocess import get_ci_matrix, postprocess_configs from qsci.state_prep import perform_state_preparation from qsci.utils import get_config_from_cas_init, make_time_evolution_circuits, rhf2ghf -from chemistry.active_space import get_frozen -from chemistry.molecule import extract_hamiltonian_rhf - from tierkreis import Worker @@ -53,28 +53,28 @@ def state_prep( ) -> Circuit: ham_init_operator = QubitPauliOperator( cast( - dict[QubitPauliString, CoeffTypeAccepted], + "dict[QubitPauliString, CoeffTypeAccepted]", qubit_mapping_jordan_wigner( *rhf2ghf( ham_init.h0, np.array(ham_init.h1), np.array(ham_init.h2), - ) + ), ), - ) + ), ) # time-evolve CASCI ground state. n_core_init = get_n_core(mo_occ, cas_init.n_ele) n_core_hsim = get_n_core(mo_occ, cas_hsim.n_ele) n_core = n_core_init - n_core_hsim logging.info( - f"mo_occ={mo_occ} n_cas_hsim={cas_hsim.n} n_elecas_hsim={cas_hsim.n_ele}" + f"mo_occ={mo_occ} n_cas_hsim={cas_hsim.n} n_elecas_hsim={cas_hsim.n_ele}", ) n_active_hsim = get_n_active(mo_occ, cas_hsim.n, cas_hsim.n_ele) prepared_circ = Circuit(n_active_hsim * 2) for i in range(n_core * 2): prepared_circ.X(i) - adapt_circ = perform_state_preparation( + return perform_state_preparation( reference_state=reference_state, ham_init=ham_init_operator, n_cas_init=cas_init.n, @@ -82,7 +82,6 @@ def state_prep( atol=atol, ) - return adapt_circ @worker.task() @@ -98,27 +97,27 @@ def circuits_from_hamiltonians( ) -> list[Circuit]: ham_init_operator = QubitPauliOperator( cast( - dict[QubitPauliString, CoeffTypeAccepted], + "dict[QubitPauliString, CoeffTypeAccepted]", qubit_mapping_jordan_wigner( *rhf2ghf( ham_init.h0, np.array(ham_init.h1), np.array(ham_init.h2), - ) + ), ), - ) + ), ) ham_hsim_operator = QubitPauliOperator( cast( - dict[QubitPauliString, CoeffTypeAccepted], + "dict[QubitPauliString, CoeffTypeAccepted]", qubit_mapping_jordan_wigner( *rhf2ghf( ham_hsim.h0, np.array(ham_hsim.h1), np.array(ham_hsim.h2), - ) + ), ), - ) + ), ) # Load the input data. n_core_init = get_n_core(mo_occ, cas_init.n_ele) @@ -138,19 +137,18 @@ def circuits_from_hamiltonians( { Qubit(qubit.index[0] + 2 * n_core): pauli for qubit, pauli in qps.map.items() - } + }, ): coeff for qps, coeff in ham_init_operator._dict.items() - } + }, ) - circuits = make_time_evolution_circuits( + return make_time_evolution_circuits( t_step_list, prepared_circ, h_hsim=ham_hsim_operator, h_init=ham_init_shifted, max_cx_gates=max_cx_gates, ) - return circuits @worker.task() @@ -167,7 +165,7 @@ def energy_from_results( counts[k] += v phis = list(counts.keys()) phis_init_orig = get_config_from_cas_init( - mo_occ, cas_init.n, cas_init.n_ele, cas_hsim.n, cas_hsim.n_ele + mo_occ, cas_init.n, cas_init.n_ele, cas_hsim.n, cas_hsim.n_ele, ) for p in phis_init_orig: if p not in phis: diff --git a/docs/source/examples/example_workers/qsci_worker/src/qsci/active_space.py b/docs/source/examples/example_workers/qsci_worker/src/qsci/active_space.py index e52d00e17..f4372f6e6 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/qsci/active_space.py +++ b/docs/source/examples/example_workers/qsci_worker/src/qsci/active_space.py @@ -7,7 +7,8 @@ def get_frozen( n_orbs = len(mo_occ) n_core = get_n_core(mo_occ, n_elecas) if n_core + n_cas > n_orbs: - raise ValueError("active space is larger than basis set") + msg = "active space is larger than basis set" + raise ValueError(msg) for i in range(n_orbs): # print(i, i < n_core, i >= n_core + n_cas) if i < n_core or i >= n_core + n_cas: @@ -20,8 +21,7 @@ def get_n_core( n_elecas: int, ) -> int: n_elec = int(sum(mo_occ)) - n_core = (n_elec - n_elecas) // 2 - return n_core + return (n_elec - n_elecas) // 2 def get_n_active( @@ -30,8 +30,7 @@ def get_n_active( n_elecas: int, ) -> int: n_frozen = len(get_frozen(mo_occ, n_cas, n_elecas)) - n_active = len(mo_occ) - n_frozen - return n_active + return len(mo_occ) - n_frozen def get_n_virtual( @@ -41,5 +40,4 @@ def get_n_virtual( ) -> int: n_frozen = len(get_frozen(mo_occ, n_cas, n_elecas)) n_core = get_n_core(mo_occ, n_elecas) - n_virtual = n_frozen - n_core - return n_virtual + return n_frozen - n_core diff --git a/docs/source/examples/example_workers/qsci_worker/src/qsci/jordan_wigner.py b/docs/source/examples/example_workers/qsci_worker/src/qsci/jordan_wigner.py index c0a67f457..a9964a3b6 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/qsci/jordan_wigner.py +++ b/docs/source/examples/example_workers/qsci_worker/src/qsci/jordan_wigner.py @@ -74,7 +74,7 @@ def jordan_wigner_two_body( if (i == j) or (k == l): return terms - elif len({i, j, k, l}) == 4: + if len({i, j, k, l}) == 4: if (i > j) ^ (k > l): coeff *= -1 @@ -98,7 +98,7 @@ def jordan_wigner_two_body( if c: (ip, op_i), (jp, op_j), (kp, op_k), (lp, op_l) = sorted( - zip((i, j, k, l), ops) + zip((i, j, k, l), ops, strict=False), ) parity_string_ij = [(Qubit(p), Pauli.Z) for p in range(ip + 1, jp)] parity_string_kl = [(Qubit(p), Pauli.Z) for p in range(kp + 1, lp)] @@ -170,10 +170,7 @@ def jordan_wigner_two_body( terms[QubitPauliString(strings)] += c * -0.25 elif len({i, j, k, l}) == 2: - if i == l: - c = coeff * -0.25 - else: - c = coeff * 0.25 + c = coeff * -0.25 if i == l else coeff * 0.25 ip, jp = sorted([i, j]) terms[QubitPauliString({})] += -c @@ -200,7 +197,7 @@ def _apply_threshold(hamiltonian: QubitHamiltonian, tol: float) -> QubitHamilton def qubit_mapping_jordan_wigner( - h0: float, h1: NDArray[np.inexact], h2: NDArray[np.inexact], tol: float = 1e-12 + h0: float, h1: NDArray[np.inexact], h2: NDArray[np.inexact], tol: float = 1e-12, ) -> QubitHamiltonian: """Map the Hamiltonian to qubits using Jordan--Wigner mapping. @@ -228,7 +225,7 @@ def _update_hamiltonian(terms: QubitHamiltonian) -> None: # Two-body terms for (i, j), (k, l) in itertools.combinations_with_replacement( # noqa: E741 - itertools.combinations_with_replacement(range(norb), r=2), r=2 + itertools.combinations_with_replacement(range(norb), r=2), r=2, ): _update_hamiltonian(jordan_wigner_two_body(i, j, l, k, h2[i, j, k, l])) diff --git a/docs/source/examples/example_workers/qsci_worker/src/qsci/postprocess.py b/docs/source/examples/example_workers/qsci_worker/src/qsci/postprocess.py index 65440abbb..5b47676e8 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/qsci/postprocess.py +++ b/docs/source/examples/example_workers/qsci_worker/src/qsci/postprocess.py @@ -5,12 +5,12 @@ def get_phase(phi_i, phi_j) -> int: - """phase factor. + """Phase factor. Note: See section 18.8 of the pink book. """ - diff = [(j - i) for i, j in zip(phi_i, phi_j)] + diff = [(j - i) for i, j in zip(phi_i, phi_j, strict=False)] phase = 1 for i, p in enumerate(diff): if p == -1: @@ -48,10 +48,9 @@ def eval_hele( """Get the matrix element based on the Slater-Condon rule.""" # Return if the particle number is not conserving. if sum(phi_i) != sum(phi_j): - val = 0.0 - return val + return 0.0 # Identify the excitation type. - config_diff = np.array([(j - i) for i, j in zip(phi_i, phi_j)]) + config_diff = np.array([(j - i) for i, j in zip(phi_i, phi_j, strict=False)]) n_excitation = np.sum(np.abs(config_diff)) // 2 # Triple or higher excitation returns zero. if n_excitation > 2: @@ -89,7 +88,7 @@ def eval_hele( val += get_h2(h2, j_index) - get_h2(h2, k_index) val += enuc else: - raise RuntimeError() + raise RuntimeError return val @@ -116,8 +115,7 @@ def get_ci_matrix( raw.append(j) col.append(i) data.append(val) - hij = csr_array((data, (raw, col))) - return hij + return csr_array((data, (raw, col))) def postprocess_configs( @@ -128,9 +126,9 @@ def postprocess_configs( nea = sum(reference[0::2]) neb = sum(reference[1::2]) # new_configs: list[tuple[int, ...]] = [] - new_configs: set[tuple[int, ...]] = set([]) + new_configs: set[tuple[int, ...]] = set() for config in configs: - occ = [i + j for i, j in zip(config[0::2], config[1::2])] + occ = [i + j for i, j in zip(config[0::2], config[1::2], strict=False)] ls: list[list[int]] = [[]] # print("occ", occ) for on in occ: diff --git a/docs/source/examples/example_workers/qsci_worker/src/qsci/state_prep.py b/docs/source/examples/example_workers/qsci_worker/src/qsci/state_prep.py index a6d092449..5c1f62bc1 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/qsci/state_prep.py +++ b/docs/source/examples/example_workers/qsci_worker/src/qsci/state_prep.py @@ -47,15 +47,15 @@ def make_pool(qubit_number: int) -> list[QubitPauliOperator]: ("YXYY", -0.125j), ]: qps = QubitPauliString( - {Qubit(x): getattr(Pauli, p) for x, p in zip([i, j, k, l], paulis)} + {Qubit(x): getattr(Pauli, p) for x, p in zip([i, j, k, l], paulis, strict=False)}, ) terms1[qps] = factor qps = QubitPauliString( - {Qubit(x): getattr(Pauli, p) for x, p in zip([i, k, j, l], paulis)} + {Qubit(x): getattr(Pauli, p) for x, p in zip([i, k, j, l], paulis, strict=False)}, ) terms2[qps] = factor qps = QubitPauliString( - {Qubit(x): getattr(Pauli, p) for x, p in zip([i, l, j, k], paulis)} + {Qubit(x): getattr(Pauli, p) for x, p in zip([i, l, j, k], paulis, strict=False)}, ) terms3[qps] = factor pool.append(QubitPauliOperator(terms1)) @@ -69,7 +69,7 @@ def costfunc( ansatz: Circuit, target: np.ndarray, ) -> float: - """Cost function for ADAPTive state preparation + """Cost function for ADAPTive state preparation. Args: params: Circuit parameters @@ -82,7 +82,7 @@ def costfunc( circ_copy = ansatz.copy() symbols = circ_copy.free_symbols() ls = list(symbols) - mapping = dict(zip(ls, params)) + mapping = dict(zip(ls, params, strict=False)) circ_copy.symbol_substitution(mapping) backend = AerStateBackend() compiled_circ = backend.get_compiled_circuit(circ_copy, optimisation_level=0) @@ -136,31 +136,29 @@ def state_preparation( ls = [Pauli.I for _ in range(reference.n_qubits)] for q, p in pauli_string.map.items(): ls[q.index[0]] = p - if any([p != Pauli.I for p in ls]): + if any(p != Pauli.I for p in ls): adapt_circ.add_pauliexpbox( PauliExpBox( ls, complex(coeff).imag * np.real(Symbol(s)) * 2.0 / np.pi, # type: ignore ), - [j for j in range(reference.n_qubits)], + list(range(reference.n_qubits)), ) x0_array = 1 - 2 * np.random.random(size=iteration + 1) opt_res = minimize(costfunc, x0_array, args=(adapt_circ, target)) # Update the reference state-vector and repeat the ADAPT procedure. x0 = opt_res.x.tolist() - mapping = dict(zip(adapt_circ.free_symbols(), opt_res.x)) + mapping = dict(zip(adapt_circ.free_symbols(), opt_res.x, strict=False)) circ = adapt_circ.copy() circ.symbol_substitution(mapping) ref_statevector = circ.get_statevector() cost = opt_res.fun - print("error after iteration=" + str(iteration) + ":", opt_res.fun, opt_res.x) else: if strict: - raise RuntimeError("Not converge") - else: - print("Not converge") + msg = "Not converge" + raise RuntimeError(msg) symbols = adapt_circ.free_symbols() - mapping = dict(zip(symbols, x0)) + mapping = dict(zip(symbols, x0, strict=False)) adapt_circ.symbol_substitution(mapping) return adapt_circ @@ -175,7 +173,7 @@ def perform_state_preparation( """State preparation or load in the saved one.""" adapt_circuit = Circuit(len(reference_state)) target_vector: NDArray[np.complex128] = np.linalg.eigh( - ham_init.to_sparse_matrix().todense() + ham_init.to_sparse_matrix().todense(), )[1][:, 0] target_vector_reshaped = np.array(target_vector).reshape(-1).real diff --git a/docs/source/examples/example_workers/qsci_worker/src/qsci/utils.py b/docs/source/examples/example_workers/qsci_worker/src/qsci/utils.py index 1d181f39b..fc22d491e 100644 --- a/docs/source/examples/example_workers/qsci_worker/src/qsci/utils.py +++ b/docs/source/examples/example_workers/qsci_worker/src/qsci/utils.py @@ -7,6 +7,7 @@ from pytket.circuit import PauliExpBox from pytket.pauli import Pauli from pytket.utils.operators import QubitPauliOperator + from .active_space import get_n_active, get_n_core, get_n_virtual @@ -32,7 +33,7 @@ def get_configs( for pj in lsp: for pi in lsp: p = [] - for i, j in zip(pi, pj): + for i, j in zip(pi, pj, strict=False): p += [i, j] p = tuple(p) if p not in ls: @@ -61,8 +62,7 @@ def get_config_from_cas_init( n_virt -= get_n_virtual(mo_occ, n_cas_hsim, n_elecas_hsim) lsdoc = [1 for _ in range(2 * n_core)] lsvir = [0 for _ in range(2 * n_virt)] - phis_init_orig = [tuple(lsdoc + list(i) + lsvir) for i in phis_init] - return phis_init_orig + return [tuple(lsdoc + list(i) + lsvir) for i in phis_init] def make_time_evolution_circuits( @@ -88,7 +88,7 @@ def make_time_evolution_circuits( n_trotter = 1 H_for_time_evolution = QubitPauliOperator( - {qps: h_hsim[qps] - h_init.get(qps, 0.0) for qps in h_hsim._dict.keys()} + {qps: h_hsim[qps] - h_init.get(qps, 0.0) for qps in h_hsim._dict}, ) H_for_time_evolution.compress() items = sorted( @@ -111,10 +111,10 @@ def make_time_evolution_circuits( ls = [Pauli.I for _ in range(len(circ.qubits))] for q, p in pauli_string.map.items(): ls[q.index[0]] = p - if any([p != Pauli.I for p in ls]): + if any(p != Pauli.I for p in ls): circ.add_pauliexpbox( PauliExpBox(ls, coeff * 2 * time_step / np.pi), - [j for j in range(len(circ.qubits))], + list(range(len(circ.qubits))), ) circ.measure_all() list_circ.append(circ) @@ -137,9 +137,9 @@ def rhf2ghf( Integrals in the GHF basis. """ nmo = h1e0.shape[0] - h1e = cast(NDArray[np.float64], np.kron(np.eye(2), h1e0)) + h1e = cast("NDArray[np.float64]", np.kron(np.eye(2), h1e0)) h2e = np.kron(np.eye(2), np.kron(np.eye(2), h2e0).T) - mask = list(itertools.chain(*zip(range(nmo), range(nmo, nmo * 2)))) + mask = list(itertools.chain(*zip(range(nmo), range(nmo, nmo * 2), strict=False))) h1e = h1e[mask][:, mask] h2e = h2e[mask][:, mask][:, :, mask][:, :, :, mask] h2e = h2e.transpose(0, 2, 1, 3) - h2e.transpose(0, 2, 3, 1) diff --git a/docs/source/examples/example_workers/scipy_worker/src/main.py b/docs/source/examples/example_workers/scipy_worker/src/main.py index 2965d9015..7e5e18fbe 100644 --- a/docs/source/examples/example_workers/scipy_worker/src/main.py +++ b/docs/source/examples/example_workers/scipy_worker/src/main.py @@ -3,6 +3,7 @@ import pickle from sys import argv from typing import Annotated, NamedTuple + import numpy as np from tierkreis.controller.data.core import Deserializer, Serializer diff --git a/docs/source/examples/example_workers/substitution_worker/src/main.py b/docs/source/examples/example_workers/substitution_worker/src/main.py index 01cff5d1a..ed049bcf4 100644 --- a/docs/source/examples/example_workers/substitution_worker/src/main.py +++ b/docs/source/examples/example_workers/substitution_worker/src/main.py @@ -1,11 +1,11 @@ import logging from sys import argv - -from tierkreis import Worker from pytket._tket.circuit import Circuit from sympy import Symbol +from tierkreis import Worker + logger = logging.getLogger(__name__) worker = Worker("substitution_worker") diff --git a/docs/source/examples/hamiltonian.ipynb b/docs/source/examples/hamiltonian.ipynb index 1ce8acdca..53ce0b019 100644 --- a/docs/source/examples/hamiltonian.ipynb +++ b/docs/source/examples/hamiltonian.ipynb @@ -69,12 +69,13 @@ "metadata": {}, "outputs": [], "source": [ - "from tierkreis.builder import GraphBuilder\n", - "from tierkreis.controller.data.models import TKR\n", - "from typing import NamedTuple, Literal\n", + "from typing import Literal, NamedTuple\n", "\n", "from example_workers.substitution_worker.api.stubs import substitute\n", "\n", + "from tierkreis.builder import GraphBuilder\n", + "from tierkreis.controller.data.models import TKR\n", + "\n", "\n", "class SymbolicExecutionInputs(NamedTuple):\n", " a: TKR[float]\n", @@ -91,7 +92,7 @@ " simulation_graph.inputs.a,\n", " simulation_graph.inputs.b,\n", " simulation_graph.inputs.c,\n", - " )\n", + " ),\n", ")" ] }, @@ -110,12 +111,12 @@ "metadata": {}, "outputs": [], "source": [ + "from tierkreis.aer_worker import submit_single\n", "from tierkreis.pytket_worker import (\n", " append_pauli_measurement_impl,\n", - " optimise_phase_gadgets,\n", " expectation,\n", + " optimise_phase_gadgets,\n", ")\n", - "from tierkreis.aer_worker import submit_single\n", "\n", "\n", "class SubmitInputs(NamedTuple):\n", @@ -161,7 +162,7 @@ "from tierkreis.controller.data.models import TKR\n", "\n", "pauli_strings_list, parameters_list = simulation_graph.task(\n", - " unzip(simulation_graph.inputs.ham)\n", + " unzip(simulation_graph.inputs.ham),\n", ")\n", "input_circuits = simulation_graph.map(\n", " lambda x: SubmitInputs(substituted_circuit, x, simulation_graph.const(100)),\n", @@ -190,7 +191,7 @@ "from tierkreis.graphs.fold import FoldFunctionInput\n", "\n", "ComputeTermsInputs = FoldFunctionInput[\n", - " tuple[float, float], float\n", + " tuple[float, float], float,\n", "] # (value, accum) -> new_accum\n", "\n", "\n", @@ -266,15 +267,15 @@ "\n", "from tierkreis.consts import PACKAGE_PATH\n", "from tierkreis.controller.executor.multiple import MultipleExecutor\n", - "from tierkreis.storage import FileStorage\n", "from tierkreis.controller.executor.uv_executor import UvExecutor\n", + "from tierkreis.storage import FileStorage\n", "\n", "storage = FileStorage(UUID(int=102), name=\"hamiltonian\")\n", "example_executor = UvExecutor(\n", - " registry_path=Path().parent / \"example_workers\", logs_path=storage.logs_path\n", + " registry_path=Path().parent / \"example_workers\", logs_path=storage.logs_path,\n", ")\n", "common_executor = UvExecutor(\n", - " registry_path=PACKAGE_PATH.parent / \"tierkreis_workers\", logs_path=storage.logs_path\n", + " registry_path=PACKAGE_PATH.parent / \"tierkreis_workers\", logs_path=storage.logs_path,\n", ")\n", "multi_executor = MultipleExecutor(\n", " common_executor,\n", @@ -335,7 +336,6 @@ "from tierkreis.controller import run_graph\n", "from tierkreis.storage import read_outputs\n", "\n", - "\n", "storage.clean_graph_files()\n", "run_graph(\n", " storage,\n", @@ -344,8 +344,7 @@ " inputs,\n", " polling_interval_seconds=0.2,\n", ")\n", - "output = read_outputs(simulation_graph, storage)\n", - "print(output)" + "output = read_outputs(simulation_graph, storage)" ] } ], diff --git a/docs/source/examples/hello_world.py b/docs/source/examples/hello_world.py index 6f544120e..70596180c 100644 --- a/docs/source/examples/hello_world.py +++ b/docs/source/examples/hello_world.py @@ -1,8 +1,7 @@ -from tierkreis.builder import GraphBuilder -from tierkreis.controller.data.models import TKR - from example_workers.hello_world_worker.api.stubs import greet +from tierkreis.builder import GraphBuilder +from tierkreis.controller.data.models import TKR graph = GraphBuilder(inputs_type=TKR[str], outputs_type=TKR[str]) hello = graph.const("Hello ") diff --git a/docs/source/examples/hello_world_graph.ipynb b/docs/source/examples/hello_world_graph.ipynb index 432257c7a..89c47a7ad 100644 --- a/docs/source/examples/hello_world_graph.ipynb +++ b/docs/source/examples/hello_world_graph.ipynb @@ -142,6 +142,7 @@ "outputs": [], "source": [ "from pathlib import Path\n", + "\n", "from tierkreis.cli.run_workflow import run_workflow\n", "\n", "\n", diff --git a/docs/source/examples/hpc.ipynb b/docs/source/examples/hpc.ipynb index 30513e8af..14ccc27ac 100644 --- a/docs/source/examples/hpc.ipynb +++ b/docs/source/examples/hpc.ipynb @@ -32,6 +32,7 @@ "outputs": [], "source": [ "from typing import NamedTuple\n", + "\n", "from tierkreis.controller.data.models import TKR, OpaqueType\n", "\n", "\n", @@ -61,14 +62,15 @@ "metadata": {}, "outputs": [], "source": [ - "from tierkreis.builder import GraphBuilder\n", "from example_workers.substitution_worker.api.stubs import substitute\n", + "\n", + "from tierkreis.aer_worker import submit_single\n", + "from tierkreis.builder import GraphBuilder\n", "from tierkreis.pytket_worker import (\n", " add_measure_all,\n", - " optimise_phase_gadgets,\n", " expectation,\n", + " optimise_phase_gadgets,\n", ")\n", - "from tierkreis.aer_worker import submit_single\n", "\n", "\n", "def symbolic_execution() -> GraphBuilder:\n", @@ -150,6 +152,7 @@ "source": [ "from pathlib import Path\n", "from uuid import UUID\n", + "\n", "from tierkreis.controller.storage.filestorage import ControllerFileStorage\n", "\n", "storage = ControllerFileStorage(\n", @@ -300,8 +303,7 @@ " },\n", " polling_interval_seconds=0.1,\n", ")\n", - "output = read_outputs(symbolic_execution().data, storage)\n", - "print(output)" + "output = read_outputs(symbolic_execution().data, storage)" ] } ], diff --git a/docs/source/examples/parallelism.ipynb b/docs/source/examples/parallelism.ipynb index b17e058e2..b0cc1e47f 100644 --- a/docs/source/examples/parallelism.ipynb +++ b/docs/source/examples/parallelism.ipynb @@ -1,357 +1,358 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "88dc868f", - "metadata": {}, - "source": [ - "## Leveraging parallelism through map\n", - "\n", - "One major advantage in workflow systems is the ease of scaling computation horizontally.\n", - "Data-parallel tasks can act independently; In tierkreis this can simply be achieved through the `map` function.\n", - "Each map element will receive exactly one sets of inputs and can therefore be immediately dispatched.\n", - "In this example we will observe the speedup by running multiple independent graphs in parallel.\n", - "\n", - "First we define a simple graph that will run a circuit in two version:\n", - "1. Using the qiskit aer simulator\n", - "2. Using the [qulacs](https://github.com/qulacs/qulacs) simulator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ee0cabe0", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install tierkreis pytket qiskit-aer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9e6fae3d", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Literal, NamedTuple\n", - "from tierkreis.builder import GraphBuilder\n", - "from tierkreis.controller.data.models import TKR, OpaqueType\n", - "from tierkreis.builtins.stubs import untuple\n", - "from tierkreis.aer_worker import (\n", - " get_compiled_circuit as aer_compile,\n", - " run_circuit as aer_run,\n", - ")\n", - "from tierkreis.qulacs_worker import (\n", - " get_compiled_circuit as qulacs_compile,\n", - " run_circuit as qulacs_run,\n", - ")\n", - "\n", - "type BackendResult = OpaqueType[\"pytket.backends.backendresult.BackendResult\"] # noqa: F821\n", - "type Circuit = OpaqueType[\"pytket._tket.circuit.Circuit\"] # noqa: F821\n", - "\n", - "\n", - "class SimulateJobInputsSingle(NamedTuple):\n", - " simulator_name: TKR[Literal[\"aer\", \"qulacs\"]]\n", - " circuit_shots: TKR[tuple[Circuit, int]]\n", - " compilation_optimisation_level: TKR[int]\n", - "\n", - "\n", - "def aer_simulate_single():\n", - " g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult])\n", - " circuit_shots = g.task(untuple(g.inputs.circuit_shots))\n", - "\n", - " compiled_circuit = g.task(\n", - " aer_compile(\n", - " circuit=circuit_shots.a,\n", - " optimisation_level=g.inputs.compilation_optimisation_level,\n", - " )\n", - " )\n", - " res = g.task(aer_run(compiled_circuit, circuit_shots.b))\n", - " g.outputs(res)\n", - " return g\n", - "\n", - "\n", - "def qulacs_simulate_single():\n", - " g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult])\n", - " circuit_shots = g.task(untuple(g.inputs.circuit_shots))\n", - "\n", - " compiled_circuit = g.task(\n", - " qulacs_compile(\n", - " circuit=circuit_shots.a,\n", - " optimisation_level=g.inputs.compilation_optimisation_level,\n", - " )\n", - " )\n", - " res = g.task(qulacs_run(compiled_circuit, circuit_shots.b))\n", - " g.outputs(res)\n", - " return g" - ] - }, - { - "cell_type": "markdown", - "id": "f1b7e67e", - "metadata": {}, - "source": [ - "So far these are regular graphs that compile and simulate a single circuit.\n", - "We are going to combine them into a single graph taking a parameter to decide which simulator to run using `ifelse`\n", - "Although we we will have two similar subgraphs in the evaluation, this is not a performance detriment as `ifelse` only evaluates lazily. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0a71ddad", - "metadata": {}, - "outputs": [], - "source": [ - "from tierkreis.builtins.stubs import str_eq\n", - "\n", - "\n", - "def compile_simulate_single():\n", - " g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult])\n", - "\n", - " aer_res = g.eval(aer_simulate_single(), g.inputs)\n", - " qulacs_res = g.eval(qulacs_simulate_single(), g.inputs)\n", - " res = g.ifelse(\n", - " g.task(str_eq(g.inputs.simulator_name, g.const(\"aer\"))), aer_res, qulacs_res\n", - " )\n", - "\n", - " g.outputs(res)\n", - " return g" - ] - }, - { - "cell_type": "markdown", - "id": "9bdbe246", - "metadata": {}, - "source": [ - "To make this parallel over multiple circuits we are using the `map` feature in a new graph." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ae2ee842", - "metadata": {}, - "outputs": [], - "source": [ - "class SimulateJobInputs(NamedTuple):\n", - " simulator_name: TKR[Literal[\"aer\", \"qulacs\"]]\n", - " circuits: TKR[list[Circuit]]\n", - " n_shots: TKR[list[int]]\n", - " compilation_optimisation_level: TKR[int]\n", - "\n", - "\n", - "g = GraphBuilder(SimulateJobInputs, TKR[list[BackendResult]])" - ] - }, - { - "cell_type": "markdown", - "id": "1052dd30", - "metadata": {}, - "source": [ - "Each of the `SimulateJobInputsSingle` expects a tuple `(Circuit, n_shots)` which we generate by zipping" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d0c9907a", - "metadata": {}, - "outputs": [], - "source": [ - "from tierkreis.builtins.stubs import tkr_zip\n", - "\n", - "circuits_shots = g.task(tkr_zip(g.inputs.circuits, g.inputs.n_shots))" - ] - }, - { - "cell_type": "markdown", - "id": "5b28a6ba", - "metadata": {}, - "source": [ - "A convenient way to aggregate the inputs is using a map over a lambda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa29dd13", - "metadata": {}, - "outputs": [], - "source": [ - "job_inputs = g.map(\n", - " lambda x: SimulateJobInputsSingle(\n", - " simulator_name=g.inputs.simulator_name,\n", - " circuit_shots=x,\n", - " compilation_optimisation_level=g.inputs.compilation_optimisation_level,\n", - " ),\n", - " circuits_shots,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "f15c2943", - "metadata": {}, - "source": [ - "and finally we can map over the jobs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "685a5d47", - "metadata": {}, - "outputs": [], - "source": [ - "res = g.map(compile_simulate_single(), job_inputs)\n", - "\n", - "g.outputs(res)" - ] - }, - { - "cell_type": "markdown", - "id": "d8b7cabf", - "metadata": {}, - "source": [ - "preparing the storage, executor and inputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7c2baf2", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "from uuid import UUID\n", - "\n", - "from pytket.qasm.qasm import circuit_from_qasm\n", - "\n", - "from tierkreis.consts import PACKAGE_PATH\n", - "from tierkreis.storage import FileStorage\n", - "from tierkreis.executor import UvExecutor\n", - "\n", - "circuit = circuit_from_qasm(Path().parent / \"data\" / \"ghz_state_n23.qasm\")\n", - "circuits = [circuit] * 3\n", - "n_shots = 1024\n", - "\n", - "storage = FileStorage(UUID(int=107), do_cleanup=True)\n", - "executor = UvExecutor(PACKAGE_PATH / \"..\" / \"tierkreis_workers\", storage.logs_path)\n", - "inputs = {\n", - " \"circuits\": circuits,\n", - " \"n_shots\": [n_shots] * len(circuits),\n", - " \"compilation_optimisation_level\": 2,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "id": "1bca6c67", - "metadata": {}, - "source": [ - "we can now benchmark aer by setting the `simulator_name` input" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93217435", - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "from tierkreis.controller import run_graph\n", - "\n", - "inputs[\"simulator_name\"] = \"aer\"\n", - "print(\"Simulating using aer...\")\n", - "start = time.time()\n", - "run_graph(storage, executor, g, inputs, polling_interval_seconds=0.1)\n", - "print(f\"time taken: {time.time() - start}\")" - ] - }, - { - "cell_type": "markdown", - "id": "606a9c18", - "metadata": {}, - "source": [ - "and" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6f6346a6", - "metadata": {}, - "outputs": [], - "source": [ - "inputs[\"simulator_name\"] = \"qulacs\"\n", - "\n", - "print(\"Simulating using qulacs...\")\n", - "storage.clean_graph_files()\n", - "start = time.time()\n", - "run_graph(storage, executor, g, inputs, polling_interval_seconds=0.1)\n", - "print(f\"time taken: {time.time() - start}\")" - ] - }, - { - "cell_type": "markdown", - "id": "c34a812e", - "metadata": {}, - "source": [ - "compared against running the same graph three times:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "39cf4dc4", - "metadata": {}, - "outputs": [], - "source": [ - "start = time.time()\n", - "for circuit in circuits:\n", - " inputs = {\n", - " \"circuit_shots\": (circuit, n_shots),\n", - " \"compilation_optimisation_level\": 2,\n", - " \"simulator_name\": \"aer\",\n", - " }\n", - " storage.clean_graph_files()\n", - " run_graph(\n", - " storage,\n", - " executor,\n", - " compile_simulate_single(),\n", - " inputs,\n", - " polling_interval_seconds=0.1,\n", - " )\n", - "print(f\"time taken: {time.time() - start}\")" - ] - } - ], - "metadata": { - "execution": { - "timeout": 120 - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.11" - } + "cells": [ + { + "cell_type": "markdown", + "id": "88dc868f", + "metadata": {}, + "source": [ + "## Leveraging parallelism through map\n", + "\n", + "One major advantage in workflow systems is the ease of scaling computation horizontally.\n", + "Data-parallel tasks can act independently; In tierkreis this can simply be achieved through the `map` function.\n", + "Each map element will receive exactly one sets of inputs and can therefore be immediately dispatched.\n", + "In this example we will observe the speedup by running multiple independent graphs in parallel.\n", + "\n", + "First we define a simple graph that will run a circuit in two version:\n", + "1. Using the qiskit aer simulator\n", + "2. Using the [qulacs](https://github.com/qulacs/qulacs) simulator" + ] }, - "nbformat": 4, - "nbformat_minor": 5 + { + "cell_type": "code", + "execution_count": null, + "id": "ee0cabe0", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install tierkreis pytket qiskit-aer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e6fae3d", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Literal, NamedTuple\n", + "\n", + "from tierkreis.aer_worker import (\n", + " get_compiled_circuit as aer_compile,\n", + ")\n", + "from tierkreis.aer_worker import (\n", + " run_circuit as aer_run,\n", + ")\n", + "from tierkreis.builder import GraphBuilder\n", + "from tierkreis.builtins.stubs import untuple\n", + "from tierkreis.controller.data.models import TKR, OpaqueType\n", + "from tierkreis.qulacs_worker import (\n", + " get_compiled_circuit as qulacs_compile,\n", + ")\n", + "from tierkreis.qulacs_worker import (\n", + " run_circuit as qulacs_run,\n", + ")\n", + "\n", + "type BackendResult = OpaqueType[\"pytket.backends.backendresult.BackendResult\"] # noqa: F821\n", + "type Circuit = OpaqueType[\"pytket._tket.circuit.Circuit\"] # noqa: F821\n", + "\n", + "\n", + "class SimulateJobInputsSingle(NamedTuple):\n", + " simulator_name: TKR[Literal[\"aer\", \"qulacs\"]]\n", + " circuit_shots: TKR[tuple[Circuit, int]]\n", + " compilation_optimisation_level: TKR[int]\n", + "\n", + "\n", + "def aer_simulate_single():\n", + " g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult])\n", + " circuit_shots = g.task(untuple(g.inputs.circuit_shots))\n", + "\n", + " compiled_circuit = g.task(\n", + " aer_compile(\n", + " circuit=circuit_shots.a,\n", + " optimisation_level=g.inputs.compilation_optimisation_level,\n", + " ),\n", + " )\n", + " res = g.task(aer_run(compiled_circuit, circuit_shots.b))\n", + " g.outputs(res)\n", + " return g\n", + "\n", + "\n", + "def qulacs_simulate_single():\n", + " g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult])\n", + " circuit_shots = g.task(untuple(g.inputs.circuit_shots))\n", + "\n", + " compiled_circuit = g.task(\n", + " qulacs_compile(\n", + " circuit=circuit_shots.a,\n", + " optimisation_level=g.inputs.compilation_optimisation_level,\n", + " ),\n", + " )\n", + " res = g.task(qulacs_run(compiled_circuit, circuit_shots.b))\n", + " g.outputs(res)\n", + " return g" + ] + }, + { + "cell_type": "markdown", + "id": "f1b7e67e", + "metadata": {}, + "source": [ + "So far these are regular graphs that compile and simulate a single circuit.\n", + "We are going to combine them into a single graph taking a parameter to decide which simulator to run using `ifelse`\n", + "Although we we will have two similar subgraphs in the evaluation, this is not a performance detriment as `ifelse` only evaluates lazily. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a71ddad", + "metadata": {}, + "outputs": [], + "source": [ + "from tierkreis.builtins.stubs import str_eq\n", + "\n", + "\n", + "def compile_simulate_single():\n", + " g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult])\n", + "\n", + " aer_res = g.eval(aer_simulate_single(), g.inputs)\n", + " qulacs_res = g.eval(qulacs_simulate_single(), g.inputs)\n", + " res = g.ifelse(\n", + " g.task(str_eq(g.inputs.simulator_name, g.const(\"aer\"))), aer_res, qulacs_res,\n", + " )\n", + "\n", + " g.outputs(res)\n", + " return g" + ] + }, + { + "cell_type": "markdown", + "id": "9bdbe246", + "metadata": {}, + "source": [ + "To make this parallel over multiple circuits we are using the `map` feature in a new graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae2ee842", + "metadata": {}, + "outputs": [], + "source": [ + "class SimulateJobInputs(NamedTuple):\n", + " simulator_name: TKR[Literal[\"aer\", \"qulacs\"]]\n", + " circuits: TKR[list[Circuit]]\n", + " n_shots: TKR[list[int]]\n", + " compilation_optimisation_level: TKR[int]\n", + "\n", + "\n", + "g = GraphBuilder(SimulateJobInputs, TKR[list[BackendResult]])" + ] + }, + { + "cell_type": "markdown", + "id": "1052dd30", + "metadata": {}, + "source": [ + "Each of the `SimulateJobInputsSingle` expects a tuple `(Circuit, n_shots)` which we generate by zipping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0c9907a", + "metadata": {}, + "outputs": [], + "source": [ + "from tierkreis.builtins.stubs import tkr_zip\n", + "\n", + "circuits_shots = g.task(tkr_zip(g.inputs.circuits, g.inputs.n_shots))" + ] + }, + { + "cell_type": "markdown", + "id": "5b28a6ba", + "metadata": {}, + "source": [ + "A convenient way to aggregate the inputs is using a map over a lambda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa29dd13", + "metadata": {}, + "outputs": [], + "source": [ + "job_inputs = g.map(\n", + " lambda x: SimulateJobInputsSingle(\n", + " simulator_name=g.inputs.simulator_name,\n", + " circuit_shots=x,\n", + " compilation_optimisation_level=g.inputs.compilation_optimisation_level,\n", + " ),\n", + " circuits_shots,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f15c2943", + "metadata": {}, + "source": [ + "and finally we can map over the jobs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "685a5d47", + "metadata": {}, + "outputs": [], + "source": [ + "res = g.map(compile_simulate_single(), job_inputs)\n", + "\n", + "g.outputs(res)" + ] + }, + { + "cell_type": "markdown", + "id": "d8b7cabf", + "metadata": {}, + "source": [ + "preparing the storage, executor and inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7c2baf2", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from uuid import UUID\n", + "\n", + "from pytket.qasm.qasm import circuit_from_qasm\n", + "\n", + "from tierkreis.consts import PACKAGE_PATH\n", + "from tierkreis.executor import UvExecutor\n", + "from tierkreis.storage import FileStorage\n", + "\n", + "circuit = circuit_from_qasm(Path().parent / \"data\" / \"ghz_state_n23.qasm\")\n", + "circuits = [circuit] * 3\n", + "n_shots = 1024\n", + "\n", + "storage = FileStorage(UUID(int=107), do_cleanup=True)\n", + "executor = UvExecutor(PACKAGE_PATH / \"..\" / \"tierkreis_workers\", storage.logs_path)\n", + "inputs = {\n", + " \"circuits\": circuits,\n", + " \"n_shots\": [n_shots] * len(circuits),\n", + " \"compilation_optimisation_level\": 2,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "1bca6c67", + "metadata": {}, + "source": [ + "we can now benchmark aer by setting the `simulator_name` input" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93217435", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "from tierkreis.controller import run_graph\n", + "\n", + "inputs[\"simulator_name\"] = \"aer\"\n", + "start = time.time()\n", + "run_graph(storage, executor, g, inputs, polling_interval_seconds=0.1)" + ] + }, + { + "cell_type": "markdown", + "id": "606a9c18", + "metadata": {}, + "source": [ + "and" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f6346a6", + "metadata": {}, + "outputs": [], + "source": [ + "inputs[\"simulator_name\"] = \"qulacs\"\n", + "\n", + "storage.clean_graph_files()\n", + "start = time.time()\n", + "run_graph(storage, executor, g, inputs, polling_interval_seconds=0.1)" + ] + }, + { + "cell_type": "markdown", + "id": "c34a812e", + "metadata": {}, + "source": [ + "compared against running the same graph three times:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39cf4dc4", + "metadata": {}, + "outputs": [], + "source": [ + "start = time.time()\n", + "for circuit in circuits:\n", + " inputs = {\n", + " \"circuit_shots\": (circuit, n_shots),\n", + " \"compilation_optimisation_level\": 2,\n", + " \"simulator_name\": \"aer\",\n", + " }\n", + " storage.clean_graph_files()\n", + " run_graph(\n", + " storage,\n", + " executor,\n", + " compile_simulate_single(),\n", + " inputs,\n", + " polling_interval_seconds=0.1,\n", + " )" + ] + } + ], + "metadata": { + "execution": { + "timeout": 120 + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/source/examples/polling_and_dir.ipynb b/docs/source/examples/polling_and_dir.ipynb index f894a1508..97ccd00de 100644 --- a/docs/source/examples/polling_and_dir.ipynb +++ b/docs/source/examples/polling_and_dir.ipynb @@ -32,6 +32,7 @@ "source": [ "from pathlib import Path\n", "from uuid import UUID\n", + "\n", "from tierkreis.storage import FileStorage\n", "\n", "storage = FileStorage(\n", @@ -109,7 +110,7 @@ "\n", "login()\n", "visualize_graph(\n", - " graph\n", + " graph,\n", ") # this spawns a server, you need to manually terminate this cell." ] }, @@ -132,8 +133,8 @@ }, "outputs": [], "source": [ - "from qnexus import AerConfig\n", "from pytket.qasm.qasm import circuit_from_qasm\n", + "from qnexus import AerConfig\n", "\n", "from tierkreis.controller import run_graph\n", "from tierkreis.storage import read_outputs\n", @@ -156,8 +157,7 @@ " inputs,\n", " polling_interval_seconds=0.1,\n", ")\n", - "res = read_outputs(graph, storage)\n", - "print(res)" + "res = read_outputs(graph, storage)" ] } ], diff --git a/docs/source/examples/qsci.ipynb b/docs/source/examples/qsci.ipynb index f7677e801..5b6bc46f9 100644 --- a/docs/source/examples/qsci.ipynb +++ b/docs/source/examples/qsci.ipynb @@ -90,16 +90,15 @@ "from typing import cast\n", "\n", "import numpy as np\n", - "from pytket._tket.circuit import Circuit\n", - "from pytket.pauli import QubitPauliString\n", - "from pytket.utils.operators import CoeffTypeAccepted, QubitPauliOperator\n", - "\n", "from example_workers.qsci_worker.src.qsci.active_space import get_n_active, get_n_core\n", "from example_workers.qsci_worker.src.qsci.jordan_wigner import (\n", " qubit_mapping_jordan_wigner,\n", ")\n", "from example_workers.qsci_worker.src.qsci.state_prep import perform_state_preparation\n", "from example_workers.qsci_worker.src.qsci.utils import rhf2ghf\n", + "from pytket._tket.circuit import Circuit\n", + "from pytket.pauli import QubitPauliString\n", + "from pytket.utils.operators import CoeffTypeAccepted, QubitPauliOperator\n", "\n", "\n", "def state_prep(\n", @@ -113,36 +112,34 @@ ") -> Circuit:\n", " ham_init_operator = QubitPauliOperator(\n", " cast(\n", - " dict[QubitPauliString, CoeffTypeAccepted],\n", + " \"dict[QubitPauliString, CoeffTypeAccepted]\",\n", " qubit_mapping_jordan_wigner(\n", " *rhf2ghf(\n", " ham_init.h0,\n", " np.array(ham_init.h1),\n", " np.array(ham_init.h2),\n", - " )\n", + " ),\n", " ),\n", - " )\n", + " ),\n", " )\n", " # time-evolve CASCI ground state.\n", " n_core_init = get_n_core(mo_occ, cas_init.n_ele)\n", " n_core_hsim = get_n_core(mo_occ, cas_hsim.n_ele)\n", " n_core = n_core_init - n_core_hsim\n", " logging.info(\n", - " f\"mo_occ={mo_occ} n_cas_hsim={cas_hsim.n} n_elecas_hsim={cas_hsim.n_ele}\"\n", + " f\"mo_occ={mo_occ} n_cas_hsim={cas_hsim.n} n_elecas_hsim={cas_hsim.n_ele}\",\n", " )\n", " n_active_hsim = get_n_active(mo_occ, cas_hsim.n, cas_hsim.n_ele)\n", " prepared_circ = Circuit(n_active_hsim * 2)\n", " for i in range(n_core * 2):\n", " prepared_circ.X(i)\n", - " adapt_circ = perform_state_preparation(\n", + " return perform_state_preparation(\n", " reference_state=reference_state,\n", " ham_init=ham_init_operator,\n", " n_cas_init=cas_init.n,\n", " max_iteration=max_iteration_prep,\n", " atol=atol,\n", - " )\n", - "\n", - " return adapt_circ" + " )\n" ] }, { @@ -152,8 +149,8 @@ "metadata": {}, "outputs": [], "source": [ - "from pytket.circuit import Qubit\n", "from example_workers.qsci_worker.src.qsci.utils import make_time_evolution_circuits\n", + "from pytket.circuit import Qubit\n", "\n", "\n", "def circuits_from_hamiltonians(\n", @@ -168,27 +165,27 @@ ") -> list[Circuit]:\n", " ham_init_operator = QubitPauliOperator(\n", " cast(\n", - " dict[QubitPauliString, CoeffTypeAccepted],\n", + " \"dict[QubitPauliString, CoeffTypeAccepted]\",\n", " qubit_mapping_jordan_wigner(\n", " *rhf2ghf(\n", " ham_init.h0,\n", " np.array(ham_init.h1),\n", " np.array(ham_init.h2),\n", - " )\n", + " ),\n", " ),\n", - " )\n", + " ),\n", " )\n", " ham_hsim_operator = QubitPauliOperator(\n", " cast(\n", - " dict[QubitPauliString, CoeffTypeAccepted],\n", + " \"dict[QubitPauliString, CoeffTypeAccepted]\",\n", " qubit_mapping_jordan_wigner(\n", " *rhf2ghf(\n", " ham_hsim.h0,\n", " np.array(ham_hsim.h1),\n", " np.array(ham_hsim.h2),\n", - " )\n", + " ),\n", " ),\n", - " )\n", + " ),\n", " )\n", " # Load the input data.\n", " n_core_init = get_n_core(mo_occ, cas_init.n_ele)\n", @@ -208,19 +205,18 @@ " {\n", " Qubit(qubit.index[0] + 2 * n_core): pauli\n", " for qubit, pauli in qps.map.items()\n", - " }\n", + " },\n", " ): coeff\n", " for qps, coeff in ham_init_operator._dict.items()\n", - " }\n", + " },\n", " )\n", - " circuits = make_time_evolution_circuits(\n", + " return make_time_evolution_circuits(\n", " t_step_list,\n", " prepared_circ,\n", " h_hsim=ham_hsim_operator,\n", " h_init=ham_init_shifted,\n", " max_cx_gates=max_cx_gates,\n", - " )\n", - " return circuits" + " )" ] }, { @@ -232,13 +228,12 @@ "source": [ "from collections import Counter\n", "\n", - "from pytket.backends.backendresult import BackendResult\n", - "\n", "from example_workers.qsci_worker.src.qsci.postprocess import (\n", " get_ci_matrix,\n", " postprocess_configs,\n", ")\n", "from example_workers.qsci_worker.src.qsci.utils import get_config_from_cas_init\n", + "from pytket.backends.backendresult import BackendResult\n", "\n", "\n", "def energy_from_results(\n", @@ -254,7 +249,7 @@ " counts[k] += v\n", " phis = list(counts.keys())\n", " phis_init_orig = get_config_from_cas_init(\n", - " mo_occ, cas_init.n, cas_init.n_ele, cas_hsim.n, cas_hsim.n_ele\n", + " mo_occ, cas_init.n, cas_init.n_ele, cas_hsim.n, cas_hsim.n_ele,\n", " )\n", " for p in phis_init_orig:\n", " if p not in phis:\n", @@ -330,10 +325,9 @@ "metadata": {}, "outputs": [], "source": [ + "from tierkreis.aer_worker import submit_single\n", "from tierkreis.builder import GraphBuilder\n", "from tierkreis.controller.data.models import TKR, OpaqueType\n", - "\n", - "from tierkreis.aer_worker import submit_single\n", "from tierkreis.quantinuum_worker import compile_circuit_quantinuum\n", "\n", "\n", @@ -342,8 +336,8 @@ " TKR[OpaqueType[\"pytket.backends.backendresult.BackendResult\"]], # noqa: F821\n", "]:\n", " g = GraphBuilder(\n", - " TKR[OpaqueType[\"pytket._tket.circuit.Circuit\"]], # noqa: F821\n", - " TKR[OpaqueType[\"pytket.backends.backendresult.BackendResult\"]], # noqa: F821\n", + " TKR[OpaqueType[\"pytket._tket.circuit.Circuit\"]],\n", + " TKR[OpaqueType[\"pytket.backends.backendresult.BackendResult\"]],\n", " )\n", "\n", " n_shots = g.const(500)\n", @@ -402,9 +396,9 @@ "outputs": [], "source": [ "from example_workers.qsci_worker.api.stubs import ( # noqa: F811\n", - " make_ham,\n", " circuits_from_hamiltonians,\n", " energy_from_results,\n", + " make_ham,\n", " state_prep,\n", ")\n", "\n", @@ -412,13 +406,13 @@ "# Separate tasks 'make_h_init'+'state_pre' and 'make_h_hsim' run in parallel\n", "ham_init = qsci_graph.task(\n", " make_ham(\n", - " qsci_graph.inputs.molecule, qsci_graph.inputs.mo_occ, qsci_graph.inputs.cas_init\n", - " )\n", + " qsci_graph.inputs.molecule, qsci_graph.inputs.mo_occ, qsci_graph.inputs.cas_init,\n", + " ),\n", ")\n", "ham_hsim = qsci_graph.task(\n", " make_ham(\n", - " qsci_graph.inputs.molecule, qsci_graph.inputs.mo_occ, qsci_graph.inputs.cas_hsim\n", - " )\n", + " qsci_graph.inputs.molecule, qsci_graph.inputs.mo_occ, qsci_graph.inputs.cas_hsim,\n", + " ),\n", ")\n", "\n", "adapt_circuit = qsci_graph.task(\n", @@ -430,7 +424,7 @@ " qsci_graph.inputs.mo_occ,\n", " qsci_graph.inputs.cas_init,\n", " qsci_graph.inputs.cas_hsim,\n", - " )\n", + " ),\n", ")\n", "circuits = qsci_graph.task(\n", " circuits_from_hamiltonians(\n", @@ -442,7 +436,7 @@ " qsci_graph.inputs.cas_hsim,\n", " qsci_graph.inputs.mo_occ,\n", " qsci_graph.inputs.max_cx_gates_hsim,\n", - " )\n", + " ),\n", ")\n", "backend_results = qsci_graph.map(_compile_and_run(), circuits)\n", "energy = qsci_graph.task(\n", @@ -452,7 +446,7 @@ " qsci_graph.inputs.mo_occ,\n", " qsci_graph.inputs.cas_init,\n", " qsci_graph.inputs.cas_hsim,\n", - " )\n", + " ),\n", ")\n", "\n", "qsci_graph.outputs(QSCIOutputs(energy))" @@ -491,14 +485,13 @@ "custom_executor = UvExecutor(registry_path=registry_path, logs_path=storage.logs_path)\n", "common_registry_path = PACKAGE_PATH.parent / \"tierkreis_workers\"\n", "common_executor = UvExecutor(\n", - " registry_path=common_registry_path, logs_path=storage.logs_path\n", + " registry_path=common_registry_path, logs_path=storage.logs_path,\n", ")\n", "multi_executor = MultipleExecutor(\n", " common_executor,\n", " executors={\"custom\": custom_executor},\n", " assignments={\"qsci_worker\": \"custom\"},\n", ")\n", - "print(\"Starting workflow at location:\", storage.logs_path)\n", "\n", "run_graph(\n", " storage,\n", @@ -537,8 +530,7 @@ " },\n", " polling_interval_seconds=0.01,\n", ")\n", - "output = read_outputs(qsci_graph, storage)\n", - "print(output)" + "output = read_outputs(qsci_graph, storage)" ] } ], diff --git a/docs/source/examples/restart.ipynb b/docs/source/examples/restart.ipynb index 20e5ff0ee..6f7293bec 100644 --- a/docs/source/examples/restart.ipynb +++ b/docs/source/examples/restart.ipynb @@ -23,11 +23,13 @@ "outputs": [], "source": [ "from uuid import UUID\n", + "\n", "from data.typed_eval import typed_eval\n", - "from tierkreis.consts import WORKERS_DIR\n", + "\n", "from tierkreis import run_graph\n", - "from tierkreis.storage import FileStorage\n", + "from tierkreis.consts import WORKERS_DIR\n", "from tierkreis.executor import UvExecutor\n", + "from tierkreis.storage import FileStorage\n", "\n", "storage = FileStorage(UUID(int=205), \"restart_example\", do_cleanup=True)\n", "executor = UvExecutor(WORKERS_DIR, storage.logs_path)\n", diff --git a/docs/source/examples/scipy.ipynb b/docs/source/examples/scipy.ipynb index 3856f88ce..fd662f71f 100644 --- a/docs/source/examples/scipy.ipynb +++ b/docs/source/examples/scipy.ipynb @@ -38,7 +38,6 @@ "\n", "from tierkreis.controller.data.models import TKR, OpaqueType\n", "\n", - "\n", "NDArray = OpaqueType[\"numpy.ndarray\"]\n", "\n", "\n", @@ -62,17 +61,17 @@ "metadata": {}, "outputs": [], "source": [ - "from tierkreis.builder import GraphBuilder\n", - "from tierkreis.controller.data.core import EmptyModel\n", - "\n", "from example_workers.scipy_worker.api.stubs import (\n", - " transpose,\n", - " reshape,\n", - " linspace,\n", " add_point,\n", " eval_point,\n", + " linspace,\n", + " reshape,\n", + " transpose,\n", ")\n", "\n", + "from tierkreis.builder import GraphBuilder\n", + "from tierkreis.controller.data.core import EmptyModel\n", + "\n", "sample_graph = GraphBuilder(EmptyModel, ScipyOutputs)\n", "onedim = sample_graph.task(linspace(sample_graph.const(0), sample_graph.const(10)))\n", "\n", @@ -103,6 +102,7 @@ "import os\n", "import pickle\n", "from typing import Annotated\n", + "\n", "import numpy as np\n", "\n", "from tierkreis.controller.data.core import Deserializer, Serializer\n", @@ -157,17 +157,15 @@ "from pathlib import Path\n", "from uuid import UUID\n", "\n", - "\n", - "from tierkreis.storage import FileStorage, read_outputs\n", - "from tierkreis.executor import UvExecutor\n", "from tierkreis import run_graph\n", + "from tierkreis.executor import UvExecutor\n", + "from tierkreis.storage import FileStorage, read_outputs\n", "\n", "storage = FileStorage(UUID(int=207), do_cleanup=True, name=\"scipy_graph\")\n", "executor = UvExecutor(Path().parent / \"example_workers\", storage.logs_path)\n", "run_graph(storage, executor, sample_graph, {})\n", "\n", - "outputs = read_outputs(sample_graph, storage)\n", - "print(outputs)" + "outputs = read_outputs(sample_graph, storage)" ] }, { @@ -215,8 +213,7 @@ "storage.clean_graph_files()\n", "run_graph(storage, executor, sample_graph, {})\n", "\n", - "outputs = read_outputs(sample_graph, storage)\n", - "print(outputs)" + "outputs = read_outputs(sample_graph, storage)" ] } ], diff --git a/docs/source/examples/signing_graph.ipynb b/docs/source/examples/signing_graph.ipynb index 8b9dd0536..8c54c4e4a 100644 --- a/docs/source/examples/signing_graph.ipynb +++ b/docs/source/examples/signing_graph.ipynb @@ -75,8 +75,8 @@ "outputs": [], "source": [ "from pathlib import Path\n", - "from tierkreis.namespace import Namespace\n", "\n", + "from tierkreis.namespace import Namespace\n", "\n", "if __name__ == \"__main__\":\n", " tsp_path = (\n", @@ -125,12 +125,12 @@ "metadata": {}, "outputs": [], "source": [ + "from example_workers.auth_worker.api.stubs import sign, verify\n", + "from example_workers.openssl_worker.api.stubs import Outputs, genrsa\n", + "\n", "from tierkreis.builder import GraphBuilder\n", "from tierkreis.models import TKR, EmptyModel\n", "\n", - "from example_workers.auth_worker.api.stubs import sign, verify\n", - "from example_workers.openssl_worker.api.stubs import genrsa, Outputs\n", - "\n", "\n", "def signing_graph():\n", " g = GraphBuilder(EmptyModel, TKR[bool])\n", @@ -165,6 +165,7 @@ "outputs": [], "source": [ "from uuid import UUID\n", + "\n", "from tierkreis.storage import FileStorage\n", "\n", "storage = FileStorage(UUID(int=105))\n", @@ -191,12 +192,12 @@ "metadata": {}, "outputs": [], "source": [ - "from tierkreis.executor import MultipleExecutor, UvExecutor, ShellExecutor\n", + "from tierkreis.executor import MultipleExecutor, ShellExecutor, UvExecutor\n", "\n", "registry_path = Path().parent / \"example_workers\"\n", "uv = UvExecutor(registry_path, storage.logs_path)\n", "shell = ShellExecutor(\n", - " registry_path, storage.workflow_dir\n", + " registry_path, storage.workflow_dir,\n", ") # export_values=True enables passing values via env vars\n", "executor = MultipleExecutor(uv, {\"shell\": shell}, {\"openssl_worker\": \"shell\"})" ] @@ -217,12 +218,11 @@ "metadata": {}, "outputs": [], "source": [ - "from tierkreis.storage import read_outputs\n", "from tierkreis import run_graph\n", + "from tierkreis.storage import read_outputs\n", "\n", "run_graph(storage, executor, signing_graph().get_data(), {})\n", - "is_verified = read_outputs(signing_graph().get_data(), storage)\n", - "print(is_verified)" + "is_verified = read_outputs(signing_graph().get_data(), storage)" ] }, { @@ -263,8 +263,7 @@ "storage.clean_graph_files()\n", "stdinout = StdInOut(registry_path, storage.workflow_dir)\n", "run_graph(storage, stdinout, stdinout_graph().get_data(), {})\n", - "out = read_outputs(stdinout_graph().get_data(), storage)\n", - "print(out)" + "out = read_outputs(stdinout_graph().get_data(), storage)" ] } ], diff --git a/docs/source/examples/storage_and_executors.ipynb b/docs/source/examples/storage_and_executors.ipynb index fdec427d6..fb4bb7e15 100644 --- a/docs/source/examples/storage_and_executors.ipynb +++ b/docs/source/examples/storage_and_executors.ipynb @@ -60,8 +60,8 @@ "from tierkreis.builder import GraphBuilder\n", "from tierkreis.controller.data.models import TKR\n", "from tierkreis.quantinuum_worker import (\n", - " get_backend_info,\n", " compile_using_info,\n", + " get_backend_info,\n", " run_circuit,\n", ")\n", "\n", @@ -70,8 +70,8 @@ "compiled_circuit = g.task(compile_using_info(g.inputs, info))\n", "results = g.task(\n", " run_circuit(\n", - " circuit=compiled_circuit, n_shots=g.const(10), device_name=g.const(\"H2-1SC\")\n", - " )\n", + " circuit=compiled_circuit, n_shots=g.const(10), device_name=g.const(\"H2-1SC\"),\n", + " ),\n", ")\n", "g.outputs(results)" ] @@ -94,6 +94,7 @@ "outputs": [], "source": [ "from uuid import UUID\n", + "\n", "from tierkreis.storage import FileStorage\n", "\n", "storage = FileStorage(UUID(int=209), do_cleanup=True, name=\"quantinuum_submission\")" @@ -162,9 +163,11 @@ "outputs": [], "source": [ "from pathlib import Path\n", - "from tierkreis import run_graph\n", + "\n", "from pytket.qasm.qasm import circuit_from_qasm\n", "\n", + "from tierkreis import run_graph\n", + "\n", "circuit = circuit_from_qasm(Path().parent / \"data\" / \"ghz_state_n23.qasm\")\n", "run_graph(storage, executor, g, circuit)" ] @@ -186,8 +189,7 @@ "source": [ "from tierkreis.storage import read_outputs\n", "\n", - "outputs = read_outputs(g, storage)\n", - "print(outputs)" + "outputs = read_outputs(g, storage)" ] } ], diff --git a/docs/source/examples/types_and_defaults.ipynb b/docs/source/examples/types_and_defaults.ipynb index daa9643ec..a76a090dc 100644 --- a/docs/source/examples/types_and_defaults.ipynb +++ b/docs/source/examples/types_and_defaults.ipynb @@ -37,6 +37,7 @@ "outputs": [], "source": [ "from typing import NamedTuple\n", + "\n", "from tierkreis.controller.data.models import TKR, OpaqueType\n", "\n", "\n", @@ -66,8 +67,8 @@ "metadata": {}, "outputs": [], "source": [ + "from tierkreis.aer_worker import get_compiled_circuit, submit_single\n", "from tierkreis.builder import GraphBuilder\n", - "from tierkreis.aer_worker import submit_single, get_compiled_circuit\n", "\n", "\n", "def compile_run_single() -> GraphBuilder[IBMInput, IBMOutput]:\n", @@ -77,7 +78,7 @@ " get_compiled_circuit(\n", " circuit=g.inputs.circuit,\n", " optimisation_level=g.const(2),\n", - " )\n", + " ),\n", " )\n", " res = g.task(submit_single(compiled_circuit, g.inputs.n_shots))\n", " g.outputs(IBMOutput(res, g.inputs.n_shots))\n", @@ -100,6 +101,7 @@ "outputs": [], "source": [ "from uuid import UUID\n", + "\n", "from pytket._tket.circuit import Circuit\n", "\n", "from tierkreis.consts import PACKAGE_PATH\n", @@ -131,8 +133,7 @@ " graph,\n", " inputs,\n", ")\n", - "res = read_outputs(graph, storage)\n", - "print(res)" + "res = read_outputs(graph, storage)" ] }, { @@ -169,7 +170,7 @@ " get_compiled_circuit(\n", " circuit=g.inputs.circuit,\n", " optimisation_level=g.inputs.optimisation_level,\n", - " )\n", + " ),\n", " )\n", " res = g.task(submit_single(compiled_circuit, g.inputs.n_shots))\n", " g.outputs(IBMOutput(res, g.inputs.n_shots)) # type: ignore\n", @@ -206,8 +207,7 @@ " graph,\n", " inputs,\n", ")\n", - "res = read_outputs(graph, storage)\n", - "print(res)" + "res = read_outputs(graph, storage)" ] } ], diff --git a/docs/source/examples/worker.ipynb b/docs/source/examples/worker.ipynb index fea2fed19..4d81f734b 100644 --- a/docs/source/examples/worker.ipynb +++ b/docs/source/examples/worker.ipynb @@ -78,6 +78,7 @@ "source": [ "import logging\n", "from pathlib import Path\n", + "\n", "from tierkreis import Worker\n", "\n", "logger = logging.getLogger(__name__)\n", diff --git a/docs/source/tutorial/auth_stubs.py b/docs/source/tutorial/auth_stubs.py index 00dec255b..586188b8c 100644 --- a/docs/source/tutorial/auth_stubs.py +++ b/docs/source/tutorial/auth_stubs.py @@ -1,26 +1,27 @@ """Code generated from auth_worker namespace. Please do not edit.""" from typing import NamedTuple + from tierkreis.controller.data.models import TKR class EncryptionResult(NamedTuple): - ciphertext: TKR[str] # noqa: F821 # fmt: skip - time_taken: TKR[float] # noqa: F821 # fmt: skip + ciphertext: TKR[str] # fmt: skip + time_taken: TKR[float] # fmt: skip class SigningResult(NamedTuple): - hex_signature: TKR[str] # noqa: F821 # fmt: skip - time_taken: TKR[float] # noqa: F821 # fmt: skip + hex_signature: TKR[str] # fmt: skip + time_taken: TKR[float] # fmt: skip class encrypt(NamedTuple): - plaintext: TKR[str] # noqa: F821 # fmt: skip - work_factor: TKR[int] # noqa: F821 # fmt: skip + plaintext: TKR[str] # fmt: skip + work_factor: TKR[int] # fmt: skip @staticmethod - def out() -> type[EncryptionResult]: # noqa: F821 # fmt: skip - return EncryptionResult # noqa: F821 # fmt: skip + def out() -> type[EncryptionResult]: # fmt: skip + return EncryptionResult # fmt: skip @property def namespace(self) -> str: @@ -28,13 +29,13 @@ def namespace(self) -> str: class sign(NamedTuple): - private_key: TKR[bytes] # noqa: F821 # fmt: skip - passphrase: TKR[bytes] # noqa: F821 # fmt: skip - message: TKR[str] # noqa: F821 # fmt: skip + private_key: TKR[bytes] # fmt: skip + passphrase: TKR[bytes] # fmt: skip + message: TKR[str] # fmt: skip @staticmethod - def out() -> type[SigningResult]: # noqa: F821 # fmt: skip - return SigningResult # noqa: F821 # fmt: skip + def out() -> type[SigningResult]: # fmt: skip + return SigningResult # fmt: skip @property def namespace(self) -> str: @@ -42,13 +43,13 @@ def namespace(self) -> str: class verify(NamedTuple): - public_key: TKR[bytes] # noqa: F821 # fmt: skip - signature: TKR[str] # noqa: F821 # fmt: skip - message: TKR[str] # noqa: F821 # fmt: skip + public_key: TKR[bytes] # fmt: skip + signature: TKR[str] # fmt: skip + message: TKR[str] # fmt: skip @staticmethod - def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip - return TKR[bool] # noqa: F821 # fmt: skip + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip @property def namespace(self) -> str: diff --git a/docs/source/worker/hello_stubs.py b/docs/source/worker/hello_stubs.py index da89e75a5..5afd6b572 100644 --- a/docs/source/worker/hello_stubs.py +++ b/docs/source/worker/hello_stubs.py @@ -1,16 +1,17 @@ """Code generated from hello_world_worker namespace. Please do not edit.""" from typing import NamedTuple + from tierkreis.controller.data.models import TKR class greet(NamedTuple): - greeting: TKR[str] # noqa: F821 # fmt: skip - subject: TKR[str] # noqa: F821 # fmt: skip + greeting: TKR[str] # fmt: skip + subject: TKR[str] # fmt: skip @staticmethod - def out() -> type[TKR[str]]: # noqa: F821 # fmt: skip - return TKR[str] # noqa: F821 # fmt: skip + def out() -> type[TKR[str]]: # fmt: skip + return TKR[str] # fmt: skip @property def namespace(self) -> str: diff --git a/infra/slurm_local/main.py b/infra/slurm_local/main.py index 1c088799f..346cdae36 100644 --- a/infra/slurm_local/main.py +++ b/infra/slurm_local/main.py @@ -8,9 +8,10 @@ import socket from sys import argv -from tierkreis import Worker from mpi4py import MPI # type: ignore +from tierkreis import Worker + logger = logging.getLogger(__name__) worker = Worker("slurm_mpi_worker") @@ -32,7 +33,6 @@ def mpi_rank_info() -> str | None: size = comm.Get_size() info = _proc_info() all_processes_info = comm.gather(info, root=0) - print(all_processes_info) if rank == 0: return "\n".join( f"Rank {info['rank']} out of {size} on {info['hostname']}." diff --git a/tierkreis/pyproject.toml b/tierkreis/pyproject.toml index 07a21cf2c..613cf0789 100644 --- a/tierkreis/pyproject.toml +++ b/tierkreis/pyproject.toml @@ -30,7 +30,38 @@ build-backend = "hatchling.build" [tool.ruff] target-version = "py312" -extend-exclude = [] +extend-exclude = [ + # Ignore worker stubs: + "*_worker.py", + "stubs.py", + "stubs_output.py", + # Ignore docs + "tierkreis/docs/*", +] +[tool.ruff.lint] +select = ["ALL"] +isort.known-first-party = ["tierkreis", "tierkreis_visualizer"] +pydocstyle.convention = "pep257" + +# Ignore specific rules that might be redundant or annoying +ignore = [ + "D203", # Conflict: 1 blank line before class (D203) vs no blank lines (D211) + "D213", # Conflict: Multi-line docstring summary start (D212) vs (D213) + "ISC001", # Single line implicit string concatenation + "S603", # Subprocess calls + "D107", # In favor of documenting classes directly +] + +[tool.ruff.lint.per-file-ignores] +"tierkreis/tests/*" = [ + "S101", # asserts allowed in tests... + "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... + "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() + "PLR2004", # Magic value used in comparison, ... + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "D", +] + [tool.pyright] include = ["."] diff --git a/tierkreis/tests/cli/test_run_workflow.py b/tierkreis/tests/cli/test_run_workflow.py index 3408d4944..bdd29651c 100644 --- a/tierkreis/tests/cli/test_run_workflow.py +++ b/tierkreis/tests/cli/test_run_workflow.py @@ -1,17 +1,17 @@ -import pytest import json from pathlib import Path -from uuid import UUID from unittest import mock +from uuid import UUID +import pytest -from tierkreis.controller.data.graph import GraphData -from tierkreis.cli.run_workflow import run_workflow from tests.controller.sample_graphdata import simple_eval +from tierkreis.cli.run_workflow import run_workflow +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import ptype_from_bytes -@pytest.fixture() +@pytest.fixture def graph() -> GraphData: return simple_eval() @@ -19,7 +19,7 @@ def graph() -> GraphData: def test_run_workflow(graph: GraphData) -> None: inputs = {} run_workflow(inputs=inputs, graph=graph, run_id=31415) - with open( + with Path.open( Path.home() / ".tierkreis" / "checkpoints" @@ -32,18 +32,24 @@ def test_run_workflow(graph: GraphData) -> None: assert c == 12 -def test_run_workflow_with_output(graph: GraphData, capfd) -> None: +def test_run_workflow_with_output(graph: GraphData, capfd) -> None: # noqa: ANN001 inputs = {} run_workflow(inputs=inputs, graph=graph, run_id=31415, print_output=True) out, _ = capfd.readouterr() - assert "{'simple_eval_output': 12}\n" in out + assert "simple_eval_output: 12\n" in out + + +@pytest.fixture +def _patch_uuid4() -> mock.Mock: + with mock.patch("uuid.uuid4", return_value=UUID(int=31415)) as m: + return m -@mock.patch("uuid.uuid4", return_value=UUID(int=31415)) -def test_run_workflow_default_run_id(_, graph: GraphData) -> None: +@pytest.mark.usefixtures("_patch_uuid4", "graph") +def test_run_workflow_default_run_id(graph: GraphData) -> None: inputs = {} run_workflow(inputs=inputs, graph=graph) - with open( + with Path.open( Path.home() / ".tierkreis" / "checkpoints" @@ -61,6 +67,6 @@ def test_run_workflow_uv_executor(graph: GraphData) -> None: inputs=inputs, graph=graph, run_id=31415, - use_uv_worker=True, - registry_path=Path("."), + use_uv_executor=True, + registry_path=Path(), ) diff --git a/tierkreis/tests/cli/test_tkr.py b/tierkreis/tests/cli/test_tkr.py index 75cdd42fa..92dcafa39 100644 --- a/tierkreis/tests/cli/test_tkr.py +++ b/tierkreis/tests/cli/test_tkr.py @@ -1,18 +1,18 @@ import json -import pytest import sys from pathlib import Path from unittest import mock from uuid import UUID -from tierkreis.cli.run import load_graph, _load_inputs +import pytest + +from tests.controller.sample_graphdata import simple_eval +from tierkreis.cli.run import _load_inputs, load_graph from tierkreis.cli.tkr import main from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType from tierkreis.exceptions import TierkreisError -from tests.controller.sample_graphdata import simple_eval - simple_eval_graph = simple_eval() graph_params = [ @@ -21,9 +21,13 @@ ] -@pytest.mark.parametrize("input,graph", graph_params, ids=["load_module", "load_file"]) -def test_load_graph(input: str, graph: GraphData) -> None: - assert load_graph(input) == graph +@pytest.mark.parametrize( + ("inputs", "graph"), + graph_params, + ids=["load_module", "load_file"], +) +def test_load_graph(inputs: str, graph: GraphData) -> None: + assert load_graph(inputs) == graph def test_load_graph_invalid() -> None: @@ -56,10 +60,12 @@ def test_load_graph_invalid() -> None: @pytest.mark.parametrize( - "input,result", input_params, ids=["json_input", "binary_input"] + ("inputs", "result"), + input_params, + ids=["json_input", "binary_input"], ) -def test_load_inputs(input: list[str], result: dict[str, PType]) -> None: - assert _load_inputs(input) == result +def test_load_inputs(inputs: list[str], result: dict[str, PType]) -> None: + assert _load_inputs(inputs) == result def test_load_inputs_invalid() -> None: @@ -92,12 +98,12 @@ def test_load_inputs_invalid() -> None: cli_params = [ ( - default_args + ["-f", "tierkreis/tests/cli/data/sample_graph"], + [*default_args, "-f", "tierkreis/tests/cli/data/sample_graph"], {"simple_eval_output": 12}, ), ( - default_args - + [ + [ + *default_args, "-g", "tests.controller.sample_graphdata:factorial", "-i", @@ -110,13 +116,15 @@ def test_load_inputs_invalid() -> None: @pytest.mark.parametrize( - "args,result", cli_params, ids=["simple_eval_cli", "factorial_cli"] + ("args", "result"), + cli_params, + ids=["simple_eval_cli", "factorial_cli"], ) def test_end_to_end(args: list[str], result: dict[str, bytes]) -> None: with mock.patch.object(sys, "argv", args): main() for key, value in result.items(): - with open( + with Path.open( Path.home() / ".tierkreis" / "checkpoints" diff --git a/tierkreis/tests/conftest.py b/tierkreis/tests/conftest.py index afe4ba84f..62a9a2f45 100644 --- a/tierkreis/tests/conftest.py +++ b/tierkreis/tests/conftest.py @@ -1,17 +1,23 @@ import pytest -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( - "--optional", action="store_true", default=False, help="run optional tests" + "--optional", + action="store_true", + default=False, + help="run optional tests", ) -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line("markers", "optional: mark test as optional to run") -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems( + config: pytest.Config, + items: list[pytest.Item], +) -> None: if config.getoption("--optional"): return skip_slow = pytest.mark.skip(reason="need --optional option to run") diff --git a/tierkreis/tests/controller/defaults_graphs.py b/tierkreis/tests/controller/defaults_graphs.py index cffb4c549..f1556a2ef 100644 --- a/tierkreis/tests/controller/defaults_graphs.py +++ b/tierkreis/tests/controller/defaults_graphs.py @@ -1,8 +1,8 @@ from typing import NamedTuple from tierkreis.builder import GraphBuilder -from tierkreis.controller.data.models import TKR from tierkreis.builtins.stubs import tkr_range +from tierkreis.controller.data.models import TKR class Inputs(NamedTuple): diff --git a/tierkreis/tests/controller/loop_graphdata.py b/tierkreis/tests/controller/loop_graphdata.py index 46467e1cc..426a87ea1 100644 --- a/tierkreis/tests/controller/loop_graphdata.py +++ b/tierkreis/tests/controller/loop_graphdata.py @@ -1,7 +1,8 @@ from typing import NamedTuple + import tierkreis.builtins.stubs as tkr_builtins -from tierkreis.controller.data.core import EmptyModel from tierkreis.builder import GraphBuilder +from tierkreis.controller.data.core import EmptyModel from tierkreis.controller.data.graph import GraphData from tierkreis.models import TKR @@ -30,7 +31,7 @@ def _loop_body_multiple_acc_untyped() -> GraphData: "acc1": new_acc, "acc2": new_acc2, "acc3": new_acc3, - } + }, ) return g @@ -136,10 +137,10 @@ def _loop_body_scoping() -> GraphBuilder[Scoping, ScopingOut]: one = g.const(1) - next = g.task(tkr_builtins.iadd(g.inputs.current, one)) + next_val = g.task(tkr_builtins.iadd(g.inputs.current, one)) should_continue = g.task(tkr_builtins.neq(g.inputs.end, g.inputs.current)) - g.outputs(ScopingOut(should_continue=should_continue, current=next)) + g.outputs(ScopingOut(should_continue=should_continue, current=next_val)) return g diff --git a/tierkreis/tests/controller/main.py b/tierkreis/tests/controller/main.py index d134ab716..97cca9a55 100644 --- a/tierkreis/tests/controller/main.py +++ b/tierkreis/tests/controller/main.py @@ -6,8 +6,8 @@ # tierkreis = { path = "../../../tierkreis", editable = true } # /// from pathlib import Path -from time import sleep from sys import argv +from time import sleep from tierkreis import Worker diff --git a/tierkreis/tests/controller/sample_graphdata.py b/tierkreis/tests/controller/sample_graphdata.py index ee46a88a4..a9f0a7465 100644 --- a/tierkreis/tests/controller/sample_graphdata.py +++ b/tierkreis/tests/controller/sample_graphdata.py @@ -38,10 +38,10 @@ def loop_body() -> GraphData: g = GraphData() a = g.input("loop_acc") one = g.const(1) - N = g.const(10) + n_val: tuple[int, str] = g.const(10) a_plus = g.func("builtins.iadd", {"a": a, "b": one})("value") - pred = g.func("builtins.igt", {"a": N, "b": a_plus})("value") + pred = g.func("builtins.igt", {"a": n_val, "b": a_plus})("value") g.output({"loop_acc": a_plus, "should_continue": pred}) return g @@ -58,11 +58,11 @@ def simple_loop() -> GraphData: def simple_map() -> GraphData: g = GraphData() six = g.const(6) - Ns_const = g.const(list(range(21))) - Ns = g.func("builtins.unfold_values", {Labels.VALUE: Ns_const}) + n_consts = g.const(list(range(21))) + n_vals = g.func("builtins.unfold_values", {Labels.VALUE: n_consts}) doubler_const = g.const(doubler_plus()) - m = g.map(doubler_const, {"doubler_input": Ns("*"), "intercept": six}) + m = g.map(doubler_const, {"doubler_input": n_vals("*"), "intercept": six}) folded = g.func("builtins.fold_values", {"values_glob": m("*")}) g.output({"value": folded(Labels.VALUE)}) return g @@ -82,11 +82,11 @@ def maps_in_series() -> GraphData: g = GraphData() zero = g.const(0) - Ns_const = g.const(list(range(21))) - Ns = g.func("builtins.unfold_values", {Labels.VALUE: Ns_const}) + n_consts = g.const(list(range(21))) + n_vals = g.func("builtins.unfold_values", {Labels.VALUE: n_consts}) doubler_const = g.const(doubler_plus()) - m = g.map(doubler_const, {"doubler_input": Ns("*"), "intercept": zero}) + m = g.map(doubler_const, {"doubler_input": n_vals("*"), "intercept": zero}) m2 = g.map(doubler_const, {"doubler_input": m("*"), "intercept": zero}) folded = g.func("builtins.fold_values", {"values_glob": m2("*")}) @@ -97,11 +97,11 @@ def maps_in_series() -> GraphData: def map_with_str_keys() -> GraphData: g = GraphData() zero = g.const(0) - Ns_const = g.const({"one": 1, "two": 2, "three": 3}) - Ns = g.func("builtins.unfold_dict", {Labels.VALUE: Ns_const}) + n_consts = g.const({"one": 1, "two": 2, "three": 3}) + n_vals = g.func("builtins.unfold_dict", {Labels.VALUE: n_consts}) doubler_const = g.const(doubler_plus()) - m = g.map(doubler_const, {"doubler_input": Ns("*"), "intercept": zero}) + m = g.map(doubler_const, {"doubler_input": n_vals("*"), "intercept": zero}) folded = g.func("builtins.fold_dict", {"values_glob": m("*")}) g.output({"value": folded(Labels.VALUE)}) return g diff --git a/tierkreis/tests/controller/test_codegen.py b/tierkreis/tests/controller/test_codegen.py index 7c011b850..6d3996ace 100644 --- a/tierkreis/tests/controller/test_codegen.py +++ b/tierkreis/tests/controller/test_codegen.py @@ -1,5 +1,7 @@ from types import NoneType + import pytest + from tierkreis.codegen import format_generic_type from tierkreis.controller.data.types import PType from tierkreis.idl.models import GenericType @@ -21,8 +23,10 @@ ] -@pytest.mark.parametrize("ttype,expected", formats) -def test_format_ttype(ttype: type[PType], expected: str): +@pytest.mark.parametrize(("ttype", "expected"), formats) +def test_format_ttype(ttype: type[PType], expected: str) -> None: generic_type = GenericType.from_type(ttype) - assert format_generic_type(generic_type, False, False) == expected + assert ( + format_generic_type(generic_type, include_bound=False, is_tkr=False) == expected + ) diff --git a/tierkreis/tests/controller/test_eagerifelse.py b/tierkreis/tests/controller/test_eagerifelse.py index 4bfe0d6f3..6afd9af51 100644 --- a/tierkreis/tests/controller/test_eagerifelse.py +++ b/tierkreis/tests/controller/test_eagerifelse.py @@ -1,20 +1,20 @@ import json -import pytest from pathlib import Path from uuid import UUID +import pytest + from tests.controller.sample_graphdata import ( simple_eagerifelse, simple_ifelse, ) - from tierkreis.controller import run_graph +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import PType from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.executor.uv_executor import UvExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage -from tierkreis.controller.data.graph import GraphData def eagerifelse_long_running() -> GraphData: @@ -34,8 +34,8 @@ def eagerifelse_long_running() -> GraphData: params = [({"pred": True}, 1), ({"pred": False}, 2)] -@pytest.mark.parametrize("input, output", params) -def test_eagerifelse_long_running(input: dict[str, PType], output: int) -> None: +@pytest.mark.parametrize(("inputs", "output"), params) +def test_eagerifelse_long_running(inputs: dict[str, PType], output: int) -> None: g = eagerifelse_long_running() storage = ControllerFileStorage(UUID(int=150), name="eagerifelse_long_running") @@ -43,7 +43,7 @@ def test_eagerifelse_long_running(input: dict[str, PType], output: int) -> None: executor = UvExecutor(registry_path=registry_path, logs_path=storage.logs_path) storage.clean_graph_files() - run_graph(storage, executor, g, input, n_iterations=20000) + run_graph(storage, executor, g, inputs, n_iterations=20000) actual_output = json.loads(storage.read_output(Loc(), "simple_eagerifelse_output")) assert actual_output == output @@ -58,7 +58,7 @@ def test_eagerifelse_nodes() -> None: assert storage.is_node_finished(Loc("-.N4")) -def test_ifelse_nodes(): +def test_ifelse_nodes() -> None: g = simple_ifelse() storage = ControllerFileStorage(UUID(int=152), name="simple_if_else") executor = ShellExecutor(Path("./python/examples/launchers"), storage.workflow_dir) diff --git a/tierkreis/tests/controller/test_graphdata.py b/tierkreis/tests/controller/test_graphdata.py index fa9d0ac89..92c2407f2 100644 --- a/tierkreis/tests/controller/test_graphdata.py +++ b/tierkreis/tests/controller/test_graphdata.py @@ -1,10 +1,11 @@ import pytest -from tierkreis.exceptions import TierkreisError + from tierkreis.controller.data.graph import GraphData +from tierkreis.exceptions import TierkreisError -def test_only_one_output(): +def test_only_one_output() -> None: + g = GraphData() + g.output({"one": g.const(1)}) with pytest.raises(TierkreisError): - g = GraphData() - g.output({"one": g.const(1)}) g.output({"two": g.const(2)}) diff --git a/tierkreis/tests/controller/test_graphdata_storage.py b/tierkreis/tests/controller/test_graphdata_storage.py index b54b7b06c..85a0f4424 100644 --- a/tierkreis/tests/controller/test_graphdata_storage.py +++ b/tierkreis/tests/controller/test_graphdata_storage.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize( - ["node_location_str", "graph", "target"], + ("node_location_str", "graph", "target"), [ ("-.N0", simple_eval(), Const(0, outputs={"value": [3]})), ("-.N4.M0", simple_map(), Eval((-1, "body"), {})), @@ -33,7 +33,7 @@ def test_read_nodedef(node_location_str: str, graph: GraphData, target: str) -> @pytest.mark.parametrize( - ["node_location_str", "graph", "port", "target"], + ("node_location_str", "graph", "port", "target"), [ ("-.N0", simple_eval(), "value", b"0"), ("-.N2", simple_eval(), "value", b"Graph"), @@ -42,7 +42,10 @@ def test_read_nodedef(node_location_str: str, graph: GraphData, target: str) -> ], ) def test_read_output( - node_location_str: str, graph: GraphData, port: PortID, target: str + node_location_str: str, + graph: GraphData, + port: PortID, + target: str, ) -> None: loc = Loc(node_location_str) storage = GraphDataStorage(UUID(int=0), graph) @@ -58,14 +61,16 @@ def test_raises() -> None: @pytest.mark.parametrize( - ["node_location_str", "graph", "target"], + ("node_location_str", "graph", "target"), [ ("-.N0", simple_eval(), ["value"]), ("-.N4.M0", simple_map(), ["0"]), ], ) def test_read_output_ports( - node_location_str: str, graph: GraphData, target: str + node_location_str: str, + graph: GraphData, + target: str, ) -> None: loc = Loc(node_location_str) storage = GraphDataStorage(UUID(int=0), graph) @@ -74,7 +79,7 @@ def test_read_output_ports( @pytest.mark.parametrize( - ["node_location_str", "graph", "target"], + ("node_location_str", "graph", "target"), [ ("-.N0", simple_eval(), Const(0, outputs={"value": [3]})), ("-.N3.N1", simple_eval(), Input("intercept", outputs={"intercept": [4]})), @@ -103,7 +108,9 @@ def test_read_output_ports( ], ) def test_graph_node_from_loc( - node_location_str: str, graph: GraphData, target: str + node_location_str: str, + graph: GraphData, + target: str, ) -> None: loc = Loc(node_location_str) node_def, _ = graph_node_from_loc(loc, graph) diff --git a/tierkreis/tests/controller/test_locs.py b/tierkreis/tests/controller/test_locs.py index 52051cf1f..91765cc9f 100644 --- a/tierkreis/tests/controller/test_locs.py +++ b/tierkreis/tests/controller/test_locs.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize( - ["node_location", "loc_str"], + ("node_location", "loc_str"), [ (node_location_1, "-.N1.L0.N3.L2.N0.M7.N10"), (node_location_2, "-.N0.L0.N3.N8.N0"), @@ -35,7 +35,7 @@ (node_location_4, "-"), ], ) -def test_to_from_str(node_location: Loc, loc_str: str): +def test_to_from_str(node_location: Loc, loc_str: str) -> None: node_location_str = str(node_location) assert node_location_str == loc_str @@ -44,7 +44,7 @@ def test_to_from_str(node_location: Loc, loc_str: str): @pytest.mark.parametrize( - ["node_location", "loc_str"], + ("node_location", "loc_str"), [ (node_location_1, "-.N1.L0.N3.L2.N0.M7"), (node_location_2, "-.N0.L0.N3.N8"), @@ -61,7 +61,7 @@ def test_parent(node_location: Loc, loc_str: str) -> None: @pytest.mark.parametrize( - ["node_location", "node_step", "loc_str"], + ("node_location", "node_step", "loc_str"), [ (node_location_1, ("N", 1), "-.L0.N3.L2.N0.M7.N10"), (node_location_2, ("N", 0), "-.L0.N3.N8.N0"), @@ -77,7 +77,7 @@ def test_pop_first(node_location: Loc, node_step: NodeStep, loc_str: str) -> Non @pytest.mark.parametrize( - ["node_location", "node_step", "loc_str"], + ("node_location", "node_step", "loc_str"), [ (node_location_1, ("N", 10), "-.N1.L0.N3.L2.N0.M7"), (node_location_2, ("N", 0), "-.N0.L0.N3.N8"), @@ -157,7 +157,7 @@ def test_pop_last_multiple() -> None: @pytest.mark.parametrize( - ["node_location", "index"], + ("node_location", "index"), [ (node_location_1, 10), (node_location_2, 0), @@ -171,7 +171,7 @@ def test_get_last_index(node_location: Loc, index: int) -> None: @pytest.mark.parametrize( - ["node_location", "expected"], + ("node_location", "expected"), [ ( node_location_1, @@ -201,5 +201,5 @@ def test_get_last_index(node_location: Loc, index: int) -> None: (node_location_4, [Loc()]), ], ) -def test_partial_paths(node_location: Loc, expected: list[Loc]): +def test_partial_paths(node_location: Loc, expected: list[Loc]) -> None: assert expected == node_location.partial_locs() diff --git a/tierkreis/tests/controller/test_models.py b/tierkreis/tests/controller/test_models.py index fc96f8c48..d08206cc5 100644 --- a/tierkreis/tests/controller/test_models.py +++ b/tierkreis/tests/controller/test_models.py @@ -1,9 +1,11 @@ from types import NoneType from typing import NamedTuple + import pytest + +from tests.controller.test_types import ptypes from tierkreis.controller.data.models import PModel, dict_from_pmodel, portmapping from tierkreis.controller.data.types import PType -from tests.controller.test_types import ptypes @portmapping @@ -20,7 +22,7 @@ class NamedPModel(NamedTuple): @pytest.mark.parametrize("pmodel", ptypes) -def test_dict_from_pmodel_unnested(pmodel: PModel): +def test_dict_from_pmodel_unnested(pmodel: PModel) -> None: assert dict_from_pmodel(pmodel) == {"value": pmodel} @@ -50,6 +52,6 @@ def test_dict_from_pmodel_unnested(pmodel: PModel): pmodels = [(named_p_model, named_p_model_expected)] -@pytest.mark.parametrize("pmodel,expected", pmodels) -def test_dict_from_pmodel_nested(pmodel: PModel, expected: dict[str, PType]): +@pytest.mark.parametrize(("pmodel", "expected"), pmodels) +def test_dict_from_pmodel_nested(pmodel: PModel, expected: dict[str, PType]) -> None: assert dict_from_pmodel(pmodel) == expected diff --git a/tierkreis/tests/controller/test_read_loop_trace.py b/tierkreis/tests/controller/test_read_loop_trace.py index f949d976e..5ac829b64 100644 --- a/tierkreis/tests/controller/test_read_loop_trace.py +++ b/tierkreis/tests/controller/test_read_loop_trace.py @@ -1,24 +1,23 @@ from pathlib import Path -from typing import Any, Type from uuid import UUID import pytest + from tests.controller.loop_graphdata import loop_multiple_acc, loop_multiple_acc_untyped from tierkreis.controller import run_graph +from tierkreis.controller.data.graph import GraphData from tierkreis.controller.executor.in_memory_executor import InMemoryExecutor from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage -from tierkreis.controller.data.graph import GraphData from tierkreis.storage import read_loop_trace - return_value = [ {"acc1": x, "acc2": y, "acc3": z} - for x, y, z in zip(range(1, 7), range(2, 13, 2), range(3, 19, 3)) + for x, y, z in zip(range(1, 7), range(2, 13, 2), range(3, 19, 3), strict=False) ] -params: list[tuple[GraphData, Any, str, int]] = [ +params: list[tuple[GraphData, list[dict[str, int]], str, int]] = [ ( loop_multiple_acc_untyped(), return_value, @@ -42,17 +41,17 @@ @pytest.mark.parametrize("storage_class", storage_classes, ids=storage_ids) -@pytest.mark.parametrize("graph,output,name,id", params, ids=ids) +@pytest.mark.parametrize(("graph", "output", "name", "workflow_id"), params, ids=ids) def test_read_loop_trace( - storage_class: Type[ControllerFileStorage | ControllerInMemoryStorage], + storage_class: type[ControllerFileStorage | ControllerInMemoryStorage], graph: GraphData, - output: Any, + output: list[dict[str, int]], name: str, - id: int, -): + workflow_id: int, +) -> None: g = graph - storage = storage_class(UUID(int=id), name=name) - executor = ShellExecutor(Path("./python/examples/launchers"), Path("")) + storage = storage_class(UUID(int=workflow_id), name=name) + executor = ShellExecutor(Path("./python/examples/launchers"), Path()) if isinstance(storage, ControllerInMemoryStorage): executor = InMemoryExecutor(Path("./tierkreis/tierkreis"), storage=storage) storage.clean_graph_files() diff --git a/tierkreis/tests/controller/test_restart.py b/tierkreis/tests/controller/test_restart.py index 8d04abbab..7f6f522ca 100644 --- a/tierkreis/tests/controller/test_restart.py +++ b/tierkreis/tests/controller/test_restart.py @@ -10,7 +10,7 @@ from tierkreis.storage import FileStorage, read_outputs -def test_restart(): +def test_restart() -> None: storage = FileStorage(UUID(int=300), "test_restart") storage.clean_graph_files() executor = UvExecutor(Path(__file__).parent.parent / "workers", storage.logs_path) diff --git a/tierkreis/tests/controller/test_resume.py b/tierkreis/tests/controller/test_resume.py index c1309fa75..69cb1159a 100644 --- a/tierkreis/tests/controller/test_resume.py +++ b/tierkreis/tests/controller/test_resume.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Type +from typing import Any from uuid import UUID import pytest @@ -46,7 +46,12 @@ from tierkreis.storage import read_outputs param_data: list[ - tuple[GraphData | GraphBuilder, Any, str, dict[str, PType] | PType] + tuple[ + GraphData | GraphBuilder, + dict[str, PType] | PType, + str, + dict[str, PType] | PType, + ] ] = [ (simple_eval(), {"simple_eval_output": 12}, "simple_eval", {}), (simple_loop(), 10, "simple_loop", {}), @@ -169,17 +174,21 @@ @pytest.mark.parametrize("storage_class", storage_classes, ids=storage_ids) -@pytest.mark.parametrize("graph,output,name,id,inputs", params, ids=ids) -def test_resume( - storage_class: Type[ControllerFileStorage | ControllerInMemoryStorage], +@pytest.mark.parametrize( + ("graph", "output", "name", "workflow_id", "inputs"), + params, + ids=ids, +) +def test_resume( # noqa: PLR0913 + storage_class: type[ControllerFileStorage | ControllerInMemoryStorage], graph: GraphData, - output: Any, + output: dict[str, PType] | PType, name: str, - id: int, + workflow_id: int, inputs: dict[str, PType] | PType, -): +) -> None: g = graph - storage = storage_class(UUID(int=id), name=name) + storage = storage_class(UUID(int=workflow_id), name=name) test_workers_path = Path(__file__).parent.parent / "test_workers" executor = UvExecutor(test_workers_path, storage.logs_path) if isinstance(storage, ControllerInMemoryStorage): @@ -206,13 +215,19 @@ def test_resume( @pytest.mark.parametrize( - "graph,output,name,id,inputs", with_worker_params, ids=with_worker_ids + ("graph", "output", "name", "id", "inputs"), + with_worker_params, + ids=with_worker_ids, ) def test_resume_with_worker( - graph: GraphData, output: Any, name: str, id: int, inputs: dict[str, PType] | PType -): + graph: GraphData, + output: dict[str, PType] | PType, + name: str, + workflow_id: int, + inputs: dict[str, PType] | PType, +) -> None: g = graph - storage = ControllerFileStorage(UUID(int=id), name=name) + storage = ControllerFileStorage(UUID(int=workflow_id), name=name) test_workers_path = Path(__file__).parent.parent / "workers" executor = UvExecutor(test_workers_path, storage.logs_path) storage.clean_graph_files() diff --git a/tierkreis/tests/controller/test_types.py b/tierkreis/tests/controller/test_types.py index f261d9e61..4beca82fc 100644 --- a/tierkreis/tests/controller/test_types.py +++ b/tierkreis/tests/controller/test_types.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime from types import NoneType, UnionType -from typing import Mapping, Sequence, TypeVar +from typing import TypeVar from uuid import UUID -from pydantic import BaseModel + import pytest +from pydantic import BaseModel + from tierkreis.controller.data.types import ( PType, bytes_from_ptype, @@ -96,7 +99,7 @@ def from_list(cls, args: list) -> "DummyListConvertible": @pytest.mark.parametrize("ptype", ptypes) -def test_bytes_roundtrip(ptype: PType): +def test_bytes_roundtrip(ptype: PType) -> None: bs = bytes_from_ptype(ptype) new_type = ptype_from_bytes(bs, type(ptype)) assert ptype == new_type @@ -117,7 +120,7 @@ def test_bytes_roundtrip(ptype: PType): @pytest.mark.parametrize("annotated_ptype", annotated_ptypes) -def test_annotated_bytes_roundtrip(annotated_ptype: tuple[PType, type]): +def test_annotated_bytes_roundtrip(annotated_ptype: tuple[PType, type]) -> None: ptype, annotation = annotated_ptype bs = bytes_from_ptype(ptype) new_type = ptype_from_bytes(bs, annotation) @@ -125,12 +128,12 @@ def test_annotated_bytes_roundtrip(annotated_ptype: tuple[PType, type]): @pytest.mark.parametrize("ptype", type_list) -def test_ptype_from_annotation(ptype: type[PType]): +def test_ptype_from_annotation(ptype: type[PType]) -> None: assert is_ptype(ptype) @pytest.mark.parametrize("ptype", fail_list) -def test_ptype_from_annotation_fails(ptype: type[PType]): +def test_ptype_from_annotation_fails(ptype: type[PType]) -> None: assert not is_ptype(ptype) @@ -138,13 +141,13 @@ def test_ptype_from_annotation_fails(ptype: type[PType]): T = TypeVar("T") generic_types = [] -generic_types.append((list[T], {str(T)})) # type: ignore -generic_types.append((list[S | T], {str(S), str(T)})) # type: ignore -generic_types.append((list[list[list[T]]], {str(T)})) # type: ignore -generic_types.append((tuple[S, T], {str(S), str(T)})) # type: ignore -generic_types.append((UntupledModel[S, T], {str(S), str(T)})) # type: ignore +generic_types.append((list[T], {str(T)})) # type: ignore[valid-type] +generic_types.append((list[S | T], {str(S), str(T)})) # type: ignore[valid-type] +generic_types.append((list[list[list[T]]], {str(T)})) # type: ignore[valid-type] +generic_types.append((tuple[S, T], {str(S), str(T)})) # type: ignore[valid-type] +generic_types.append((UntupledModel[S, T], {str(S), str(T)})) # type: ignore[valid-type] -@pytest.mark.parametrize("ptype,generics", generic_types) -def test_generic_types(ptype: type[PType], generics: set[type[PType]]): +@pytest.mark.parametrize(("ptype", "generics"), generic_types) +def test_generic_types(ptype: type[PType], generics: set[type[PType]]) -> None: assert generics_in_ptype(ptype) == generics diff --git a/tierkreis/tests/controller/typed_graphdata.py b/tierkreis/tests/controller/typed_graphdata.py index 088ac92f4..328c94f82 100644 --- a/tierkreis/tests/controller/typed_graphdata.py +++ b/tierkreis/tests/controller/typed_graphdata.py @@ -25,14 +25,14 @@ class DoublerOutput(NamedTuple): value: TKR[int] -def typed_doubler(): +def typed_doubler() -> GraphBuilder[TKR[int], TKR[int]]: g = GraphBuilder(TKR[int], TKR[int]) out = g.task(itimes(a=g.const(2), b=g.inputs)) g.outputs(out) return g -def typed_doubler_plus_multi(): +def typed_doubler_plus_multi() -> GraphBuilder[DoublerInput, DoublerOutput]: g = GraphBuilder(DoublerInput, DoublerOutput) mul = g.task(itimes(a=g.inputs.x, b=g.const(2))) out = g.task(iadd(a=mul, b=g.inputs.intercept)) @@ -40,7 +40,7 @@ def typed_doubler_plus_multi(): return g -def typed_doubler_plus(): +def typed_doubler_plus() -> GraphBuilder[DoublerInput, TKR[int]]: g = GraphBuilder(DoublerInput, TKR[int]) mul = g.task(itimes(a=g.inputs.x, b=g.const(2))) out = g.task(iadd(a=mul, b=g.inputs.intercept)) @@ -52,7 +52,7 @@ class TypedEvalOutputs(NamedTuple): typed_eval_output: TKR[int] -def typed_eval(): +def typed_eval() -> GraphBuilder[EmptyModel, TypedEvalOutputs]: g = GraphBuilder(EmptyModel, TypedEvalOutputs) e = g.eval(typed_doubler_plus(), DoublerInput(x=g.const(6), intercept=g.const(0))) g.outputs(TypedEvalOutputs(typed_eval_output=e)) @@ -68,7 +68,7 @@ class LoopBodyOutput(NamedTuple): should_continue: TKR[bool] -def loop_body(): +def loop_body() -> GraphBuilder[LoopBodyInput, LoopBodyOutput]: g = GraphBuilder(LoopBodyInput, LoopBodyOutput) a_plus = g.task(iadd(a=g.inputs.loop_acc, b=g.const(1))) pred = g.task(igt(a=g.const(10), b=a_plus)) @@ -76,21 +76,21 @@ def loop_body(): return g -def typed_loop(): +def typed_loop() -> GraphBuilder[EmptyModel, TKR[int]]: g = GraphBuilder(EmptyModel, TKR[int]) loop = g.loop(loop_body(), LoopBodyInput(loop_acc=g.const(6))) g.outputs(loop.loop_acc) return g -def typed_map_simple(): +def typed_map_simple() -> GraphBuilder[TKR[list[int]], TKR[list[int]]]: g = GraphBuilder(TKR[list[int]], TKR[list[int]]) m = g.map(typed_doubler(), g.inputs) g.outputs(m) return g -def typed_map(): +def typed_map() -> GraphBuilder[TKR[list[int]], TKR[list[int]]]: g = GraphBuilder(TKR[list[int]], TKR[list[int]]) ins = g.map(lambda n: DoublerInput(x=n, intercept=g.const(6)), g.inputs) m = g.map(typed_doubler_plus(), ins) @@ -98,7 +98,7 @@ def typed_map(): return g -def typed_destructuring(): +def typed_destructuring() -> GraphBuilder[TKR[list[int]], TKR[list[int]]]: g = GraphBuilder(TKR[list[int]], TKR[list[int]]) ins = g.map(lambda n: DoublerInput(x=n, intercept=g.const(6)), g.inputs) m = g.map(typed_doubler_plus_multi(), ins) @@ -107,7 +107,7 @@ def typed_destructuring(): return g -def tuple_untuple(): +def tuple_untuple() -> GraphBuilder[EmptyModel, TKR[int]]: g = GraphBuilder(EmptyModel, TKR[int]) t = g.task(tkr_tuple(g.const(1), g.const(2))) ut = g.task(untuple(t)) @@ -115,7 +115,7 @@ def tuple_untuple(): return g -def factorial(): +def factorial() -> GraphBuilder[TKR[int], TKR[int]]: g = GraphBuilder(TKR[int], TKR[int]) pred = g.task(igt(g.inputs, g.const(1))) n_minus_one = g.task(iadd(g.const(-1), g.inputs)) @@ -130,7 +130,7 @@ class GCDInput(NamedTuple): b: TKR[int] -def gcd(): +def gcd() -> GraphBuilder[GCDInput, TKR[int]]: g = GraphBuilder(GCDInput, TKR[int]) pred = g.task(igt(g.inputs.b, g.const(0))) @@ -141,21 +141,21 @@ def gcd(): return g -def tkr_conj(): +def tkr_conj() -> GraphBuilder[TKR[complex], TKR[complex]]: g = GraphBuilder(TKR[complex], TKR[complex]) z = g.task(conjugate(g.inputs)) g.outputs(z) return g -def tkr_list_conj(): +def tkr_list_conj() -> GraphBuilder[TKR[list[complex]], TKR[list[complex]]]: g = GraphBuilder(TKR[list[complex]], TKR[list[complex]]) zs = g.map(tkr_conj(), g.inputs) g.outputs(zs) return g -def eval_body_is_from_worker(): +def eval_body_is_from_worker() -> GraphBuilder[TKR[int], TKR[int]]: g = GraphBuilder(TKR[int], TKR[int]) graph = g.task(doubler_plus_graph()) graph_ref = TypedGraphRef(graph.value_ref(), TKR[int], TKR[int]) diff --git a/tierkreis/tests/errors/failing_worker/main.py b/tierkreis/tests/errors/failing_worker/main.py index 0365613f5..466644bb1 100644 --- a/tierkreis/tests/errors/failing_worker/main.py +++ b/tierkreis/tests/errors/failing_worker/main.py @@ -1,5 +1,7 @@ +# noqa: INP001 import logging -from sys import argv +import sys + from tierkreis import Worker logger = logging.getLogger(__name__) @@ -9,7 +11,8 @@ @worker.task() def fail() -> int: logger.error("Raising an error now...") - raise ValueError("Worker failed!") + msg = "Worker failed!" + raise ValueError(msg) @worker.task() @@ -19,8 +22,8 @@ def wont_fail() -> int: @worker.task() def exit_code_1() -> int: - exit(1) + sys.exit(1) if __name__ == "__main__": - worker.app(argv) + worker.app(sys.argv) diff --git a/tierkreis/tests/errors/test_error.py b/tierkreis/tests/errors/test_error.py index 1c4da439b..1dd598580 100644 --- a/tierkreis/tests/errors/test_error.py +++ b/tierkreis/tests/errors/test_error.py @@ -1,7 +1,9 @@ -import pytest from pathlib import Path from uuid import UUID +import pytest + +from tests.errors.failing_worker.stubs import exit_code_1, fail, wont_fail from tierkreis.builder import GraphBuilder from tierkreis.controller import run_graph from tierkreis.controller.data.core import EmptyModel @@ -9,29 +11,28 @@ from tierkreis.controller.data.models import TKR from tierkreis.controller.executor.uv_executor import UvExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage -from tests.errors.failing_worker.stubs import fail, wont_fail, exit_code_1 from tierkreis.exceptions import TierkreisError -def will_fail_graph(): +def will_fail_graph() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.task(fail())) return graph -def wont_fail_graph(): +def wont_fail_graph() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.task(wont_fail())) return graph -def fail_in_eval(): +def fail_in_eval() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.eval(will_fail_graph(), EmptyModel())) return graph -def non_zero_exit_code(): +def non_zero_exit_code() -> GraphBuilder[EmptyModel, TKR[int]]: graph = GraphBuilder(EmptyModel, TKR[int]) graph.outputs(graph.task(exit_code_1())) return graph @@ -44,7 +45,7 @@ def test_raise_error() -> None: storage.clean_graph_files() with pytest.raises(TierkreisError): run_graph(storage, executor, g.get_data(), {}, n_iterations=1000) - assert storage.node_has_error(Loc("-.N0")) + assert storage.node_has_error(Loc("-.N0")) def test_raises_no_error() -> None: @@ -63,7 +64,7 @@ def test_nested_error() -> None: storage.clean_graph_files() with pytest.raises(TierkreisError): run_graph(storage, executor, g.get_data(), {}, n_iterations=1000) - assert (storage.logs_path.parent / "-/errors").exists() + assert (storage.logs_path.parent / "-/_error").exists() def test_non_zero_exit_code() -> None: @@ -73,4 +74,4 @@ def test_non_zero_exit_code() -> None: storage.clean_graph_files() with pytest.raises(TierkreisError): run_graph(storage, executor, g.get_data(), {}, n_iterations=1000) - assert (storage.logs_path.parent / "-/_error").exists() + assert (storage.logs_path.parent / "-/_error").exists() diff --git a/tierkreis/tests/executor/test_hpc_executor.py b/tierkreis/tests/executor/test_hpc_executor.py index 84b287887..5f47f70ba 100644 --- a/tierkreis/tests/executor/test_hpc_executor.py +++ b/tierkreis/tests/executor/test_hpc_executor.py @@ -1,6 +1,9 @@ from pathlib import Path from uuid import UUID + import pytest + +from tests.executor.stubs import mpi_rank_info from tierkreis.builder import GraphBuilder from tierkreis.controller import run_graph from tierkreis.controller.data.graph import GraphData @@ -12,8 +15,6 @@ ) from tierkreis.controller.executor.hpc.slurm import SLURMExecutor from tierkreis.controller.storage.filestorage import ControllerFileStorage - -from tests.executor.stubs import mpi_rank_info from tierkreis.storage import read_outputs @@ -28,7 +29,9 @@ def job_spec() -> JobSpec: return JobSpec( job_name="test_job", account="test_usr", - command="--allow-run-as-root /root/.local/bin/uv run /slurm_mpi_worker/main.py ", + command=( + "--allow-run-as-root /root/.local/bin/uv run /slurm_mpi_worker/main.py " + ), resource=ResourceSpec(nodes=2, memory_gb=None), walltime="00:15:00", mpi=MpiSpec(max_proc_per_node=1), @@ -47,7 +50,7 @@ def test_slurm_with_mpi() -> None: do_cleanup=True, ) sbatch = str( - Path(__file__).parent.parent.parent.parent / "infra/slurm_local/sbatch" + Path(__file__).parent.parent.parent.parent / "infra/slurm_local/sbatch", ) executor = SLURMExecutor( spec=job_spec(), diff --git a/tierkreis/tests/idl/__init__.py b/tierkreis/tests/idl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tierkreis/tests/idl/namespace1.py b/tierkreis/tests/idl/namespace1.py index 069b6c5f0..38c12cfe1 100644 --- a/tierkreis/tests/idl/namespace1.py +++ b/tierkreis/tests/idl/namespace1.py @@ -1,6 +1,7 @@ from typing import NamedTuple -from tierkreis.controller.data.models import portmapping + from tierkreis import Worker +from tierkreis.controller.data.models import portmapping from tierkreis.controller.data.types import PType worker = Worker("TestNamespace") diff --git a/tierkreis/tests/idl/stubs_output.py b/tierkreis/tests/idl/stubs_output.py index 6d674d76f..4845b7157 100644 --- a/tierkreis/tests/idl/stubs_output.py +++ b/tierkreis/tests/idl/stubs_output.py @@ -1,47 +1,48 @@ """Code generated from TestNamespace namespace. Please do not edit.""" from typing import NamedTuple, Protocol + from tierkreis.controller.data.models import TKR from tierkreis.controller.data.types import PType, Struct class A(NamedTuple): - age: TKR[int] # noqa: F821 # fmt: skip - name: TKR[dict[str, str]] # noqa: F821 # fmt: skip + age: TKR[int] # fmt: skip + name: TKR[dict[str, str]] # fmt: skip class B(Struct, Protocol): - age: int # noqa: F821 # fmt: skip - name: dict[str, str] # noqa: F821 # fmt: skip + age: int # fmt: skip + name: dict[str, str] # fmt: skip class C[T: PType](Struct, Protocol): - a: list[int] # noqa: F821 # fmt: skip - b: "B" # noqa: F821 # fmt: skip - included: "IncludedType" # noqa: F821 # fmt: skip - ol: "list[ListItem]" # noqa: F821 # fmt: skip - t: "T" # noqa: F821 # fmt: skip + a: list[int] # fmt: skip + b: "B" # fmt: skip + included: "IncludedType" # fmt: skip + ol: "list[ListItem]" # fmt: skip + t: "T" # fmt: skip class IncludedType(Struct, Protocol): - nested: "NestedType" # noqa: F821 # fmt: skip + nested: "NestedType" # fmt: skip class ListItem(Struct, Protocol): - i: int # noqa: F821 # fmt: skip + i: int # fmt: skip class NestedType(Struct, Protocol): - city: str # noqa: F821 # fmt: skip + city: str # fmt: skip class foo(NamedTuple): - a: TKR[int] # noqa: F821 # fmt: skip - b: TKR[str] # noqa: F821 # fmt: skip + a: TKR[int] # fmt: skip + b: TKR[str] # fmt: skip @staticmethod - def out() -> type[A]: # noqa: F821 # fmt: skip - return A # noqa: F821 # fmt: skip + def out() -> type[A]: # fmt: skip + return A # fmt: skip @property def namespace(self) -> str: @@ -50,8 +51,8 @@ def namespace(self) -> str: class bar(NamedTuple): @staticmethod - def out() -> type[TKR[B]]: # noqa: F821 # fmt: skip - return TKR[B] # noqa: F821 # fmt: skip + def out() -> type[TKR[B]]: # fmt: skip + return TKR[B] # fmt: skip @property def namespace(self) -> str: @@ -59,11 +60,11 @@ def namespace(self) -> str: class z[T: PType](NamedTuple): - c: TKR[C[T]] # noqa: F821 # fmt: skip + c: TKR[C[T]] # fmt: skip @staticmethod - def out() -> type[TKR[C[T]]]: # noqa: F821 # fmt: skip - return TKR[C[T]] # noqa: F821 # fmt: skip + def out() -> type[TKR[C[T]]]: # fmt: skip + return TKR[C[T]] # fmt: skip @property def namespace(self) -> str: diff --git a/tierkreis/tests/idl/test_idl.py b/tierkreis/tests/idl/test_idl.py index f2999d3e9..6ff6fa490 100644 --- a/tierkreis/tests/idl/test_idl.py +++ b/tierkreis/tests/idl/test_idl.py @@ -1,10 +1,12 @@ from pathlib import Path + import pytest + +import tests.idl.namespace1 from tierkreis.exceptions import TierkreisError from tierkreis.idl.models import GenericType -from tierkreis.namespace import Namespace from tierkreis.idl.type_symbols import type_symbol -import tests.idl.namespace1 +from tierkreis.namespace import Namespace type_symbols = [ ("uint8", GenericType(int, [])), @@ -13,7 +15,8 @@ ( "Record>", GenericType( - dict, [GenericType(str, []), GenericType(list, [GenericType(str, [])])] + dict, + [GenericType(str, []), GenericType(list, [GenericType(str, [])])], ), ), ( @@ -28,17 +31,17 @@ ), ] type_symbols_for_failure = ["decimal", "unknown", "duration"] -dir = Path(__file__).parent -typespecs = [(dir / "namespace1.tsp", tests.idl.namespace1.expected_namespace)] +current_dir = Path(__file__).parent +typespecs = [(current_dir / "namespace1.tsp", tests.idl.namespace1.expected_namespace)] -@pytest.mark.parametrize("type_symb,expected", type_symbols) -def test_type_t(type_symb: str, expected: type): +@pytest.mark.parametrize(("type_symb", "expected"), type_symbols) +def test_type_t(type_symb: str, expected: type) -> None: assert (expected, "") == type_symbol(type_symb) -@pytest.mark.parametrize("path,expected", typespecs) -def test_namespace(path: Path, expected: Namespace): +@pytest.mark.parametrize(("path", "expected"), typespecs) +def test_namespace(path: Path, expected: Namespace) -> None: namespace = Namespace.from_spec_file(path) assert namespace.stubs() == expected.stubs() @@ -49,6 +52,6 @@ def test_namespace(path: Path, expected: Namespace): @pytest.mark.parametrize("type_symb", type_symbols_for_failure) -def test_parser_fail(type_symb: str): +def test_parser_fail(type_symb: str) -> None: with pytest.raises(TierkreisError): type_symbol(type_symb) diff --git a/tierkreis/tierkreis/__init__.py b/tierkreis/tierkreis/__init__.py index 615e11994..03f08a080 100644 --- a/tierkreis/tierkreis/__init__.py +++ b/tierkreis/tierkreis/__init__.py @@ -1,5 +1,7 @@ +"""Tierkreis main package.""" + +from tierkreis.controller import run_graph from tierkreis.labels import Labels from tierkreis.worker.worker import Worker -from tierkreis.controller import run_graph __all__ = ["Labels", "Worker", "run_graph"] diff --git a/tierkreis/tierkreis/builder.py b/tierkreis/tierkreis/builder.py index 0d6969e64..cbc01079d 100644 --- a/tierkreis/tierkreis/builder.py +++ b/tierkreis/tierkreis/builder.py @@ -1,18 +1,23 @@ +"""Typed graph builder for Tierkreis workflows.""" + +from __future__ import annotations + +from collections.abc import Callable from dataclasses import dataclass from inspect import isclass -from typing import Any, Callable, NamedTuple, Protocol, overload, runtime_checkable +from typing import Any, NamedTuple, Protocol, overload, runtime_checkable from tierkreis.controller.data.core import EmptyModel +from tierkreis.controller.data.graph import GraphData, ValueRef from tierkreis.controller.data.models import ( TKR, TModel, TNamedModel, dict_from_tmodel, - model_fields, init_tmodel, + model_fields, ) from tierkreis.controller.data.types import PType -from tierkreis.controller.data.graph import GraphData, ValueRef @dataclass @@ -24,27 +29,72 @@ class TList[T: TModel]: @runtime_checkable class Function[Out](TNamedModel, Protocol): + """A worker function type. + + :abstract: + """ + @property - def namespace(self) -> str: ... + def namespace(self) -> str: + """The namespace name. + + :return: The namespace name. + :rtype: str + """ + ... @staticmethod - def out() -> type[Out]: ... + def out() -> type[Out]: + """Return the output type of the function. + + :return: The output type. + :rtype: type[Out] + """ + ... @dataclass class TypedGraphRef[Ins: TModel, Outs: TModel]: + """A typed tierkreis graph. + + :attr graph_ref: The graph reference. + :attr outputs_type: The output type of the graph. + :attr inputs_type: The input type of the graph. + """ + graph_ref: ValueRef outputs_type: type[Outs] inputs_type: type[Ins] class LoopOutput(TNamedModel, Protocol): + """Protocol for loop output models to ensure should continue.""" + @property - def should_continue(self) -> TKR[bool]: ... + def should_continue(self) -> TKR[bool]: + """The loop continuation port. + + :return: The continuation port value. + :rtype: TKR[bool] + """ + ... + + +def script(script_name: str, script_input: TKR[bytes]) -> Function[TKR[bytes]]: + """Add a script to the graph. + + A shell script or binary with a single input and output. + Inputs are provided from the standard input and outputs to the standard output. + :param script_name: The name of the script. + :type script_name: str + :param script_input: The input to the script. + :type script_input: TKR[bytes] + :return: The output of the script. + :rtype: Function[TKR[bytes]] + """ -def script(script_name: str, input: TKR[bytes]) -> Function[TKR[bytes]]: - class exec_script(NamedTuple): + class exec_script(NamedTuple): # noqa: N801 input: TKR[bytes] @staticmethod @@ -55,10 +105,18 @@ def out() -> type[TKR[bytes]]: def namespace(self) -> str: return script_name - return exec_script(input=input) + return exec_script(input=script_input) class GraphBuilder[Inputs: TModel, Outputs: TModel]: + """Class to construct typed workflow graphs. + + :attr data: The underlying graph data. + :attr inputs_type: The input type of the graph. + :attr inputs: The inputs to the graph. + :attr outputs_type: The output type of the graph. + """ + outputs_type: type inputs: Inputs @@ -66,7 +124,7 @@ def __init__( self, inputs_type: type[Inputs] = EmptyModel, outputs_type: type[Outputs] = EmptyModel, - ): + ) -> None: self.data = GraphData() self.inputs_type = inputs_type self.outputs_type = outputs_type @@ -74,37 +132,97 @@ def __init__( self.inputs = init_tmodel(self.inputs_type, inputs) def get_data(self) -> GraphData: + """Return the underlying graph from the builder. + + :return: The graph. + :rtype: GraphData + """ return self.data def ref(self) -> TypedGraphRef[Inputs, Outputs]: + """Return a reference of the typed graph. + + :return: The ref of the typed graph. + :rtype: TypedGraphRef[Inputs, Outputs] + """ return TypedGraphRef((-1, "body"), self.outputs_type, self.inputs_type) - def outputs(self, outputs: Outputs): + def outputs(self, outputs: Outputs) -> None: + """Set output nodes of a graph. + + :param outputs: The output nodes. + :type outputs: Outputs + """ self.data.output(inputs=dict_from_tmodel(outputs)) def const[T: PType](self, value: T) -> TKR[T]: + """Add a constant node to the graph. + + :return: The constant value. + :rtype: TKR[T] + """ idx, port = self.data.const(value) return TKR[T](idx, port) def ifelse[A: PType, B: PType]( - self, pred: TKR[bool], if_true: TKR[A], if_false: TKR[B] + self, + pred: TKR[bool], + if_true: TKR[A], + if_false: TKR[B], ) -> TKR[A] | TKR[B]: + """Add an if-else node to the graph. + + This will be evaluated lazily. + The values can be returned from an eval node or another graph. + + :param pred: The predicate value. + :type pred: TKR[bool] + :param if_true: The value if the predicate is true. + :type if_true: TKR[A] + :param if_false: The value if the predicate is false. + :type if_false: TKR[B] + :return: The outputs of the if-else expression. + :rtype: TKR[A] | TKR[B] + """ idx, port = self.data.if_else( - pred.value_ref(), if_true.value_ref(), if_false.value_ref() + pred.value_ref(), + if_true.value_ref(), + if_false.value_ref(), )("value") return TKR(idx, port) def eifelse[A: PType, B: PType]( - self, pred: TKR[bool], if_true: TKR[A], if_false: TKR[B] + self, + pred: TKR[bool], + if_true: TKR[A], + if_false: TKR[B], ) -> TKR[A] | TKR[B]: + """Add an eager if-else node to the graph. + + This will be evaluated eagerly. + The values can be returned from an eval node or another graph. + + :param pred: The predicate value. + :type pred: TKR[bool] + :param if_true: The value if the predicate is true. + :type if_true: TKR[A] + :param if_false: The value if the predicate is false. + :type if_false: TKR[B] + :return: The outputs of the if-else expression. + :rtype: TKR[A] | TKR[B] + """ idx, port = self.data.eager_if_else( - pred.value_ref(), if_true.value_ref(), if_false.value_ref() + pred.value_ref(), + if_true.value_ref(), + if_false.value_ref(), )("value") return TKR(idx, port) def _graph_const[A: TModel, B: TModel]( - self, graph: "GraphBuilder[A, B]" + self, + graph: GraphBuilder[A, B], ) -> TypedGraphRef[A, B]: + # TODO @philipp-seitz: Turn this into a public method? idx, port = self.data.const(graph.data.model_dump()) return TypedGraphRef[A, B]( graph_ref=(idx, port), @@ -112,48 +230,104 @@ def _graph_const[A: TModel, B: TModel]( inputs_type=graph.inputs_type, ) - def task[Out: TModel](self, f: Function[Out]) -> Out: - name = f"{f.namespace}.{f.__class__.__name__}" - ins = dict_from_tmodel(f) - idx, _ = self.data.func(name, ins)("dummy") - OutModel = f.out() + def task[Out: TModel](self, func: Function[Out]) -> Out: + """Add a worker task node to the graph. + + :param func: The worker function. + :type func: Function[Out] + :return: The outputs of the task. + :rtype: Out + """ + name = f"{func.namespace}.{func.__class__.__name__}" + inputs = dict_from_tmodel(func) + idx, _ = self.data.func(name, inputs)("dummy") + OutModel = func.out() # noqa: N806 outputs = [(idx, x) for x in model_fields(OutModel)] return init_tmodel(OutModel, outputs) @overload - def eval[A: TModel, B: TModel](self, body: TypedGraphRef[A, B], a: A) -> B: ... + def eval[A: TModel, B: TModel]( + self, + body: TypedGraphRef[A, B], + eval_inputs: A, + ) -> B: ... @overload - def eval[A: TModel, B: TModel](self, body: "GraphBuilder[A, B]", a: A) -> B: ... def eval[A: TModel, B: TModel]( - self, body: "GraphBuilder[A,B] | TypedGraphRef", a: Any + self, + body: GraphBuilder[A, B], + eval_inputs: A, + ) -> B: ... + def eval[A: TModel, B: TModel]( + self, + body: GraphBuilder[A, B] | TypedGraphRef, + eval_inputs: Any, ) -> Any: + """Add a evaluation node to the graph. + + This will evaluate a nested graph with the given inputs. + + :param body: The graph to evaluate. + :type body: TypedGraphRef[A, B] | GraphBuilder[A, B], + where A are the input type and B the output type of the graph. + :param eval_inputs: The inputs to the graph. + :type eval_inputs: A + :return: The outputs of the evaluation. + :rtype: B + """ if isinstance(body, GraphBuilder): body = self._graph_const(body) - idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(a))("dummy") + idx, _ = self.data.eval(body.graph_ref, dict_from_tmodel(eval_inputs))("dummy") outputs = [(idx, x) for x in model_fields(body.outputs_type)] return init_tmodel(body.outputs_type, outputs) @overload def loop[A: TModel, B: LoopOutput]( - self, body: TypedGraphRef[A, B], a: A, name: str | None = None + self, + body: TypedGraphRef[A, B], + loop_inputs: A, + name: str | None = None, ) -> B: ... @overload def loop[A: TModel, B: LoopOutput]( - self, body: "GraphBuilder[A, B]", a: A, name: str | None = None + self, + body: GraphBuilder[A, B], + loop_inputs: A, + name: str | None = None, ) -> B: ... def loop[A: TModel, B: LoopOutput]( self, - body: "TypedGraphRef[A, B] |GraphBuilder[A, B]", - a: A, + body: TypedGraphRef[A, B] | GraphBuilder[A, B], + loop_inputs: A, name: str | None = None, ) -> B: + """Add a loop node to the graph. + + This will loop over the given graph until the `should_continue` output is false. + To trace intermediate values, use the name attribute in conjunction with + read_loop_trace. + + :param body: The graph to loop. + :type body: TypedGraphRef[A, B] | GraphBuilder[A, B], + where A are the input type and B the output type of the graph. + :param loop_inputs: The inputs to the loop graph. + :type loop_inputs: A + :param name: An optional name for the loop. + :type name: str | None + :return: The outputs of the loop. + :rtype: B + """ if isinstance(body, GraphBuilder): body = self._graph_const(body) - g = body.graph_ref - idx, _ = self.data.loop(g, dict_from_tmodel(a), "should_continue", name)( - "dummy" + graph = body.graph_ref + idx, _ = self.data.loop( + graph, + dict_from_tmodel(loop_inputs), + "should_continue", + name, + )( + "dummy", ) outputs = [(idx, x) for x in model_fields(body.outputs_type)] return init_tmodel(body.outputs_type, outputs) @@ -164,27 +338,33 @@ def _unfold_list[T: PType](self, ref: TKR[list[T]]) -> TList[TKR[T]]: return TList(TKR[T](idx, "*")) def _fold_list[T: PType](self, refs: TList[TKR[T]]) -> TKR[list[T]]: - value_ref = (refs._value.node_index, refs._value.port_id) + value_ref = (refs._value.node_index, refs._value.port_id) # noqa: SLF001 idx, _ = self.data.func("builtins.fold_values", {"values_glob": value_ref})( - "dummy" + "dummy", ) return TKR[list[T]](idx, "value") def _map_fn_single_in[A: PType, B: TModel]( - self, aes: TKR[list[A]], body: Callable[[TKR[A]], B] - ) -> "TList[B]": - tlist = self._unfold_list(aes) - return TList(body(TKR(tlist._value.node_index, "*"))) + self, + map_inputs: TKR[list[A]], + body: Callable[[TKR[A]], B], + ) -> TList[B]: + tlist = self._unfold_list(map_inputs) + return TList(body(TKR(tlist._value.node_index, "*"))) # noqa: SLF001 def _map_fn_single_out[A: TModel, B: PType]( - self, aes: TList[A], body: Callable[[A], TKR[B]] + self, + map_inputs: TList[A], + body: Callable[[A], TKR[B]], ) -> TKR[list[B]]: - return self._fold_list(TList(body(aes._value))) + return self._fold_list(TList(body(map_inputs._value))) # noqa: SLF001 def _map_graph_full[A: TModel, B: TModel]( - self, aes: TList[A], body: TypedGraphRef[A, B] + self, + map_inputs: TList[A], + body: TypedGraphRef[A, B], ) -> TList[B]: - ins = dict_from_tmodel(aes._value) + ins = dict_from_tmodel(map_inputs._value) # noqa: SLF001 idx, _ = self.data.map(body.graph_ref, ins)("x") refs = [(idx, s + "-*") for s in model_fields(body.outputs_type)] @@ -194,23 +374,25 @@ def _map_graph_full[A: TModel, B: TModel]( def map[A: PType, B: TNamedModel]( self, body: ( - Callable[[TKR[A]], B] | TypedGraphRef[TKR[A], B] | "GraphBuilder[TKR[A], B]" + Callable[[TKR[A]], B] | TypedGraphRef[TKR[A], B] | GraphBuilder[TKR[A], B] ), - aes: TKR[list[A]], + map_inputs: TKR[list[A]], ) -> TList[B]: ... @overload def map[A: TNamedModel, B: PType]( self, body: ( - Callable[[A], TKR[B]] | TypedGraphRef[A, TKR[B]] | "GraphBuilder[A, TKR[B]]" + Callable[[A], TKR[B]] | TypedGraphRef[A, TKR[B]] | GraphBuilder[A, TKR[B]] ), - aes: TList[A], + map_inputs: TList[A], ) -> TKR[list[B]]: ... @overload def map[A: TNamedModel, B: TNamedModel]( - self, body: TypedGraphRef[A, B] | "GraphBuilder[A, B]", aes: TList[A] + self, + body: TypedGraphRef[A, B] | GraphBuilder[A, B], + map_inputs: TList[A], ) -> TList[B]: ... @overload @@ -219,30 +401,42 @@ def map[A: PType, B: PType]( body: ( Callable[[TKR[A]], TKR[B]] | TypedGraphRef[TKR[A], TKR[B]] - | "GraphBuilder[TKR[A], TKR[B]]" + | GraphBuilder[TKR[A], TKR[B]] ), - aes: TKR[list[A]], + map_inputs: TKR[list[A]], ) -> TKR[list[B]]: ... def map( - self, body: TypedGraphRef | Callable | "GraphBuilder", aes: TKR | TList + self, + body: TypedGraphRef | Callable | GraphBuilder, + map_inputs: TKR | TList, ) -> Any: + """Add a map node to the graph. + + :param body: The graph to map over. + :type body: TypedGraphRef | Callable | GraphBuilder + :param map_inputs: The values to map over. + :type map_inputs: TKR | TList + :return: The outputs of the map. + :rtype: Any + """ if isinstance(body, GraphBuilder): body = self._graph_const(body) if isinstance(body, Callable): - if isinstance(aes, TList): - return self._map_fn_single_out(aes, body) - elif isinstance(aes, TKR): - return self._map_fn_single_in(aes, body) + if isinstance(map_inputs, TList): + return self._map_fn_single_out(map_inputs, body) + if isinstance(map_inputs, TKR): + return self._map_fn_single_in(map_inputs, body) - if isinstance(aes, TKR): - aes = self._unfold_list(aes) + if isinstance(map_inputs, TKR): + map_inputs = self._unfold_list(map_inputs) - out = self._map_graph_full(aes, body) + out = self._map_graph_full(map_inputs, body) if not isclass(body.outputs_type) or not issubclass( - body.outputs_type, TNamedModel + body.outputs_type, + TNamedModel, ): out = self._fold_list(out) diff --git a/tierkreis/tierkreis/builtins/__init__.py b/tierkreis/tierkreis/builtins/__init__.py index e69de29bb..19467cbcc 100644 --- a/tierkreis/tierkreis/builtins/__init__.py +++ b/tierkreis/tierkreis/builtins/__init__.py @@ -0,0 +1 @@ +"""Built-in Tierkreis worker and stubs for basic operations.""" diff --git a/tierkreis/tierkreis/builtins/main.py b/tierkreis/tierkreis/builtins/main.py index ed8b4b120..e20d5f656 100644 --- a/tierkreis/tierkreis/builtins/main.py +++ b/tierkreis/tierkreis/builtins/main.py @@ -1,11 +1,15 @@ +"""Built-in Tierkreis tasks for basic operations.""" + +import statistics +from collections.abc import Sequence from logging import getLogger from pathlib import Path from random import randint -import statistics from sys import argv from time import sleep -from typing import NamedTuple, Sequence +from typing import NamedTuple +from tierkreis import Worker from tierkreis.controller.data.location import WorkerCallArgs from tierkreis.controller.data.models import portmapping from tierkreis.controller.data.types import ( @@ -13,10 +17,8 @@ bytes_from_ptype, ptype_from_bytes, ) -from tierkreis.worker.worker import TierkreisWorkerError from tierkreis.worker.storage.protocol import WorkerStorage -from tierkreis import Worker - +from tierkreis.worker.worker import TierkreisWorkerError logger = getLogger(__name__) @@ -25,151 +27,405 @@ @worker.task() def iadd(a: int, b: int) -> int: - logger.debug(f"iadd {a} {b}") + """Add two integers a+b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: The sum of the two integers. + :rtype: int + """ + logger.debug("iadd %s %s", a, b) return a + b @worker.task() -def add(a: int | float, b: int | float) -> int | float: +def add(a: float, b: float) -> int | float: + """Add two float like values a+b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: The sum of the two values. + :rtype: int | float + """ return a + b @worker.task() def isubtract(a: int, b: int) -> int: + """Subtract two integers a-b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: The difference of the two integers. + :rtype: int + """ return a - b @worker.task() -def subtract(a: int | float, b: int | float) -> int | float: +def subtract(a: float, b: float) -> int | float: + """Subtract two float like values a-b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: The difference of the two values. + :rtype: int | float + """ return a - b @worker.task() def itimes(a: int, b: int) -> int: - logger.debug(f"itimes {a} {b}") + """Multiply two integers a*b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: The product of the two integers. + :rtype: int + """ + logger.debug("itimes %s %s", a, b) return a * b @worker.task() -def times(a: int | float, b: int | float) -> int | float: +def times(a: float, b: float) -> int | float: + """Multiply two float like values a*b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: The product of the two values. + :rtype: int | float + """ return a * b @worker.task() -def divide(a: int | float, b: int | float) -> float: +def divide(a: float, b: float) -> float: + """Divide two float like values a/b. + + :param a: The dividend. + :type a: float + :param b: The divisor. + :type b: float + :return: The quotient of the two values. + :rtype: float + """ return a / b @worker.task() def idivide(a: int, b: int) -> int: + """Integer division of two integers a//b. + + :param a: The dividend. + :type a: int + :param b: The divisor. + :type b: int + :return: The integer quotient of the two integers. + :rtype: int + """ return a // b @worker.task() def igt(a: int, b: int) -> bool: - logger.debug(f"igt {a} {b}") + """Check if integer a is greater than integer b. + + :param a: The first integer. + :type a: int + :param b: The second integer. + :type b: int + :return: True if a > b, False otherwise. + :rtype: bool + """ + logger.debug("igt %s %s", a, b) return a > b @worker.task() -def gt(a: int | float, b: int | float) -> bool: +def gt(a: float, b: float) -> bool: + """Check if value a is greater than value b. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a > b, False otherwise. + :rtype: bool + """ return a > b +@worker.task() +def lt(a: float, b: float) -> bool: + """Check if value a is less than value b. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a < b, False otherwise. + :rtype: bool + """ + return a < b + + @worker.task() def conjugate(z: complex) -> complex: + """Return the complex conjugate of z. + + :param z: The complex number. + :type z: complex + :return: The complex conjugate of z. + :rtype: complex + """ return z.conjugate() @worker.task() -def eq(a: int | float, b: int | float) -> bool: +def eq(a: float, b: float) -> bool: + """Check if two float like values are equal. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a == b, False otherwise. + :rtype: bool + """ return a == b @worker.task() -def neq(a: int | float, b: int | float) -> bool: +def neq(a: float, b: float) -> bool: + """Check if two float like values are not equal. + + :param a: The first value. + :type a: float + :param b: The second value. + :type b: float + :return: True if a != b, False otherwise. + :rtype: bool + """ return a != b @worker.task() def ipow(a: int, b: int) -> int: + """Raise integer a to the power of integer b. + + :param a: The base integer. + :type a: int + :param b: The exponent integer. + :type b: int + :return: The result of a**b. + :rtype: int + """ return a**b @worker.task() -def pow(a: int | float, b: int | float) -> int | float: +def tkr_pow(a: float, b: float) -> int | float: + """Raise value a to the power of value b. + + Returns an int if both inputs are integers, otherwise a float. + + :param a: The base value. + :type a: float + :param b: The exponent value. + :type b: float + :return: The result of a**b. + :rtype: int | float + """ return a**b @worker.task() -def tkr_abs(a: int | float) -> int | float: +def tkr_abs(a: float) -> int | float: + """Return the absolute value of a float like value. + + :param a: The value. + :type a: float + :return: The absolute value of a. + :rtype: int | float + """ return abs(a) @worker.task() -def tkr_round(a: float | int) -> int: +def tkr_round(a: float) -> int: + """Round a float to the nearest integer. + + :param a: The float value to round. + :type a: float + :return: The rounded integer. + :rtype: int + """ return round(a) @worker.task() -def neg(a: bool) -> bool: +def neg(*, a: bool) -> bool: + """Negate a boolean value. + + :param a: The boolean value. + :type a: bool + :return: The negated boolean value. + :rtype: bool + """ return not a @worker.task() -def trk_and(a: bool, b: bool) -> bool: - logger.debug(f"and {a} {b}") +def tkr_and(*, a: bool, b: bool) -> bool: + """Return the logical AND of two boolean values. + + :param a: The first boolean value. + :type a: bool + :param b: The second boolean value. + :type b: bool + :return: The logical AND of a and b. + :rtype: bool + """ + logger.debug("and %s %s", a, b) return a and b @worker.task() -def trk_or(a: bool, b: bool) -> bool: - logger.debug(f"and {a} {b}") +def tkr_or(*, a: bool, b: bool) -> bool: + """Return the logical OR of two boolean values. + + :param a: The first boolean value. + :type a: bool + :param b: The second boolean value. + :type b: bool + :return: The logical OR of a and b. + :rtype: bool + """ + logger.debug("and %s %s", a, b) return a or b @worker.task() def tkr_id[T: PType](value: T) -> T: - logger.debug(f"id {value}") + """Return the input value unchanged (identity function). + + :param value: The value to return. + :type value: T + :return: The same value. + :rtype: T + """ + logger.debug("id %s", value) return value @worker.task() -def append[T](v: list[T], a: T) -> list[T]: # noqa: E741 +def append[T](v: list[T], a: T) -> list[T]: + """Append an element to a list and return the modified list. + + :param v: The list to append to. + :type v: list[T] + :param a: The element to append. + :type a: T + :return: The list with the element appended. + :rtype: list[T] + """ v.append(a) return v @portmapping class Headed[T: PType](NamedTuple): + """A tuple containing a head element and the rest of the list.""" + head: T rest: list[T] @worker.task() -def head[T: PType](v: list[T]) -> Headed[T]: # noqa: E741 +def head[T: PType](v: list[T]) -> Headed[T]: + """Return the first element and remaining elements of a list. + + :param v: The list. + :type v: list[T] + :return: A Headed tuple containing the first element and the rest of the list. + :rtype: Headed[T] + """ head, rest = v[0], v[1:] return Headed(head=head, rest=rest) @worker.task() def tkr_len[A](v: list[A]) -> int: + """Return the length of a list. + + :param v: The list. + :type v: list[A] + :return: The number of elements in the list. + :rtype: int + """ logger.info("len: %s", v) return len(v) @worker.task() def str_eq(a: str, b: str) -> bool: + """Check if two strings are equal. + + :param a: The first string. + :type a: str + :param b: The second string. + :type b: str + :return: True if the strings are equal, False otherwise. + :rtype: bool + """ return a == b @worker.task() def str_neq(a: str, b: str) -> bool: + """Check if two strings are not equal. + + :param a: The first string. + :type a: str + :param b: The second string. + :type b: str + :return: True if the strings are not equal, False otherwise. + :rtype: bool + """ return a != b @worker.primitive_task() def fold_values(args: WorkerCallArgs, storage: WorkerStorage) -> None: + """Fold multiple values from storage into a single list. + + Reads values from storage matching a glob pattern (values_glob) + and combines them into a single list output at the specified output path. + + :param args: The worker call arguments containing the glob pattern and output path. + :type args: WorkerCallArgs + :param storage: The worker storage for reading and writing values. + :type storage: WorkerStorage + """ values_glob = storage.glob(str(args.inputs["values_glob"])) values_glob.sort(key=lambda x: int(Path(x).name.split("-")[-1])) bs = [storage.read_input(Path(value)) for value in values_glob] @@ -179,50 +435,106 @@ def fold_values(args: WorkerCallArgs, storage: WorkerStorage) -> None: @worker.primitive_task() def unfold_values(args: WorkerCallArgs, storage: WorkerStorage) -> None: + """Unfold a single list value into multiple individual values in storage. + + Reads a list from storage and writes each element to a separate storage location. + + :param args: The worker call arguments containing input value and output directory. + :type args: WorkerCallArgs + :param storage: The worker storage for reading and writing values. + :type storage: WorkerStorage + :raises TierkreisWorkerError: If the input is not a list or sequence. + """ value_list = ptype_from_bytes(storage.read_input(args.inputs["value"])) match value_list: case list() | Sequence(): for i, v in enumerate(value_list): storage.write_output(args.output_dir / str(i), bytes_from_ptype(v)) case _: - raise TierkreisWorkerError(f"Expected list found {value_list}") + msg = f"Expected list found {value_list}" + raise TierkreisWorkerError(msg) @worker.task() def concat(lhs: str, rhs: str) -> str: + """Concatenate two strings lhs+rhs. + + :param lhs: The first string. + :type lhs: str + :param rhs: The second string. + :type rhs: str + :return: The concatenated string. + :rtype: str + """ return lhs + rhs @worker.task() def tkr_zip[U, V](a: list[U], b: list[V]) -> list[tuple[U, V]]: - return list(zip(a, b)) + """Zip two lists together into a list of tuples. + + :param a: The first list. + :type a: list[U] + :param b: The second list. + :type b: list[V] + :return: A list of tuples pairing elements from both lists. + :rtype: list[tuple[U, V]] + """ + return list(zip(a, b, strict=False)) @portmapping class Unzipped[U: PType, V: PType](NamedTuple): + """A tuple containing two lists resulting from unzipping.""" + a: list[U] b: list[V] @worker.task() def unzip[U: PType, V: PType](value: list[tuple[U, V]]) -> Unzipped[U, V]: - value_a, value_b = map(list, zip(*value)) + """Unzip a list of tuples into two separate lists. + + :param value: The list of tuples to unzip. + :type value: list[tuple[U, V]] + :return: An Unzipped tuple containing two lists. + :rtype: Unzipped[U, V] + """ + value_a, value_b = map(list, zip(*value, strict=False)) return Unzipped(a=value_a, b=value_b) @worker.task() def tkr_tuple[U, V](a: U, b: V) -> tuple[U, V]: + """Create a tuple from two values. + + :param a: The first value. + :type a: U + :param b: The second value. + :type b: V + :return: A tuple containing both values. + :rtype: tuple[U, V] + """ return (a, b) @portmapping class Untupled[U: PType, V: PType](NamedTuple): + """A tuple containing two unpacked values.""" + a: U b: V @worker.task() def untuple[U: PType, V: PType](value: tuple[U, V]) -> Untupled[U, V]: + """Unpack a tuple of two elements into separate values. + + :param value: The tuple to unpack. + :type value: tuple[U, V] + :return: An Untupled tuple containing the two unpacked values. + :rtype: Untupled[U, V] + """ logger.info("untuple: %s", value) value_a, value_b = value return Untupled(a=value_a, b=value_b) @@ -230,83 +542,209 @@ def untuple[U: PType, V: PType](value: tuple[U, V]) -> Untupled[U, V]: @worker.task() def mean(values: list[float]) -> float: + """Calculate the arithmetic mean of a list of floats. + + :param values: The list of float values. + :type values: list[float] + :return: The mean of the values. + :rtype: float + """ return statistics.mean(values) @worker.task() def mod(a: int, b: int) -> int: + """Return the modulo of two integers a % b. + + :param a: The dividend. + :type a: int + :param b: The divisor. + :type b: int + :return: The remainder of a divided by b. + :rtype: int + """ return a % b @worker.task() def tkr_range(start: int, stop: int, step: int = 1) -> list[int]: + """Return a list of ints in the range. + + As pythons range(). + + :param start: Start of the range (inclusive). + :type start: int + :param stop: End of the range (exclusive). + :type stop: int + :param step: Step size, defaults to 1. + :type step: int, optional + :return: A list of integers in the specified range. + :rtype: list[int] + """ return list(range(start, stop, step)) @worker.task() def rand_int(a: int, b: int) -> int: - return randint(a, b) + """Return a random integer between a and b (inclusive). + + :param a: The lower bound (inclusive). + :type a: int + :param b: The upper bound (inclusive). + :type b: int + :return: A random integer between a and b. + :rtype: int + """ + return randint(a, b) # noqa: S311 @worker.task() def tkr_sleep(delay_seconds: float) -> bool: + """Sleep for a specified number of seconds. + + :param delay_seconds: The number of seconds to sleep. + :type delay_seconds: float + :return: True after the sleep completes. + :rtype: bool + """ sleep(delay_seconds) return True @worker.task() def tkr_encode(string: str) -> bytes: + """Encode a string to bytes using UTF-8 encoding. + + :param string: The string to encode. + :type string: str + :return: The UTF-8 encoded bytes. + :rtype: bytes + """ return string.encode() @worker.task() -def tkr_decode(bytes: bytes) -> str: - return bytes.decode() +def tkr_decode(value_bytes: bytes) -> str: + """Decode bytes to a string using UTF-8 decoding. + + :param value_bytes: The bytes to decode. + :type value_bytes: bytes + :return: The decoded string. + :rtype: str + """ + return value_bytes.decode() @worker.task() def tkr_all[T: PType](values: Sequence[T]) -> bool: + """Check if all elements in a sequence are truthy. + + :param values: The sequence of values. + :type values: Sequence[T] + :return: True if all elements are truthy, False otherwise. + :rtype: bool + """ return all(values) @worker.task() def tkr_any[T: PType](values: Sequence[T]) -> bool: + """Check if any element in a sequence is truthy. + + :param values: The sequence of values. + :type values: Sequence[T] + :return: True if any element is truthy, False otherwise. + :rtype: bool + """ return any(values) @worker.task() def tkr_reversed[T: PType](values: list[T]) -> list[T]: + """Return a reversed copy of a list. + + :param values: The list to reverse. + :type values: list[T] + :return: A new list with elements in reverse order. + :rtype: list[T] + """ return list(reversed(values)) @worker.task() def tkr_extend[T: PType](first: list[T], second: list[T]) -> list[T]: + """Extend a list with elements from another list. + + :param first: The list to extend. + :type first: list[T] + :param second: The list of elements to add. + :type second: list[T] + :return: The extended list. + :rtype: list[T] + """ first.extend(second) return first @worker.task() def concat_lists[U: PType, V: PType](first: list[U], second: list[V]) -> list[U | V]: + """Concatenate two lists of potentially different types. + + :param first: The first list. + :type first: list[U] + :param second: The second list. + :type second: list[V] + :return: A concatenated list containing elements from both lists. + :rtype: list[U | V] + """ return first + second @worker.task() -def tkr_str(value: int | float | bool) -> str: +def tkr_str(*, value: float | bool) -> str: + """Convert a float or bool value to a string. + + :param value: The value to convert. + :type value: float | bool + :return: The string representation of the value. + :rtype: str + """ return str(value) @worker.task() -def tkr_int(value: int | float | bool | str) -> int: +def tkr_int(*, value: float | bool | str) -> int: + """Convert a float, bool, or string value to an integer. + + :param value: The value to convert. + :type value: float | bool | str + :return: The integer representation of the value. + :rtype: int + """ return int(value) @worker.task() def sum_list(values: list[int | float]) -> int | float: + """Sum all elements in a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The sum of all elements. + :rtype: int | float + """ return sum(values) @worker.task() def prod_list(values: list[int | float]) -> int | float: + """Calculate the product of all elements in a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The product of all elements. + :rtype: int | float + """ prod = 1 for v in values: prod *= v @@ -315,26 +753,61 @@ def prod_list(values: list[int | float]) -> int | float: @worker.task() def max_item(values: list[int | float]) -> int | float: + """Return the maximum element from a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The maximum value in the list. + :rtype: int | float + """ return max(values) @worker.task() def min_item(values: list[int | float]) -> int | float: + """Return the minimum element from a list of numbers. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: The minimum value in the list. + :rtype: int | float + """ return min(values) @worker.task() def sort_number_list(values: list[int | float]) -> list[int | float]: + """Sort a list of numbers in ascending order. + + :param values: The list of numeric values. + :type values: list[int | float] + :return: A sorted list of numeric values. + :rtype: list[int | float] + """ return sorted(values) @worker.task() def sort_string_list(values: list[str]) -> list[str]: + """Sort a list of strings in ascending order. + + :param values: The list of strings. + :type values: list[str] + :return: A sorted list of strings. + :rtype: list[str] + """ return sorted(values) @worker.task() def flatten[T: PType](values: list[list[T]]) -> list[T]: + """Flatten a list of lists into a single list. + + :param values: The list of lists to flatten. + :type values: list[list[T]] + :return: A flattened list containing all elements. + :rtype: list[T] + """ out = [] for sub in values: out.extend(sub) @@ -343,11 +816,29 @@ def flatten[T: PType](values: list[list[T]]) -> list[T]: @worker.task() def take[T: PType](values: list[T], n: int) -> list[T]: + """Return the first n elements of a list. + + :param values: The list. + :type values: list[T] + :param n: The number of elements to take. + :type n: int + :return: A list containing the first n elements. + :rtype: list[T] + """ return values[:n] @worker.task() def drop[T: PType](values: list[T], n: int) -> list[T]: + """Drop the first n elements of a list and return the rest. + + :param values: The list. + :type values: list[T] + :param n: The number of elements to drop. + :type n: int + :return: A list with the first n elements removed. + :rtype: list[T] + """ return values[n:] diff --git a/tierkreis/tierkreis/builtins/stubs.py b/tierkreis/tierkreis/builtins/stubs.py index ed0968f7f..a4d537305 100644 --- a/tierkreis/tierkreis/builtins/stubs.py +++ b/tierkreis/tierkreis/builtins/stubs.py @@ -34,12 +34,12 @@ def namespace(self) -> str: class add(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -60,12 +60,12 @@ def namespace(self) -> str: class subtract(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -86,12 +86,12 @@ def namespace(self) -> str: class times(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -99,8 +99,8 @@ def namespace(self) -> str: class divide(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod def out() -> type[TKR[float]]: # noqa: F821 # fmt: skip @@ -138,8 +138,21 @@ def namespace(self) -> str: class gt(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip + + @staticmethod + def out() -> type[TKR[bool]]: # fmt: skip + return TKR[bool] # fmt: skip + + @property + def namespace(self) -> str: + return "builtins" + + +class lt(NamedTuple): + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip @@ -163,8 +176,8 @@ def namespace(self) -> str: class eq(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip @@ -176,8 +189,8 @@ def namespace(self) -> str: class neq(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip @@ -201,13 +214,13 @@ def namespace(self) -> str: return "builtins" -class pow(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip - b: TKR[Union[int, float]] # noqa: F821 # fmt: skip +class tkr_pow(NamedTuple): + a: TKR[float] # fmt: skip + b: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -215,11 +228,11 @@ def namespace(self) -> str: class tkr_abs(NamedTuple): - a: TKR[Union[int, float]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -227,7 +240,7 @@ def namespace(self) -> str: class tkr_round(NamedTuple): - a: TKR[Union[float, int]] # noqa: F821 # fmt: skip + a: TKR[float] # fmt: skip @staticmethod def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip @@ -250,9 +263,9 @@ def namespace(self) -> str: return "builtins" -class trk_and(NamedTuple): - a: TKR[bool] # noqa: F821 # fmt: skip - b: TKR[bool] # noqa: F821 # fmt: skip +class tkr_and(NamedTuple): + a: TKR[bool] # fmt: skip + b: TKR[bool] # fmt: skip @staticmethod def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip @@ -263,9 +276,9 @@ def namespace(self) -> str: return "builtins" -class trk_or(NamedTuple): - a: TKR[bool] # noqa: F821 # fmt: skip - b: TKR[bool] # noqa: F821 # fmt: skip +class tkr_or(NamedTuple): + a: TKR[bool] # fmt: skip + b: TKR[bool] # fmt: skip @staticmethod def out() -> type[TKR[bool]]: # noqa: F821 # fmt: skip @@ -491,7 +504,7 @@ def namespace(self) -> str: class tkr_decode(NamedTuple): - bytes: TKR[bytes] # noqa: F821 # fmt: skip + value_bytes: TKR[bytes] # fmt: skip @staticmethod def out() -> type[TKR[str]]: # noqa: F821 # fmt: skip @@ -556,8 +569,8 @@ class concat_lists[U: PType, V: PType](NamedTuple): second: TKR[list[V]] # noqa: F821 # fmt: skip @staticmethod - def out() -> type[TKR[list[Union[U, V]]]]: # noqa: F821 # fmt: skip - return TKR[list[Union[U, V]]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[Union[U, V]]]]: # fmt: skip + return TKR[list[Union[U, V]]] # fmt: skip @property def namespace(self) -> str: @@ -565,7 +578,7 @@ def namespace(self) -> str: class tkr_str(NamedTuple): - value: TKR[Union[int, float, bool]] # noqa: F821 # fmt: skip + value: TKR[Union[float, bool]] # fmt: skip @staticmethod def out() -> type[TKR[str]]: # noqa: F821 # fmt: skip @@ -577,7 +590,7 @@ def namespace(self) -> str: class tkr_int(NamedTuple): - value: TKR[Union[int, float, bool, str]] # noqa: F821 # fmt: skip + value: TKR[Union[float, bool, str]] # fmt: skip @staticmethod def out() -> type[TKR[int]]: # noqa: F821 # fmt: skip @@ -589,11 +602,11 @@ def namespace(self) -> str: class sum_list(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -601,11 +614,11 @@ def namespace(self) -> str: class prod_list(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -613,11 +626,11 @@ def namespace(self) -> str: class max_item(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -625,11 +638,11 @@ def namespace(self) -> str: class min_item(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[Union[int, float]]]: # noqa: F821 # fmt: skip - return TKR[Union[int, float]] # noqa: F821 # fmt: skip + def out() -> type[TKR[Union[int, float]]]: # fmt: skip + return TKR[Union[int, float]] # fmt: skip @property def namespace(self) -> str: @@ -637,11 +650,11 @@ def namespace(self) -> str: class sort_number_list(NamedTuple): - values: TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + values: TKR[list[Union[int, float]]] # fmt: skip @staticmethod - def out() -> type[TKR[list[Union[int, float]]]]: # noqa: F821 # fmt: skip - return TKR[list[Union[int, float]]] # noqa: F821 # fmt: skip + def out() -> type[TKR[list[Union[int, float]]]]: # fmt: skip + return TKR[list[Union[int, float]]] # fmt: skip @property def namespace(self) -> str: diff --git a/tierkreis/tierkreis/cli/__init__.py b/tierkreis/tierkreis/cli/__init__.py index e69de29bb..0c81fcc6b 100644 --- a/tierkreis/tierkreis/cli/__init__.py +++ b/tierkreis/tierkreis/cli/__init__.py @@ -0,0 +1 @@ +"""The Tierkreis CLI.""" diff --git a/tierkreis/tierkreis/cli/project.py b/tierkreis/tierkreis/cli/project.py index 2797f5458..57088468a 100644 --- a/tierkreis/tierkreis/cli/project.py +++ b/tierkreis/tierkreis/cli/project.py @@ -1,13 +1,14 @@ +"""CLI for project related operations.""" + import argparse import os -from pathlib import Path -import subprocess import shutil - +import subprocess +from pathlib import Path from tierkreis.cli.templates import ( - external_worker_idl, default_graph, + external_worker_idl, python_worker_main, python_worker_pyproject, python_worker_workspace_pyproject, @@ -19,6 +20,7 @@ def parse_args( parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: + """Parse the arguments for the init subcommand.""" init_subparsers = parser.add_subparsers( dest="init_type", help="Initialize tierkreis related structures", @@ -26,15 +28,16 @@ def parse_args( ) project = init_subparsers.add_parser( "project", - description="Initialize and manages project wide options." - " Please make sure to set up a python project first, e.g. by executing `uv init`.", + description="Initialize and manages project wide options. " + "Make sure to set up a python project first, e.g. by executing `uv init`.", help="Initializes a new tierkreis project and manages project wide options.", ) project.add_argument( "--default-checkpoint-directory", - help="Overwrites the default checkpoint directory and sets the environment variable TKR_DIR for the current shell." - "If you want to persist this behavior add it to your systems environment. e.g. export TKR_DIR=... ", + help="""Overwrites the default checkpoint directory and sets the environment + variable TKR_DIR for the current shell. If you want to persist this behavior + add it to your systems environment. e.g. export TKR_DIR=... """, type=Path, default=Path.home() / ".tierkreis/checkpoints", ) @@ -42,11 +45,12 @@ def parse_args( "--project-directory", help="Sets the default project directory. ", type=Path, - default=Path("."), + default=Path(), ) project.add_argument( "--worker-directory", - help="Overwrites the default worker directory. Defaults to /workers.", + help="Overwrites the default worker directory." + "Defaults to /workers.", type=Path, default=Path("./tkr") / "workers", ) @@ -62,17 +66,23 @@ def parse_args( ) worker.add_argument( "--worker-directory", - help="Overwrites the default worker directory. Defaults to /workers.", + help="Overwrites the default worker directory." + "Defaults to /workers.", type=str, default=Path("./tkr") / "workers", ) worker.add_argument( "--external", - help="Set this flag for non-python workers. This will generate an IDL file instead of python related files.", + help="Set this flag for non-python workers." + " This will generate an IDL file instead of python related files.", action="store_true", ) worker.add_argument( - "-n", "--name", required=True, help="The name of the new worker", type=str + "-n", + "--name", + required=True, + help="The name of the new worker", + type=str, ) stubs = init_subparsers.add_parser("stubs", help="Generates worker stubs with UV.") stubs.add_argument( @@ -90,37 +100,38 @@ def parse_args( return parser -def _gen_worker(worker_name: str, worker_dir: Path, external: bool = False) -> None: +def _gen_worker(worker_name: str, worker_dir: Path, *, external: bool = False) -> None: base_dir = worker_dir / worker_name base_dir.mkdir(exist_ok=True) - with open(base_dir / "README.md", "w+", encoding="utf-8") as fh: + with Path.open(base_dir / "README.md", "w+", encoding="utf-8") as fh: fh.write(f"# {worker_name} \n") - with open(base_dir / "pyproject.toml", "w+", encoding="utf-8") as fh: - fh.write(python_worker_workspace_pyproject(worker_name)) + with Path.open(base_dir / "pyproject.toml", "w+", encoding="utf-8") as fh: + fh.write(python_worker_workspace_pyproject(worker_name, external=external)) api_dir = base_dir / "api" src_dir = base_dir / "src" api_dir.mkdir(exist_ok=True) - with open(api_dir / "pyproject.toml", "w+", encoding="utf-8") as fh: + with Path.open(api_dir / "pyproject.toml", "w+", encoding="utf-8") as fh: fh.write(python_worker_pyproject(worker_name, kind="api")) - with open(api_dir / "README.md", "w+", encoding="utf-8") as fh: + with Path.open(api_dir / "README.md", "w+", encoding="utf-8") as fh: fh.write(f"# {worker_name}-api \n") src_dir.mkdir(exist_ok=True) if external: - with open(src_dir / f"{worker_name}.tsp", "w+", encoding="utf-8") as fh: + with Path.open(src_dir / f"{worker_name}.tsp", "w+", encoding="utf-8") as fh: fh.write(external_worker_idl(worker_name)) return - with open(src_dir / "main.py", "w+", encoding="utf-8") as fh: + with Path.open(src_dir / "main.py", "w+", encoding="utf-8") as fh: fh.write(python_worker_main(worker_name)) - with open(src_dir / "pyproject.toml", "w+", encoding="utf-8") as fh: + with Path.open(src_dir / "pyproject.toml", "w+", encoding="utf-8") as fh: fh.write(python_worker_pyproject(worker_name, kind="src")) - with open(src_dir / "README.md", "w+", encoding="utf-8") as fh: + with Path.open(src_dir / "README.md", "w+", encoding="utf-8") as fh: fh.write(f"# {worker_name}-src \n") def _gen_stubs(worker_directory: Path, stubs_name: str) -> None: uv_path = shutil.which("uv") if uv_path is None: - raise TierkreisError("uv is required to use this feature.") + msg = "uv is required to use this feature." + raise TierkreisError(msg) for worker in worker_directory.iterdir(): if not worker.is_dir(): continue @@ -129,11 +140,14 @@ def _gen_stubs(worker_directory: Path, stubs_name: str) -> None: namespace.write_stubs(idl.parent / stubs_name) else: subprocess.run( - [uv_path, "run", "src/main.py", "--stubs-path", stubs_name], cwd=worker + [uv_path, "run", "src/main.py", "--stubs-path", stubs_name], + cwd=worker, + check=True, ) def run_args(args: argparse.Namespace) -> None: + """Run the project initialization according to the args.""" if args.init_type == "project": worker_name = "example_worker" worker_dir = Path(args.worker_directory) @@ -145,40 +159,38 @@ def run_args(args: argparse.Namespace) -> None: worker_dir.mkdir(exist_ok=True, parents=True) _gen_worker(worker_name, worker_dir) graphs_dir.mkdir(exist_ok=True, parents=True) - with open(graphs_dir / "main.py", "w+", encoding="utf-8") as fh: + with Path.open(graphs_dir / "main.py", "w+", encoding="utf-8") as fh: fh.write(default_graph(worker_name)) os.environ["TKR_DIR"] = str(args.default_checkpoint_directory) _gen_stubs(worker_dir, "./api/api.py") - print(f"""Successfully generated project in '{args.project_directory}'. - -To run the sample graph use "python -m tkr.graphs.main". -Or import the function into a top level script with: - -from tkr.graphs.main import main -main() - -It is highly recommended to add this to your project definition e.g. pyproject.toml. -""") elif args.init_type == "worker": Path(args.worker_directory).mkdir(exist_ok=True, parents=True) - _gen_worker(args.name, Path(args.worker_directory), args.external) + _gen_worker(args.name, Path(args.worker_directory), external=args.external) elif args.init_type == "stubs": _gen_stubs(Path(args.worker_directory), args.api_file_name) class TierkreisInitCli: + """Tierkeirs cli for the `init` subcommand. + + Used to initialize tkr projects.. + """ + @staticmethod def add_subcommand( main_parser: argparse._SubParsersAction, ) -> None: + """Add the init subcommand.""" parser = main_parser.add_parser( "init", description="Initializes the tierkreis project resources", - help="Initializes the tierkreis project resources. Run `tkr init --help` for more information.", + help="Initializes the tierkreis project resources. Run `tkr init --help`" + " for more information.", ) parser = parse_args(parser) parser.set_defaults(func=TierkreisInitCli.execute) @staticmethod def execute(args: argparse.Namespace) -> None: + """Execute the init subcommand.""" run_args(args) diff --git a/tierkreis/tierkreis/cli/run.py b/tierkreis/tierkreis/cli/run.py index 45f1d0d34..861487c69 100644 --- a/tierkreis/tierkreis/cli/run.py +++ b/tierkreis/tierkreis/cli/run.py @@ -1,32 +1,51 @@ +"""Tierkreis CLI main entrypoint.""" + from __future__ import annotations -import argparse import importlib import json import logging import sys from pathlib import Path -from typing import Any, Callable +from typing import TYPE_CHECKING from tierkreis.cli.run_workflow import run_workflow from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType, ptype_from_bytes from tierkreis.exceptions import TierkreisError +if TYPE_CHECKING: + import argparse + import types + from collections.abc import Callable + +logger = logging.getLogger(__name__) + -def _import_from_path(module_name: str, file_path: str) -> Any: - spec = importlib.util.spec_from_file_location(module_name, file_path) # type: ignore - module = importlib.util.module_from_spec(spec) # type: ignore +def _import_from_path(module_name: str, file_path: str) -> types.ModuleType: + """Import a graph when supplied as a path to a python file.""" + spec = importlib.util.spec_from_file_location(module_name, file_path) # type: ignore[no-untyped-call] + module = importlib.util.module_from_spec(spec) # type: ignore[no-untyped-call] sys.modules[module_name] = module spec.loader.exec_module(module) return module def load_graph(graph_input: str) -> GraphData: + """Load a graph from an argument string. + + Loads a graph similar to how python runs modules with "-m" + + :param graph_input: The argument string specifying the graph. + :type graph_input: str + :raises TierkreisError: If the argument string is invalid. + :return: The loaded graph data. + :rtype: GraphData + """ if ":" not in graph_input: - raise TierkreisError(f"Invalid argument: {graph_input}") + msg = f"Invalid argument: {graph_input}" + raise TierkreisError(msg) module_name, function_name = graph_input.split(":") - print(f"Loading graph from module '{module_name}' and function '{function_name}'") if ".py" in module_name: module = _import_from_path("graph_module", module_name) else: @@ -37,15 +56,17 @@ def load_graph(graph_input: str) -> GraphData: def _load_inputs(input_files: list[str]) -> dict[str, PType]: + """Load the inputs to a graph.""" if len(input_files) == 1 and input_files[0].endswith(".json"): - with open(input_files[0], "r") as fh: + with Path.open(Path(input_files[0])) as fh: return {k: json.dumps(v).encode() for k, v in json.load(fh).items()} inputs = {} for input_file in input_files: if ":" not in input_file: - raise TierkreisError(f"Invalid argument: {input_file}") + msg = f"Invalid argument: {input_file}" + raise TierkreisError(msg) key, value = input_file.split(":") - with open(value, "rb") as fh: + with Path.open(Path(value), "rb") as fh: inputs[key] = ptype_from_bytes(fh.read()) return inputs @@ -53,22 +74,31 @@ def _load_inputs(input_files: list[str]) -> dict[str, PType]: def parse_args( main_parser: argparse._SubParsersAction[argparse.ArgumentParser], ) -> argparse.ArgumentParser: + """Parse the arguments for the 'run' subcommand. + + :param main_parser: The main parser to add the subcommand to. + :type main_parser: argparse._SubParsersAction[argparse.ArgumentParser] + :return: The parser for the 'run' subcommand. + :rtype: argparse.ArgumentParser + """ parser = main_parser.add_parser( name="run", - description="Runs tierkreis graphs from the cli.", - help="Runs tierkreis graphs. Run `tkr run --help` for more information.", + description="Tierkreis: a workflow engine for quantum HPC.", ) graph = parser.add_mutually_exclusive_group(required=True) graph.add_argument( - "-f", "--from-file", type=Path, help="Load a graph from a .json file" + "-f", + "--from-file", + type=Path, + help="Load a graph from a .json file", ) graph.add_argument( "-g", "--graph-location", help="Fully qualifying name of a Callable () -> GraphData. " - + "Example: tierkreis.cli.sample_graph:simple_eval" - + "Or a path to a python file and function." - + "Example: docs/source/examples/hello_world.py:graph", + "Example: tierkreis.cli.sample_graph:simple_eval" + "Or a path to a python file and function." + "Example: examples/hello_world/hello_world_graph.py:hello_graph", type=str, ) parser.add_argument( @@ -77,10 +107,13 @@ def parse_args( nargs="*", help="Graph inputs:" "Either a single .json file or a key value list port1:path1 port2:path2" - + "where path is a binary file.", + "where path is a binary file.", ) parser.add_argument( - "--run-id", default=None, type=int, help="Set a workflow run id" + "--run-id", + default=None, + type=int, + help="Set a workflow run id", ) parser.add_argument("--name", default=None, type=str, help="Set a workflow name") parser.add_argument( @@ -92,7 +125,10 @@ def parse_args( ) parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( - "--registry-path", default=None, type=Path, help="Location of executable tasks." + "--registry-path", + default=None, + type=Path, + help="Location of executable tasks.", ) parser.add_argument( "-o", @@ -127,41 +163,50 @@ def parse_args( return parser -def run_workflow_args(args: argparse.Namespace): +def run_workflow_args(args: argparse.Namespace) -> None: + """Run a tierkreis workflow according to the run command. + + :param args: The arguments parsed from tkr run. + :type args: argparse.Namespace + """ if args.verbose: args.log_level = logging.DEBUG if args.graph_location is not None: graph = load_graph(args.graph_location) else: - with open(args.from_file, "r") as fh: + with Path.open(args.from_file) as fh: graph = ptype_from_bytes(fh.read().encode(), GraphData) - if args.input_files is not None: - inputs = _load_inputs(args.input_files) - else: - inputs = {} - print(inputs) + inputs = _load_inputs(args.input_files) if args.input_files is not None else {} run_workflow( graph, inputs, name=args.name, run_id=args.run_id, + log_level=args.log_level, registry_path=args.registry_path, - use_uv_worker=args.uv, n_iterations=args.n_iterations, polling_interval_seconds=args.polling_interval_seconds, print_output=args.print_output, + use_uv_executor=args.uv, ) class TierkreisRunCli: + """Tierkeirs cli for the `run` subcommand. + + Used to run graphs with tkr run ... + """ + @staticmethod def add_subcommand( main_parser: argparse._SubParsersAction[argparse.ArgumentParser], ) -> None: + """Add the run subcommand.""" parser = parse_args(main_parser) parser.set_defaults(func=TierkreisRunCli.execute) @staticmethod def execute(args: argparse.Namespace) -> None: + """Execute the run subcommand.""" run_workflow_args(args) diff --git a/tierkreis/tierkreis/cli/run_workflow.py b/tierkreis/tierkreis/cli/run_workflow.py index 884304257..8c2d2d9ca 100644 --- a/tierkreis/tierkreis/cli/run_workflow.py +++ b/tierkreis/tierkreis/cli/run_workflow.py @@ -1,41 +1,67 @@ -from pathlib import Path -import uuid +"""Implementation to run a workflow.""" + import logging +import uuid +from pathlib import Path from tierkreis.controller import run_graph from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.types import PType -from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.executor.uv_executor import UvExecutor +from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.storage import read_outputs logger = logging.getLogger(__name__) -def run_workflow( +def run_workflow( # noqa: PLR0913 graph: GraphData, inputs: dict[str, PType], name: str | None = None, run_id: int | None = None, log_level: int | str = logging.INFO, registry_path: Path | None = None, - print_output: bool = False, - use_uv_worker: bool = False, n_iterations: int = 10000, polling_interval_seconds: float = 0.1, + *, + print_output: bool = False, + use_uv_executor: bool = False, ) -> None: - """Run a workflow.""" + """Run a workflow. + + Wrapper for :py:func:`tierkreis.controller.run_graph.run_graph` to run a workflow. + Adds some sensible defaults. + + :param graph: The graph to run. + :type graph: GraphData + :param inputs: The inputs to the workflow. + :type inputs: dict[str, PType] + :param name: The name of the workflow, defaults to None + :type name: str | None, optional + :param run_id: The run ID of the workflow, defaults to None + :type run_id: int | None, optional + :param log_level: The log level for the workflow, defaults to logging.INFO + :type log_level: int | str, optional + :param registry_path: The worker registry, defaults to Path(__file__).parent + :type registry_path: Path | None, optional + :param print_output: Whether to print final outputs, defaults to False + :type print_output: bool, optional + :param use_uv_executor: Use the UV executor instead of ShellExecutor + , defaults to False + :type use_uv_executor: bool, optional + :param n_iterations: The maximum number of iterations, defaults to 10000 + :type n_iterations: int, optional + :param polling_interval_seconds: The controller tickrate, defaults to 0.1 + :type polling_interval_seconds: float, optional + """ logger.setLevel(log_level) - if run_id is None: - workflow_id = uuid.uuid4() - else: - workflow_id = uuid.UUID(int=run_id) + workflow_id = uuid.uuid4() if run_id is None else uuid.UUID(int=run_id) logger.info("Workflow ID is %s", workflow_id) storage = ControllerFileStorage(workflow_id, name=name, do_cleanup=True) if registry_path is None: registry_path = Path(__file__).parent - if use_uv_worker: + if use_uv_executor: executor = UvExecutor(registry_path=registry_path, logs_path=storage.logs_path) else: executor = ShellExecutor(registry_path, storage.workflow_dir) @@ -51,5 +77,9 @@ def run_workflow( polling_interval_seconds, ) if print_output: - res = read_outputs(graph, storage) - print(res) + all_outputs = read_outputs(graph, storage) + if isinstance(all_outputs, dict): + for output_name, output_value in all_outputs.items(): + print(f"{output_name}: {output_value!r}") # noqa: T201 + else: + print(f"value: {all_outputs!r}") # noqa: T201 diff --git a/tierkreis/tierkreis/cli/templates.py b/tierkreis/tierkreis/cli/templates.py index 1811f9743..f1b22f9db 100644 --- a/tierkreis/tierkreis/cli/templates.py +++ b/tierkreis/tierkreis/cli/templates.py @@ -1,7 +1,16 @@ +"""String template for the project initialization.""" + from typing import Literal def python_worker_main(worker_name: str) -> str: + """Generate a python morker main.py. + + :param worker_name: The name of the worker. + :type worker_name: str + :return: The generated main.py content. + :rtype: str + """ worker_name = worker_name.replace("-", "_") return f"""from sys import argv @@ -25,7 +34,21 @@ def main(): """ -def python_worker_workspace_pyproject(worker_name: str, external: bool = False) -> str: +def python_worker_workspace_pyproject( + worker_name: str, + *, + external: bool = False, +) -> str: + """Generate the pyproject.toml for the worker workspace. + + :param worker_name: Name of the worker. + :type worker_name: str + :param external: Whether the worker is external (not-python worker), + defaults to False + :type external: bool, optional + :return: The generated pyproject.toml content. + :rtype: str + """ worker_name = worker_name.replace("_", "-") template = f"""[project] name = "tkr-{worker_name}" @@ -63,8 +86,20 @@ def python_worker_workspace_pyproject(worker_name: str, external: bool = False) def python_worker_pyproject( - worker_name: str, kind: Literal["api", "src"] = "api" + worker_name: str, + kind: Literal["api", "src"] = "api", ) -> str: + """Generate the pyproject.toml for the worker. + + Either for the api directory (only stubs) used during build time or + the src directory (the actual worker implementation) used during runtime. + + :param worker_name: Name of the worker. + :type worker_name: str + :param kind: Either "api" or "src", defaults to "api" + :type kind: Literal['api', 'src'], optional, + :rtype: str + """ worker_name = worker_name.replace("_", "-") template = f"""[project] name = "tkr-{worker_name}-{kind}" @@ -94,10 +129,17 @@ def python_worker_pyproject( def external_worker_idl(worker_name: str) -> str: + """Generate a typespec file for a worker. + + :param worker_name: The name of the worker. + :type worker_name: str + :return: The generated typespec content. + :rtype: str + """ return f"""model YourModel {{ value: int }} - + interface {worker_name} {{ your_function(value: int): YourModel; }} @@ -106,11 +148,18 @@ def external_worker_idl(worker_name: str) -> str: def default_graph(worker_name: str) -> str: + """Generate a default graph example using a worker. + + :param worker_name: The name of the worker. + :type worker_name: str + :return: The generated main graph content. + :rtype: str + """ worker_name = worker_name.replace("-", "_") return f"""from typing import NamedTuple from pathlib import Path from uuid import UUID - + from tierkreis.builder import GraphBuilder from tierkreis.controller import run_graph from tierkreis.controller.data.models import TKR, OpaqueType @@ -126,13 +175,13 @@ class GraphInputs(NamedTuple): class GraphOutputs(NamedTuple): value: TKR[int] - + def your_graph() -> GraphBuilder[GraphInputs, GraphOutputs]: g = GraphBuilder(GraphInputs, GraphOutputs) out = g.task(your_worker_task(g.inputs.value)) g.outputs(GraphOutputs(value=out)) return g - + def main() -> None: graph = your_graph() storage = FileStorage(workflow_id=UUID(int=12345), name="your_graph") diff --git a/tierkreis/tierkreis/cli/tkr.py b/tierkreis/tierkreis/cli/tkr.py index df7d92dd3..2e4b4f85d 100644 --- a/tierkreis/tierkreis/cli/tkr.py +++ b/tierkreis/tierkreis/cli/tkr.py @@ -1,27 +1,35 @@ +"""Tierkreis CLI main entrypoint.""" + +from __future__ import annotations + import argparse import logging import sys -from tierkreis.cli.run import TierkreisRunCli from tierkreis.cli.project import TierkreisInitCli +from tierkreis.cli.run import TierkreisRunCli + +logger = logging.getLogger(__name__) def main() -> None: + """Run the main entry point for the tkr cli.""" parser = argparse.ArgumentParser( prog="tkr", - description="Tierkreis: a workflow engine for quantum HPC. This is the main tierkreis command-line tool.", + description="Tierkreis: a workflow engine for quantum HPC." + "This is the main tierkreis command-line tool.", ) subparser = parser.add_subparsers(title="subcommands") TierkreisRunCli.add_subcommand(subparser) TierkreisInitCli.add_subcommand(subparser) try: - from tierkreis_visualization.cli import TierkreisVizCli + from tierkreis_visualization.cli import TierkreisVizCli # noqa: PLC0415 TierkreisVizCli.add_subcommand(subparser) except ImportError: - logging.warning("Could not import Tierkreis Visualization CLI") - logging.warning( - "To install it, please run 'pip install tierkreis-visualization'" + logger.warning("Could not import Tierkreis Visualization CLI") + logger.warning( + "To install it, please run 'pip install tierkreis-visualization'", ) args = parser.parse_args(args=None if sys.argv[1:] else ["--help"]) args.func(args) diff --git a/tierkreis/tierkreis/codegen.py b/tierkreis/tierkreis/codegen.py index 1f43d5ee2..1d589b186 100644 --- a/tierkreis/tierkreis/codegen.py +++ b/tierkreis/tierkreis/codegen.py @@ -1,5 +1,9 @@ +"""Code generation utilities for Tierkreis stubs.""" + from inspect import isclass + from pydantic import BaseModel + from tierkreis.controller.data.types import ( DictConvertible, ListConvertible, @@ -12,11 +16,19 @@ def format_ptype(ptype: type | str) -> str: + """Format a ptype to a string. + + :param ptype: The type to format. + :type ptype: type | str + :return: The formatted string representation of the type. + :rtype: str + """ if isinstance(ptype, str): return ptype if isclass(ptype) and issubclass( - ptype, (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel) + ptype, + (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel), ): return f'OpaqueType["{ptype.__module__}.{ptype.__qualname__}"]' @@ -27,8 +39,22 @@ def format_ptype(ptype: type | str) -> str: def format_generic_type( - generictype: GenericType | str, include_bound: bool, is_tkr: bool + generictype: GenericType | str, + *, + include_bound: bool, + is_tkr: bool, ) -> str: + """Format a generic type to a string. + + :param generictype: The generic type to format. + :type generictype: GenericType | str + :param include_bound: Whether to include the bound. + :type include_bound: bool + :param is_tkr: Whether the type is a TKR type. + :type is_tkr: bool + :return: The formatted string representation of the generic type. + :rtype: str + """ bound_str = ": PType" if include_bound else "" if isinstance(generictype, str): out = generictype + bound_str @@ -36,45 +62,84 @@ def format_generic_type( origin_str = format_ptype(generictype.origin) - generics = [format_generic_type(x, include_bound, False) for x in generictype.args] + generics = [ + format_generic_type(x, include_bound=include_bound, is_tkr=False) + for x in generictype.args + ] generics_str = f"[{', '.join(generics)}]" if generictype.args else "" out = f"{origin_str}{generics_str}" return f"TKR[{out}]" if is_tkr else out -def format_typed_arg(typed_arg: TypedArg, is_portmaping: bool) -> str: - type_str = format_generic_type(typed_arg.t, False, not is_portmaping) - should_quote = typed_arg.t.included_structs() and is_portmaping +def format_typed_arg(typed_arg: TypedArg, *, is_portmapping: bool) -> str: + """Format a typed argument to a string. + + :param typed_arg: The typed argument. + :type typed_arg: TypedArg + :param is_portmapping: Wheter the argument is a portmapping. + :type is_portmapping: bool + :return: The formatted string representation of the typed argument. + :rtype: str + """ + type_str = format_generic_type( + typed_arg.t, + include_bound=False, + is_tkr=not is_portmapping, + ) + should_quote = typed_arg.t.included_structs() and is_portmapping type_str = f'"{type_str}"' if should_quote else type_str default_str = " | None = None " if typed_arg.has_default else "" return f"{typed_arg.name}: {type_str}{default_str} {NO_QA_STR}" def format_model(model: Model) -> str: + """Format a model to a string. + + :param model: The model to format. + :type model: Model + :return: The formatted string representation of the model. + :rtype: str + """ is_portmapping = model.is_portmapping - outs = [format_typed_arg(x, not is_portmapping) for x in model.decls] + outs = [format_typed_arg(x, is_portmapping=not is_portmapping) for x in model.decls] outs.sort() outs_str = "\n ".join(outs) bases = ["NamedTuple"] if is_portmapping else ["Struct", "Protocol"] + bases_str = ", ".join(bases) + generic_type_str = format_generic_type(model.t, include_bound=True, is_tkr=False) return f""" -class {format_generic_type(model.t, True, False)}({", ".join(bases)}): +class {generic_type_str}({bases_str}): {outs_str} """ def format_method(namespace_name: str, fn: Method) -> str: - ins = [format_typed_arg(x, False) for x in fn.args] + """Format a method to a string. + + :param namespace_name: The function namespace. + :type namespace_name: str + :param fn: The method to format. + :type fn: Method + :return: The formatted string representation of the method. + :rtype: str + """ + ins = [format_typed_arg(x, is_portmapping=False) for x in fn.args] ins_str = "\n ".join(ins) class_name = format_generic_type( - fn.return_type, False, not fn.return_type_is_portmapping + fn.return_type, + include_bound=False, + is_tkr=not fn.return_type_is_portmapping, ) bases = ["NamedTuple"] - return f"""class {format_generic_type(fn.name, True, False)}({", ".join(bases)}): + class_def = format_generic_type(fn.name, include_bound=True, is_tkr=False) + bases_str = ", ".join(bases) + + return f"""class {class_def}({bases_str}): {ins_str} @staticmethod diff --git a/tierkreis/tierkreis/consts.py b/tierkreis/tierkreis/consts.py index 2b3ffdab7..03039a037 100644 --- a/tierkreis/tierkreis/consts.py +++ b/tierkreis/tierkreis/consts.py @@ -1,5 +1,6 @@ -from pathlib import Path +"""Tierkreis constant definitions.""" +from pathlib import Path PACKAGE_PATH = Path(__file__).parent.parent TESTS_PATH = PACKAGE_PATH / "tests" diff --git a/tierkreis/tierkreis/controller/__init__.py b/tierkreis/tierkreis/controller/__init__.py index 5a7dbc2d3..e621f561b 100644 --- a/tierkreis/tierkreis/controller/__init__.py +++ b/tierkreis/tierkreis/controller/__init__.py @@ -1,5 +1,13 @@ +"""The workflow controller for Tierkreis. + +This is the main functionality controlling the execution of Tierkreis graphs. +It provides the main entry point for running a graph, + and the main loop for resuming a graph until completion. +""" + import logging from time import sleep +from typing import TYPE_CHECKING from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import Eval, GraphData @@ -8,11 +16,13 @@ from tierkreis.controller.data.types import PType, bytes_from_ptype, ptype_from_bytes from tierkreis.controller.executor.protocol import ControllerExecutor from tierkreis.controller.start import NodeRunData, start, start_nodes -from tierkreis.logger_setup import set_tkr_logger from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.controller.storage.walk import walk_node -from tierkreis.controller.data.core import PortID, ValueRef from tierkreis.exceptions import TierkreisError +from tierkreis.logger_setup import set_tkr_logger + +if TYPE_CHECKING: + from tierkreis.controller.data.core import PortID, ValueRef root_loc = Loc("") logger = logging.getLogger(__name__) @@ -25,18 +35,43 @@ def run_graph[A: TModel, B: TModel]( graph_inputs: dict[str, PType] | PType, n_iterations: int = 10000, polling_interval_seconds: float = 0.01, + *, enable_logging: bool = True, ) -> None: + """Start a graph execution. + + Kicks of the execution by writing the graph inputs to storage. + Also marks the top-level eval wrapping the graph asa ready. + + :param storage: The storage backend for the controller. + :type storage: ControllerStorage + :param executor: The executor backend for the controller. + :type executor: ControllerExecutor + :param g: The graph to run. + :type g: GraphData | GraphBuilder[A, B] + :param graph_inputs: The inputs to the graph. + If a single PType is provided, it will be provided as the input "value". + :type graph_inputs: dict[str, PType] | PType + :param n_iterations: The maximum number of iterations to run the graph, + defaults to 10000 + :type n_iterations: int, optional + :param polling_interval_seconds: The polling interval in seconds, defaults to 0.01 + :type polling_interval_seconds: float, optional + :param enable_logging: Whether to enable logging, defaults to True + :type enable_logging: bool, optional + :raises TierkreisError: If the graph encounters errors during execution. + """ if isinstance(g, GraphBuilder): g = g.get_data() if not isinstance(graph_inputs, dict): graph_inputs = {"value": graph_inputs} - remaining_inputs = g.remaining_inputs({k for k in graph_inputs.keys()}) + remaining_inputs = g.remaining_inputs(set(graph_inputs.keys())) if len(remaining_inputs) > 0: logger.warning( - f"Some inputs were not provided: {remaining_inputs}. " - "Tasks will use default values if available." + "Some inputs were not provided: %s" + "Tasks will use default values if available.", + remaining_inputs, ) storage.write_metadata(Loc("")) @@ -52,7 +87,7 @@ def run_graph[A: TModel, B: TModel]( k: (-1, k) for k, _ in graph_inputs.items() if k != "body" } node_run_data = NodeRunData(Loc(), Eval((-1, "body"), inputs), []) - start(storage, executor, node_run_data, enable_logging) + start(storage, executor, node_run_data) resume_graph(storage, executor, n_iterations, polling_interval_seconds) @@ -62,6 +97,24 @@ def resume_graph( n_iterations: int = 10000, polling_interval_seconds: float = 0.01, ) -> None: + """Resume a graph after initial start. + + This iteratively walks the graph to find new nodes to start. + A node is ready to start once all its inputs are available. + Starts from constructing the dependencies by starting from the output node + and walking backwards. + + :param storage: The storage backend for the controller. + :type storage: ControllerStorage + :param executor: The executor backend for the controller. + :type executor: ControllerExecutor + :param n_iterations: The maximum number of iterations to run the graph, + defaults to 10000 + :type n_iterations: int, optional + :param polling_interval_seconds: The polling interval in seconds, defaults to 0.01 + :type polling_interval_seconds: float, optional + :raises TierkreisError: If the graph encounters errors during execution. + """ message = storage.read_output(Loc().N(-1), "body") graph = ptype_from_bytes(message, GraphData) @@ -73,18 +126,18 @@ def resume_graph( node_errors = "\n".join(x for x in walk_results.errored) storage.write_node_errors(Loc(), node_errors) - print("\n\nGraph finished with errors.\n\n") - + logger.error("\n\nGraph finished with errors.\n\n") for error_loc in walk_results.errored: - print(storage.read_errors(error_loc)) - print(f"Node: '{error_loc}' encountered an error.") - print( - f"Stderr information is available at {storage._worker_logs_path(error_loc)}." + logger.error(storage.read_errors(error_loc)) + logger.error("Node: '%s' encountered an error.", error_loc) + logger.error( + "Stderr information is available at %s.", + storage._worker_logs_path(error_loc), # noqa: SLF001 ) - print("\n\n") - print("--- Tierkreis graph errors above this line. ---\n\n") - raise TierkreisError("Graph encountered errors") + logger.error("--- Tierkreis graph errors above this line. ---") + msg = "Graph encountered errors" + raise TierkreisError(msg) start_nodes(storage, executor, walk_results.inputs_ready) if storage.is_node_finished(Loc()): diff --git a/tierkreis/tierkreis/controller/consts.py b/tierkreis/tierkreis/controller/consts.py index 65bf1757d..9df05b531 100644 --- a/tierkreis/tierkreis/controller/consts.py +++ b/tierkreis/tierkreis/controller/consts.py @@ -1,5 +1,6 @@ -import os +"""Controller constants.""" + from pathlib import Path BODY_PORT = "body" -PACKAGE_PATH = Path(os.path.dirname(os.path.realpath(__file__))) +PACKAGE_PATH = Path(__file__).resolve().parent diff --git a/tierkreis/tierkreis/controller/data/__init__.py b/tierkreis/tierkreis/controller/data/__init__.py index e69de29bb..2ba1472be 100644 --- a/tierkreis/tierkreis/controller/data/__init__.py +++ b/tierkreis/tierkreis/controller/data/__init__.py @@ -0,0 +1 @@ +"""Core data structures for typing and constructing graphs.""" diff --git a/tierkreis/tierkreis/controller/data/core.py b/tierkreis/tierkreis/controller/data/core.py index 1aa3422f7..42f80a62a 100644 --- a/tierkreis/tierkreis/controller/data/core.py +++ b/tierkreis/tierkreis/controller/data/core.py @@ -1,8 +1,16 @@ +"""Core types in tierkreis. + +- PortID = str, name of an output on a node +- NodeIndex = int, index of node in the graph list +- ValueRef = tuple[NodeIndex, PortID] reference of a value; + uniquely identified by the node and its output. +""" + +from collections.abc import Callable from dataclasses import dataclass from typing import ( Annotated, Any, - Callable, Literal, NamedTuple, Protocol, @@ -12,14 +20,14 @@ runtime_checkable, ) - PortID = str NodeIndex = int ValueRef = tuple[NodeIndex, PortID] SerializationFormat = Literal["bytes", "json", "unknown"] -class EmptyModel(NamedTuple): ... +class EmptyModel(NamedTuple): + """A model without content.""" @runtime_checkable @@ -27,33 +35,82 @@ class RestrictedNamedTuple[T](Protocol): """A NamedTuple whose members are restricted to being of type T.""" def _asdict(self) -> dict[str, T]: ... - def __getitem__(self, key: SupportsIndex, /) -> T: ... + def __getitem__(self, key: SupportsIndex, /) -> T: + """Access the indexed element as in a tuple.""" + ... @dataclass class Serializer: + """Serializer for tkr values. + + :fields: + serializer (Callable[[Any], Any]): A function taking a value producing a + serialized version of it. + serialization_method (Literal): Indicator of serializer type of + ["bytes", "json", "unknown"], defaults to "bytes". + """ + serializer: Callable[[Any], Any] serialization_method: SerializationFormat = "bytes" @dataclass class Deserializer: + """Serializer for tkr values. + + :fields: + serializer (Callable[[Any], Any]): A function taking a serialized + value producing a deserialized value. + serialization_method (Literal): Indicator of serializer type of + ["bytes", "json", "unknown"], defaults to "bytes". + """ + deserializer: Callable[[Any], Any] serialization_method: SerializationFormat = "bytes" def get_t_from_args[T](t: type[T], hint: type | None) -> T | None: + """Get the possible type generic T from a type. + + :return: The generic hint T if it exists on the value. + Either from its annotation or type hint. + :rtype: T | None + """ if hint is None or get_origin(hint) is not Annotated: return None for arg in get_args(hint): if isinstance(arg, t): return arg + return None + +def get_serializer(hint: type | None) -> Serializer | None: + """Get the serializer for an annotated type. -def get_serializer(hint: type | None): + This is relevant for annotated worker types + AnnotatedType = Annotated[BaseType, ser, deser] + worker_fn: AnnotatedType -> AnnotatedType + + :param hint: The type to get the serializer for. + :type hint: type | None + :return: The deserializer if one is annotated. + :rtype: Deserializer | None + """ return get_t_from_args(Serializer, hint) -def get_deserializer(hint: type | None): +def get_deserializer(hint: type | None) -> Deserializer | None: + """Get the deserializer for an annotated type. + + This is relevant for annotated worker types + AnnotatedType = Annotated[BaseType, ser, deser] + worker_fn: AnnotatedType -> AnnotatedType + + :param hint: The type to get the deserializer for. + :type hint: type | None + :return: The deserializer if one is annotated. + :rtype: Deserializer | None + """ return get_t_from_args(Deserializer, hint) diff --git a/tierkreis/tierkreis/controller/data/graph.py b/tierkreis/tierkreis/controller/data/graph.py index 1891dbe64..646e934b3 100644 --- a/tierkreis/tierkreis/controller/data/graph.py +++ b/tierkreis/tierkreis/controller/data/graph.py @@ -1,6 +1,18 @@ +"""Graph and node definitions. + +(Computational) graphs are the underlying data structure for workflows in tierkreis. +A Graph is comprised on nodes (atomic operations) and edges (their values). +Nodes have named inputs referencing a previously computed value in the graph; +and named outputs referencing an id to look for the respective value. +Inputs and outputs are called ports. +The graph is constructed by mapping inputs off a node (by name) to the +outputs of a previous node. +""" + import logging +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Callable, Literal, assert_never +from typing import Any, Literal, assert_never from pydantic import BaseModel, RootModel @@ -14,12 +26,24 @@ @dataclass class NodeDefBase: - outputs: dict[PortID, list[NodeIndex]] = field(default_factory=dict, kw_only=True) """Map each out-port to the list of nodes that use it.""" + outputs: dict[PortID, list[NodeIndex]] = field(default_factory=dict, kw_only=True) + @dataclass class Func(NodeDefBase): + """A function node. + + Defines a task which is run by a worker on an executor. + + :fields: + function_name (str): The function to run. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + outputs (dict[PortID, NodeIndex]): Typically not used on functions, + for the sake of simplifying NodeData. + """ + function_name: str inputs: dict[PortID, ValueRef] type: Literal["function"] = field(default="function") @@ -27,6 +51,18 @@ class Func(NodeDefBase): @dataclass class Eval(NodeDefBase): + """An eval node. + + Evaluates a nested graph. + Necessary for higher order operations. + + :fields: + graph (ValueRef): The reference to a nested graph body. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the nested graph. + """ + graph: ValueRef inputs: dict[PortID, ValueRef] type: Literal["eval"] = field(default="eval") @@ -34,6 +70,21 @@ class Eval(NodeDefBase): @dataclass class Loop(NodeDefBase): + """A loop node. + + Evaluates a nested graph iteratively. + Inputs are updated from the previous iteration. + Loops until continue_port value evaluates to false. + + :fields: + body (ValueRef): The reference to a nested graph body. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + continue_port: PortID: A named boolean port as stopping criterion. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the nested graph. + name (str | None): Used as debug data for loop tracing. + """ + body: ValueRef inputs: dict[PortID, ValueRef] continue_port: PortID # The port that specifies if the loop should continue. @@ -43,6 +94,20 @@ class Loop(NodeDefBase): @dataclass class Map(NodeDefBase): + """A map node. + + Evaluates a nested graph concurrently for a set of values on one port. + Maps have a * input which indicates the value to map over. + Typically this is done by fold map a b c unfold [...] where a b c are the arbitrary + but fixed inputs of the map. + + :fields: + body (ValueRef): The reference to a nested graph body. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + outputs (dict[PortID, NodeIndex]): Typically not used on functions, + for the sake of simplifying NodeData. + """ + body: ValueRef inputs: dict[PortID, ValueRef] type: Literal["map"] = field(default="map") @@ -50,39 +115,91 @@ class Map(NodeDefBase): @dataclass class Const(NodeDefBase): + """A constant node. + + :fields: + value (Any): The constant value + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + Typically "value" or "body" + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the nested graphs. + """ + value: Any - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["const"] = field(default="const") @dataclass class IfElse(NodeDefBase): + """A lazy if else node. + + :fields: + pred (ValueRef): Ref to a boolean value dictating which branch to evaluate. + if_true (ValueRef): Branch to evaluate when pred is true. + if_false (ValueRef): Branch to evaluate when pred is false. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + Typically pred and values for the branches. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the branches. + """ + pred: ValueRef if_true: ValueRef if_false: ValueRef - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["ifelse"] = field(default="ifelse") @dataclass class EagerIfElse(NodeDefBase): + """An eager if else node. + + :fields: + pred (ValueRef): Ref to a boolean value dictating which value to forward. + if_true (ValueRef): Branch to forward when pred is true. + if_false (ValueRef): Branch to forward when pred is false. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values. + Typically pred and values for the branches. + outputs (dict[PortID, NodeIndex]): Mapping from outer output names to respective + output nodes (by index) in the branches. + """ + pred: ValueRef if_true: ValueRef if_false: ValueRef - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["eifelse"] = field(default="eifelse") @dataclass class Input(NodeDefBase): + """An input node. + + :fields: + name (str): The name of the input value. + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values, + typically a single element. + outputs (dict[PortID, NodeIndex]): Typically not used on inputs, + for the sake of simplifying NodeData. + """ + name: str - inputs: dict[PortID, ValueRef] = field(default_factory=lambda: {}) + inputs: dict[PortID, ValueRef] = field(default_factory=dict) type: Literal["input"] = field(default="input") @dataclass class Output(NodeDefBase): + """An output node. + + :fields: + inputs (dict[PortID, ValueRef]): The mapping of inputs to their values, + typically a single element (e.g. computation -> output) + outputs (dict[PortID, NodeIndex]): Typically only forwards itself. + """ + inputs: dict[PortID, ValueRef] type: Literal["output"] = field(default="output") @@ -92,7 +209,19 @@ class Output(NodeDefBase): def in_edges(node: NodeDef) -> dict[PortID, ValueRef]: - parents = {k: v for k, v in node.inputs.items()} + """Find the incoming edges of a node. + + Finds all the defined inputs and adds the special constructions: + - Graph body for map, loop, eval + - Prediction for ifelse + - All nodes for eager if else + + :param node: The node to evaluate. + :type node: NodeDef + :return: MApping of port names to value references. + :rtype: dict[PortID, ValueRef] + """ + parents = dict(node.inputs.items()) match node.type: case "eval": @@ -114,6 +243,24 @@ def in_edges(node: NodeDef) -> dict[PortID, ValueRef]: class GraphData(BaseModel): + """The model of a computational graph. + + Encapsulates the entire computation. + Nodes are stored in a list, where the NodeIndex points to a unique node. + Graphs have a single output which can be a Struct of multiple fields. + + :fields: + nodes (list[NodeDef]): The list of nodes in a graph. + fixed_inputs (dict[PortID, OutputLoc]): A dict of fixed inputs for the graph. + They have values defined at construction time. + graph_inputs: (set[PortID]): A set of user defined inputs at runtime. + graph_output_idx (NodeIndex | None): The index of the output node. + Graphs must have exactly one output to run. + named_nodes (dict[str, NodeIndex]): Mapping of node names to their index in the + list. This is used for debug information. + + """ + nodes: list[NodeDef] = [] fixed_inputs: dict[PortID, OutputLoc] = {} graph_inputs: set[PortID] = set() @@ -121,19 +268,55 @@ class GraphData(BaseModel): named_nodes: dict[str, NodeIndex] = {} def input(self, name: str) -> ValueRef: + """Add an input name. + + :param name: The name of the input. + :type name: str + :return: The reference to that value. + :rtype: ValueRef + """ return self.add(Input(name))(name) def const(self, value: PType) -> ValueRef: + """Add a constant value. + + :param value: The value to add. + :type value: PType + :return: The reference to that value. + :rtype: ValueRef + """ return self.add(Const(value))("value") def func( - self, function_name: str, inputs: dict[PortID, ValueRef] + self, + function_name: str, + inputs: dict[PortID, ValueRef], ) -> Callable[[PortID], ValueRef]: + """Add a function node (task). + + :param function_name: The name of the function. + :type function_name: str + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Func(function_name, inputs)) def eval( - self, graph: ValueRef, inputs: dict[PortID, ValueRef] + self, + graph: ValueRef, + inputs: dict[PortID, ValueRef], ) -> Callable[[PortID], ValueRef]: + """Add an eval node. + + :param graph: The nested graph to evaluate. + :type graph: ValueRef + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Eval(graph, inputs)) def loop( @@ -143,6 +326,19 @@ def loop( continue_port: PortID, name: str | None = None, ) -> Callable[[PortID], ValueRef]: + """Add a loop node. + + :param body: The graph to loop over. + :type body: ValueRef + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :param continue_port: The termination criterion port. + :type continue_port: PortID + :param name: Name of the loop for tracing, defaults to None + :type name: str | None, optional + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Loop(body, inputs, continue_port, name=name)) def map( @@ -150,25 +346,83 @@ def map( body: ValueRef, inputs: dict[PortID, ValueRef], ) -> Callable[[PortID], ValueRef]: + """Add a map node. + + :param body: The graph to map over. + :type body: ValueRef + :param inputs: The mapping of the input values. + :type inputs: dict[PortID, ValueRef] + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(Map(body, inputs)) - def if_else(self, pred: ValueRef, if_true: ValueRef, if_false: ValueRef): + def if_else( + self, + pred: ValueRef, + if_true: ValueRef, + if_false: ValueRef, + ) -> Callable[[PortID], ValueRef]: + """Add an lazy if else node. + + :param pred: The reference to conditional value. + :type pred: ValueRef + :param if_true: The graph/value for the true branch. + :type if_true: ValueRef + :param if_false: The graph/value for the false branch. + :type if_false: ValueRef + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(IfElse(pred, if_true, if_false)) - def eager_if_else(self, pred: ValueRef, if_true: ValueRef, if_false: ValueRef): + def eager_if_else( + self, + pred: ValueRef, + if_true: ValueRef, + if_false: ValueRef, + ) -> Callable[[PortID], ValueRef]: + """Add an eager if else node. + + :param pred: The reference to conditional value. + :type pred: ValueRef + :param if_true: The graph/value for the true branch. + :type if_true: ValueRef + :param if_false: The graph/value for the false branch. + :type if_false: ValueRef + :return: A function returning index given an output. + :rtype: Callable[[PortID], ValueRef] + """ return self.add(EagerIfElse(pred, if_true, if_false)) def output(self, inputs: dict[PortID, ValueRef]) -> None: + """Add an output node. + + Computation -> output. + + :param inputs: The inputs of the outup node. + :type inputs: dict[PortID, ValueRef] + """ _ = self.add(Output(inputs)) def add(self, node: NodeDef) -> Callable[[PortID], ValueRef]: + """Add a node to the graph. + + :param node: The node to add. + :type node: NodeDef + :raises TierkreisError: If multiple outputs are added. + :return: A function given the output name of a node returns + the index of the node it corresponds to. + :rtype: Callable[[PortID], ValueRef] + """ idx = len(self.nodes) self.nodes.append(node) match node.type: case "output": if self.graph_output_idx is not None: + msg = f"Graph already has output at index {self.graph_output_idx}" raise TierkreisError( - f"Graph already has output at index {self.graph_output_idx}" + msg, ) self.graph_output_idx = idx @@ -188,22 +442,42 @@ def add(self, node: NodeDef) -> Callable[[PortID], ValueRef]: return lambda k: (idx, k) def output_idx(self) -> NodeIndex: + """Find the index of the graph output node. + + :raises TierkreisError: If the graph has no output. + :raises TierkreisError: It the node at the index is not an output. + :return: The index for the output node in self.nodes + :rtype: NodeIndex + """ idx = self.graph_output_idx if idx is None: - raise TierkreisError("Graph has no output index.") + msg = "Graph has no output index." + raise TierkreisError(msg) node = self.nodes[idx] if node.type != "output": - raise TierkreisError(f"Expected output node at {idx} found {node}") + msg = f"Expected output node at {idx} found {node}" + raise TierkreisError(msg) return idx def remaining_inputs(self, provided_inputs: set[PortID]) -> set[PortID]: + """Find the inputs for which no values are provided. + + :param provided_inputs: The list of already provided inputs. + :type provided_inputs: set[PortID] + :raises TierkreisError: If provided inputs would overwrite fixed inputs. + :return: A set of input names which don't have an associated value. + :rtype: set[PortID] + """ fixed_inputs = set(self.fixed_inputs.keys()) if fixed_inputs & provided_inputs: - raise TierkreisError( + msg = ( f"Fixed inputs {fixed_inputs}" - + f" should not intersect provided inputs {provided_inputs}." + f" should not intersect provided inputs {provided_inputs}." + ) + raise TierkreisError( + msg, ) actual_inputs = fixed_inputs.union(provided_inputs) @@ -214,15 +488,35 @@ def graph_node_from_loc( node_location: Loc, graph: GraphData, ) -> tuple[NodeDef, GraphData]: - """Assumes the first part of a loc can be found in current graph""" + """Find the node definition and graph of a nested graph given a loc. + + Nested graphs nodes are not indexed in their parent as their are + represented by a single node. E.g. g_1.eval(const(g_2)) will only produce a single + index although g_2 can contain many nodes. + Locs on the other hand contain this information e.g -.N0.L0.N-1 is a virtual eval + node. + This functions recursively steps trough nested graph definitions like this to find + a graph according to a flat loc. + Assumes the first part of a loc can be found in current graph. + + :param node_location: The loc to search for. + :type node_location: Loc + :param graph: The current graph to search in. + :type graph: GraphData + :raises TierkreisError: On an empty graph of a malformed Loc + :return: The node containing a graph and the graph itself. + :rtype: tuple[NodeDef, GraphData] + """ if len(graph.nodes) == 0: - raise TierkreisError("Cannot convert location to node. Reason: Empty Graph") + msg = "Cannot convert location to node. Reason: Empty Graph" + raise TierkreisError(msg) if node_location == "-": return Eval((-1, "body"), {}), graph step, remaining_location = node_location.pop_first() if isinstance(step, str): - raise TierkreisError("Cannot convert location: Reason: Malformed Loc") + msg = "Cannot convert location: Reason: Malformed Loc" + raise TierkreisError(msg) (_, node_id) = step if node_id == -1: return Eval((-1, "body"), {}), graph @@ -236,7 +530,7 @@ def graph_node_from_loc( case "loop" | "map": graph = _unwrap_graph(graph.nodes[node.body[0]], node.type) _, remaining_location = remaining_location.pop_first() # Remove the M0/L0 - if len(remaining_location.steps()) < 2: + if len(remaining_location.steps()) <= 1: return Eval((-1, "body"), node.inputs, outputs=node.outputs), graph node, graph = graph_node_from_loc(remaining_location, graph) @@ -251,9 +545,12 @@ def graph_node_from_loc( def _unwrap_graph(node: NodeDef, node_type: str) -> GraphData: """Safely unwraps a const nodes GraphData.""" if not isinstance(node, Const): - raise TierkreisError( + msg = ( f"Cannot convert location to node. Reason: {node_type} does not wrap const" ) + raise TierkreisError( + msg, + ) match node.value: case GraphData() as graph: return graph @@ -263,6 +560,7 @@ def _unwrap_graph(node: NodeDef, node_type: str) -> GraphData: return GraphData(**data) case _: + msg = "Cannot convert location to node. Reason: const value is not a graph" raise TierkreisError( - "Cannot convert location to node. Reason: const value is not a graph" + msg, ) diff --git a/tierkreis/tierkreis/controller/data/location.py b/tierkreis/tierkreis/controller/data/location.py index 1e5bbd8b3..3fb8cd683 100644 --- a/tierkreis/tierkreis/controller/data/location.py +++ b/tierkreis/tierkreis/controller/data/location.py @@ -1,46 +1,117 @@ +"""Data structures for node locations in the controller.""" + from logging import getLogger from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Self, assert_never from pydantic import BaseModel, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from tierkreis.controller.data.core import PortID -from typing_extensions import assert_never -from tierkreis.controller.data.core import NodeIndex +from tierkreis.controller.data.core import NodeIndex, PortID from tierkreis.exceptions import TierkreisError logger = getLogger(__name__) class WorkerCallArgs(BaseModel): + """The arguments for a worker call. + + Contains all information. + This will be provided as first + + :fields: + function_name (str): The name of the function to call. + inputs (dict[str, Path]): The input paths for the worker. + Each input is provided as a path to a file containing the input data. + outputs (dict[str, Path]): The output paths for the worker. + Each output is provided as a path to a file where the worker + should write the output data. + output_dir (Path): The directory for the worker outputs. + done_path (Path): The path to touch to indicate finished successfully. + error_path (Path): The path to touch to indicate finished with error. + logs_path (Path | None): The path to write logs to. + """ + function_name: str inputs: dict[str, Path] outputs: dict[str, Path] output_dir: Path done_path: Path error_path: Path - logs_path: Optional[Path] + logs_path: Path | None NodeStep = Literal["-"] | tuple[Literal["N", "L", "M"], NodeIndex] +MIN_LENGTH = 2 + class Loc(str): - def __new__(cls, k: str = "-") -> "Loc": - return super(Loc, cls).__new__(cls, k) + """The Loc(ation) of a node in storage. + + This is a string that encodes the path to the node in the graph. + """ + + __slots__ = [] + + def __new__(cls, k: str = "-") -> Self: + """Construct a new Loc. - def N(self, idx: int) -> "Loc": + :param k: The location string., defaults to "-" + :type k: str, optional + :return: The representation of self. + :rtype: Self + """ + return super().__new__(cls, k) + + def N(self, idx: int) -> "Loc": # noqa: N802 + """Append a regular node. + + Regular nodes are all but loops and maps. + + :param idx: The index of the node to append. + :type idx: int + :return: The new location. + :rtype: Loc + """ return Loc(str(self) + f".N{idx}") - def L(self, idx: int) -> "Loc": + def L(self, idx: int) -> "Loc": # noqa: N802 + """Append a loop node. + + L is one iteration of a loop. + + :param idx: The index of the loop iteration to append. + :type idx: int + :return: The new Location. + :rtype: Loc + """ return Loc(str(self) + f".L{idx}") - def M(self, idx: int) -> "Loc": + def M(self, idx: int) -> "Loc": # noqa: N802 + """Append a map node. + + M is one element of a map. + + :param idx: The index of the map iteration to append. + :type idx: int + :return: The new location. + :rtype: Loc + """ return Loc(str(self) + f".M{idx}") @staticmethod def from_steps(steps: list[NodeStep]) -> "Loc": + """Construct a location from steps. + + Steps are separated by '.' and consist of + a node type and index. + + :param steps: The list of steps to construct the location from. + :type steps: list[NodeStep] + :return: The new location. + :rtype: Loc + """ loc = "" for step in steps.copy(): match step: @@ -51,6 +122,11 @@ def from_steps(steps: list[NodeStep]) -> "Loc": return Loc(loc) def parent(self) -> "Loc | None": + """Return the parent of the node. + + :return: The parent location, or none if it is a root. + :rtype: Loc | None + """ steps = self.steps() if not steps: return None @@ -69,6 +145,15 @@ def parent(self) -> "Loc | None": assert_never(last_step) def steps(self) -> list[NodeStep]: + """Deconstruct self into steps. + + Steps are separated by '.' and consist of + a node type and index. + + :raises TierkreisError: On a malformed location string. + :return: The list of steps. + :rtype: list[NodeStep] + """ if self == "": return [] @@ -84,42 +169,80 @@ def steps(self) -> list[NodeStep]: case ("M", idx_str): steps.append(("M", int(idx_str))) case _: - raise TierkreisError(f"Invalid Loc: {self}") + msg = f"Invalid Loc: {self}" + raise TierkreisError(msg) return steps @classmethod def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler + cls, + source_type: Any, # noqa: ANN401 inherited from pydantic + handler: GetCoreSchemaHandler, ) -> CoreSchema: + """Make Loc work with pydantic.""" return core_schema.no_info_after_validator_function(cls, handler(str)) def pop_first(self) -> tuple[NodeStep, "Loc"]: + """Pop the first step of the location. + + The remaining location is still valid. + + e.g. for Loc("N0.L1.N3"), this returns ("N", 0) and Loc("L1.N3"). + + :raises TierkreisError: On a malformed location string. + :return: The first step and the remaining location. + :rtype: tuple[NodeStep, Loc] + """ if self == "-": return "-", Loc("") steps = self.steps() - if len(steps) < 2: - raise TierkreisError("Malformed Loc") + if len(steps) < MIN_LENGTH: + msg = "Malformed Loc" + raise TierkreisError(msg) first = steps.pop(1) if first == "-": - raise TierkreisError("Malformed Loc") + msg = "Malformed Loc" + raise TierkreisError(msg) return first, Loc.from_steps(steps) def pop_last(self) -> tuple[NodeStep, "Loc"]: + """Pop the last step of the location. + + The remaining location is still valid. + + e.g. for Loc("N0.L1.N3"), this returns ("N", 3) and Loc("N0.L1"). + + :raises TierkreisError: On a malformed location string. + :return: The last step and the remaining location. + :rtype: tuple[NodeStep, Loc] + """ if self == "-": return "-", Loc("") steps = self.steps() - if len(steps) < 2: - raise TierkreisError("Malformed Loc") + if len(steps) < MIN_LENGTH: + msg = "Malformed Loc" + raise TierkreisError(msg) last = steps.pop(-1) if last == "-": - raise TierkreisError("Malformed Loc") + msg = "Malformed Loc" + raise TierkreisError(msg) return last, Loc.from_steps(steps) def peek(self) -> NodeStep: + """Get the last step without removing. + + :return: The last step of the loc. + :rtype: NodeStep + """ return self.steps()[-1] def peek_index(self) -> int: + """Get the index of the last step. + + :return: The index of the last node in the locaction. + :rtype: int + """ step = self.steps()[-1] if isinstance(step, str): @@ -127,6 +250,12 @@ def peek_index(self) -> int: return step[1] def partial_locs(self) -> list["Loc"]: + """Generate all partial locations contained in self. + + :return: A list of all partial locations contained in self + including self and root. + :rtype: list[Loc] + """ steps = self.steps() return [Loc.from_steps(steps[: i + 1]) for i in range(len(steps))] diff --git a/tierkreis/tierkreis/controller/data/models.py b/tierkreis/tierkreis/controller/data/models.py index 50877a13a..d8b7783b6 100644 --- a/tierkreis/tierkreis/controller/data/models.py +++ b/tierkreis/tierkreis/controller/data/models.py @@ -1,3 +1,5 @@ +"""Models for type structures used in the graphbuilder.""" + from dataclasses import dataclass from inspect import isclass from itertools import chain @@ -13,7 +15,9 @@ overload, runtime_checkable, ) + from typing_extensions import TypeIs + from tierkreis.controller.data.core import ( NodeIndex, PortID, @@ -56,7 +60,8 @@ def value_ref(self) -> ValueRef: class TNamedModel(RestrictedNamedTuple[TKR[PType] | None], Protocol): """A struct whose members are restricted to being references to PTypes. - E.g. in graph builder code these are outputs of tasks.""" + E.g. in graph builder code these are outputs of tasks. + """ TModel = TNamedModel | TKR @@ -84,7 +89,7 @@ def is_portmapping( return hasattr(o, TKR_PORTMAPPING_FLAG) -def is_tnamedmodel(o) -> TypeIs[type[TNamedModel]]: +def is_tnamedmodel(o) -> TypeIs[type[TNamedModel]]: # noqa: ANN001 inherited from get_origin origin = get_origin(o) if origin is not None: return is_tnamedmodel(origin) @@ -115,10 +120,10 @@ def dict_from_tmodel(tmodel: TModel) -> dict[PortID, ValueRef]: def model_fields(model: type[PModel] | type[TModel]) -> list[str]: if is_portmapping(model): - return getattr(model, "_fields") + return model._fields if is_tnamedmodel(model): - return getattr(model, "_fields") + return model._fields return ["value"] @@ -134,7 +139,7 @@ def init_tmodel[T: TModel](tmodel: type[T], refs: list[ValueRef]) -> T: if get_origin(param) == Union: param = next(x for x in get_args(param) if x) args.append(param(ref[0], ref[1])) - return cast(T, model(*args)) + return cast("T", model(*args)) return tmodel(refs[0][0], refs[0][1]) diff --git a/tierkreis/tierkreis/controller/data/types.py b/tierkreis/tierkreis/controller/data/types.py index a5a170c40..7d3b8eff8 100644 --- a/tierkreis/tierkreis/controller/data/types.py +++ b/tierkreis/tierkreis/controller/data/types.py @@ -1,19 +1,21 @@ -from collections import defaultdict +"""Valid Python types for annotating worker functions and their serialisation.""" + +# ruff: noqa: ANN001 ANN003 ANN401 due to serialization and inheritance from json +import collections.abc +import json import logging +import pickle from base64 import b64decode, b64encode -import collections.abc +from collections import defaultdict +from collections.abc import Mapping, Sequence from inspect import Parameter, _empty, isclass from itertools import chain -import json -import pickle from types import NoneType, UnionType from typing import ( Annotated, Any, - Mapping, Protocol, Self, - Sequence, TypeVar, Union, assert_never, @@ -25,6 +27,8 @@ from pydantic import BaseModel, ValidationError from pydantic._internal._generics import get_args as pydantic_get_args +from typing_extensions import TypeIs + from tierkreis.controller.data.core import ( RestrictedNamedTuple, SerializationFormat, @@ -32,7 +36,6 @@ get_serializer, ) from tierkreis.exceptions import TierkreisError -from typing_extensions import TypeIs @runtime_checkable @@ -40,25 +43,50 @@ class NdarraySurrogate(Protocol): """A protocol to enable use of numpy.ndarray. By default the serialisation will be done using dumps - and the deserialisation using `pickle.loads`.""" + and the deserialisation using `pickle.loads`. + + The semantics are left to the implementor. + """ + + def dumps(self) -> bytes: + """Dump self to bytes.""" + ... - def dumps(self) -> bytes: ... - def tobytes(self) -> bytes: ... - def tolist(self) -> list: ... + def tobytes(self) -> bytes: + """Transform self to bytes.""" + ... + + def tolist(self) -> list: + """Convert self to a list.""" + ... @runtime_checkable class DictConvertible(Protocol): - def to_dict(self) -> dict: ... + """A protocol for types that can be converted to and from dicts.""" + + def to_dict(self) -> dict: + """Convert self to a dict.""" + ... + @classmethod - def from_dict(cls, arg: dict, /) -> "Self": ... + def from_dict(cls, arg: dict, /) -> "Self": + """Construct self from a dict.""" + ... @runtime_checkable class ListConvertible(Protocol): - def to_list(self) -> list: ... + """A protocol for types that can be converted to and from lists.""" + + def to_list(self) -> list: + """Convert self to a list.""" + ... + @classmethod - def from_list(cls, arg: list, /) -> "Self": ... + def from_list(cls, arg: list, /) -> "Self": + """Construct self from a list.""" + ... type Container[T] = ( @@ -87,7 +115,8 @@ def from_list(cls, arg: list, /) -> "Self": ... @runtime_checkable -class Struct(RestrictedNamedTuple[JsonType], Protocol): ... +class Struct(RestrictedNamedTuple[JsonType], Protocol): + """Supertype for structs, which are named tuples with JSON-serialisable fields.""" _StructPType = JsonType | Struct @@ -99,7 +128,14 @@ class Struct(RestrictedNamedTuple[JsonType], Protocol): ... class TierkreisEncoder(json.JSONEncoder): """Encode bytes also.""" - def default(self, o): + def default(self, o) -> dict[str, Any] | dict[str, list[float]] | Any: + """Call the default tierkreis serializer. + + :param o: The object to serialize. + :type o: _type_ + :return: The serialized object. + :rtype: dict[str, Any] | dict[str, list[float]] | Any + """ if isinstance(o, bytes): return {"__tkr_bytes__": True, "bytes": b64encode(o).decode()} @@ -112,11 +148,11 @@ def default(self, o): class TierkreisDecoder(json.JSONDecoder): """Decode bytes also.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: kwargs.setdefault("object_hook", self._object_hook) super().__init__(**kwargs) - def _object_hook(self, d): + def _object_hook(self, d) -> bytes | complex | Any: """Try to decode an object containing bytes.""" if "__tkr_bytes__" in d and "bytes" in d: return b64decode(d["bytes"]) @@ -129,25 +165,18 @@ def _object_hook(self, d): def _is_union(o: object) -> bool: return ( - get_origin(o) == UnionType - or get_origin(o) == Union - or o == Union - or o == UnionType + o in (Union, UnionType) or get_origin(o) == UnionType or get_origin(o) == Union ) def is_optional(t: type) -> bool: - """Check that the origin of `t` is a Union - and the args of `t` contains NoneType.""" + """Check that the origin of `t` is a Union and the args of `t` contains NoneType.""" origin = get_origin(t) if origin is None: return False is_union = origin == Union or (isclass(origin) and issubclass(origin, UnionType)) - if is_union and NoneType in get_args(t): - return True - - return False + return bool(is_union and NoneType in get_args(t)) def _is_generic(o) -> TypeIs[type[TypeVar]]: @@ -167,6 +196,13 @@ def _is_tuple(o: object) -> TypeIs[type[tuple[Any, ...]]]: def is_ptype(annotation: Any) -> TypeIs[type[PType]]: + """Check if a type annotation is a PType. + + :param annotation: The annotation to check. + :type annotation: Any + :return: The according TypeIs if the annotation is a PType, otherwise False. + :rtype: TypeIs[type[PType]] + """ if get_origin(annotation) is Annotated: return is_ptype(get_args(annotation)[0]) @@ -181,24 +217,34 @@ def is_ptype(annotation: Any) -> TypeIs[type[PType]]: ): return all(is_ptype(x) for x in get_args(annotation)) - elif isclass(annotation) and issubclass( - annotation, - (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel, Struct), - ): - return True - - elif annotation in get_args(ElementaryType.__value__): + if ( + isclass(annotation) + and issubclass( + annotation, + (DictConvertible, ListConvertible, NdarraySurrogate, BaseModel, Struct), + ) + ) or annotation in get_args(ElementaryType.__value__): return True origin = get_origin(annotation) if origin is not None: return is_ptype(origin) and all(is_ptype(x) for x in get_args(annotation)) - else: - return False + return False + + +def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> JsonType: + """Get the json serializable type of a ptype value. + Potentially uses a custom serializer if the annotation has one. -def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> Any: + :param ptype: The type to serialize. + :type ptype: PType + :param annotation: The annotation of the ptype, if available. + :type annotation: type[PType] | None + :return: The serialized ptype. + :rtype: JsonType + """ if sr := get_serializer(annotation): return sr.serializer(ptype) @@ -232,6 +278,15 @@ def ser_from_ptype(ptype: PType, annotation: type[PType] | None) -> Any: def bytes_from_ptype(ptype: PType, annotation: type[PType] | None = None) -> bytes: + """Get the bytes representation of a ptype value. + + :param ptype: The ptype value to convert to bytes. + :type ptype: PType + :param annotation: The annotation of the ptype, if available. + :type annotation: type[PType] | None, optional + :return: The bytes representation of the ptype value. + :rtype: bytes + """ ser = ser_from_ptype(ptype, annotation) match ser: case bytes(): @@ -241,6 +296,18 @@ def bytes_from_ptype(ptype: PType, annotation: type[PType] | None = None) -> byt def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: + """Find the value of type T from a serialized form. + + Uses the annotation to find the correct deserialization method, if available. + + :param ser: The value to coerce. + :type ser: Any + :param annotation: The annotation to coerce to, if available. + :type annotation: type[T] | None, optional + :raises TierkreisError: If the value cannot be coerced to the annotation. + :return: The coerced value. + :rtype: T + """ if annotation is None: return ser @@ -255,8 +322,9 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: try: return coerce_from_annotation(ser, t) except (AssertionError, ValidationError): - logger.debug(f"Tried deserialising as {t}") - raise TierkreisError(f"Could not deserialise {ser} as {annotation}") + logger.debug("Tried deserialising as %s", t) + msg = f"Could not deserialise {ser} as {annotation}" + raise TierkreisError(msg) origin = get_origin(annotation) if origin is None: @@ -276,18 +344,24 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: return ser if issubclass(origin, DictConvertible): - assert issubclass(annotation, origin) + if not issubclass(annotation, origin): + msg = "Invalid subclass relation encountered." + raise TypeError(msg) return annotation.from_dict(ser) if issubclass(origin, ListConvertible): - assert issubclass(annotation, origin) + if not issubclass(annotation, origin): + msg = "Invalid subclass relation encountered." + raise TypeError(msg) return annotation.from_list(ser) if issubclass(origin, NdarraySurrogate): return pickle.loads(ser) if issubclass(origin, BaseModel): - assert issubclass(annotation, origin) + if not issubclass(annotation, origin): + msg = "Invalid subclass relation encountered." + raise TypeError(msg) return annotation(**ser) if issubclass(origin, Struct): @@ -295,21 +369,24 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: k: coerce_from_annotation(ser[k], v) for k, v in origin.__annotations__.items() } - return cast(T, origin(**d)) + return cast("T", origin(**d)) if issubclass(origin, collections.abc.Sequence): args = get_args(annotation) if len(args) == 0: return ser - return cast(T, [coerce_from_annotation(x, args[0]) for x in ser]) + return cast("T", [coerce_from_annotation(x, args[0]) for x in ser]) if issubclass(origin, collections.abc.Mapping): args = get_args(annotation) if len(args) == 0: return ser - return cast(T, {k: coerce_from_annotation(v, args[1]) for k, v in ser.items()}) + return cast( + "T", + {k: coerce_from_annotation(v, args[1]) for k, v in ser.items()}, + ) assert_never(ser) @@ -317,6 +394,15 @@ def coerce_from_annotation[T: PType](ser: Any, annotation: type[T] | None) -> T: def get_serialization_format[T: PType]( hint: type[T] | None = None, ) -> SerializationFormat: + """Find the serializaiton format to a type hint. + + Returns 'unknown' for None. + + :param hint: The type hint to find the serialization format for. + :type hint: type[T] | None, optional + :return: The serialization format for the given type hint. + :rtype: SerializationFormat + """ if hint is None: return "unknown" @@ -331,6 +417,15 @@ def get_serialization_format[T: PType]( def ptype_from_bytes[T: PType](bs: bytes, annotation: type[T] | None = None) -> T: + """Get the value with the correct type from its bytes. + + :param bs: The bytes to deserialize. + :type bs: bytes + :param annotation: The annotation to use for deserialization, if available. + :type annotation: type[T] | None, optional + :return: The deserialized value of type T. + :rtype: T + """ method = get_serialization_format(annotation) match method: case "bytes": @@ -343,12 +438,19 @@ def ptype_from_bytes[T: PType](bs: bytes, annotation: type[T] | None = None) -> j = json.loads(bs, cls=TierkreisDecoder) return coerce_from_annotation(j, annotation) except (json.JSONDecodeError, UnicodeDecodeError): - return cast(T, bs) + return cast("T", bs) case _: assert_never(method) def generics_in_ptype(ptype: type[PType]) -> set[str]: + """Get the generics in a type annotation. + + :param ptype: The ptype to extract generics from. + :type ptype: type[PType] + :return: The set of generic names in the ptype. + :rtype: set[str] + """ if _is_generic(ptype): return {str(ptype)} @@ -366,10 +468,17 @@ def generics_in_ptype(ptype: type[PType]) -> set[str]: return set() if issubclass(ptype, BaseModel): - return set((str(x) for x in pydantic_get_args(ptype))) + return {str(x) for x in pydantic_get_args(ptype)} assert_never(ptype) def has_default(t: Parameter) -> bool: + """Check if a parameter has a default value. + + :param t: The parameter to check. + :type t: Parameter + :return: True if the parameter has a default value, False otherwise. + :rtype: bool + """ return not (isclass(t.default) and issubclass(t.default, _empty)) diff --git a/tierkreis/tierkreis/controller/executor/__init__.py b/tierkreis/tierkreis/controller/executor/__init__.py index e69de29bb..a8511eb76 100644 --- a/tierkreis/tierkreis/controller/executor/__init__.py +++ b/tierkreis/tierkreis/controller/executor/__init__.py @@ -0,0 +1 @@ +"""Tierkreis executors to launch worker tasks.""" diff --git a/tierkreis/tierkreis/controller/executor/check_launcher.py b/tierkreis/tierkreis/controller/executor/check_launcher.py index 5e35610de..89b2af503 100644 --- a/tierkreis/tierkreis/controller/executor/check_launcher.py +++ b/tierkreis/tierkreis/controller/executor/check_launcher.py @@ -1,3 +1,5 @@ +"""Utilities to find the correct executable for a worker.""" + import logging from pathlib import Path from typing import Literal @@ -8,42 +10,70 @@ def check_and_set_launcher( - launcher_path: Path, launcher_name: str, suffix: Literal[".sh", ".py"] + launcher_path: Path, + launcher_name: str, + suffix: Literal[".sh", ".py"], ) -> Path: + """Find the correct executable for a worker. + + Given the directory and a worker name searches for + 1. main.py (.sh) + 2. src/main.py (.sh) + + :param launcher_path: The directory to search. + :type launcher_path: Path + :param launcher_name: The name of the worker to find. + :type launcher_name: str + :param suffix: External or internal worker (.py or .sh). + :type suffix: Literal['.sh', '.py'] + :raises TierkreisError: If neither of the expected paths exist. + :return: The full path to the worker executable. + :rtype: Path + """ try: path = _exists(launcher_path, launcher_name, suffix) - logger.warning( - "Placing the launcher in the root directory is deprecated.\n Please move it to a 'src' subdirectory." - ) - return path except TierkreisError as e: try: return _exists(launcher_path, launcher_name, suffix, add_src=True) except TierkreisError as ef: + msg = ( + f"Launcher '{launcher_name}' not found in" + f" '{launcher_path}' or '{launcher_path}/src'." + ) raise ExceptionGroup( - f"Launcher '{launcher_name}' not found in '{launcher_path}' or '{launcher_path}/src'.", + msg, [e, ef], ) from ef + else: + logger.warning( + "Placing the launcher in the root directory is deprecated.\n" + "Please move it to a 'src' subdirectory.", + ) + return path def _exists( launcher_path: Path, launcher_name: str, suffix: Literal[".sh", ".py"], + *, add_src: bool = False, ) -> Path: launcher_path = launcher_path / launcher_name if add_src: launcher_path = launcher_path / "src" if not launcher_path.exists(): - raise TierkreisError(f"Launcher not found: {launcher_name}.") + msg = f"Launcher not found: {launcher_name}." + raise TierkreisError(msg) if launcher_path.is_dir() and not (launcher_path / f"main{suffix}").exists(): - raise TierkreisError(f"Expected launcher file. Got {launcher_path}.") + msg = f"Expected launcher file. Got {launcher_path}." + raise TierkreisError(msg) if launcher_path.is_dir() and not (launcher_path / f"main{suffix}").is_file(): + msg = f"Expected launcher file. Got {launcher_path}/main{suffix}" raise TierkreisError( - f"Expected launcher file. Got {launcher_path}/main{suffix}" + msg, ) if launcher_path.is_dir() and (launcher_path / f"main{suffix}").is_file(): launcher_path = launcher_path / f"main{suffix}" diff --git a/tierkreis/tierkreis/controller/executor/commands.py b/tierkreis/tierkreis/controller/executor/commands.py index 11b9fab09..707600979 100644 --- a/tierkreis/tierkreis/controller/executor/commands.py +++ b/tierkreis/tierkreis/controller/executor/commands.py @@ -1,11 +1,14 @@ +"""Utilities to generagete command strings for executors.""" + from pathlib import Path def add_std_handlers(workflow_logs: Path, node_logs: Path, command: str) -> str: """Pipe stdout and stderr to `workflow_logs` and `node_logs`. - If the `command` returns with a non-zero exit code then touch the appropriate _error file. - """ + If the `command` returns with a non-zero exit code, + then touch the appropriate _error file. + """ _error_path = node_logs.parent / "_error" - tee_str = f">(tee -a {str(node_logs)} {str(workflow_logs)} >/dev/null)" + tee_str = f">(tee -a {node_logs!s} {workflow_logs!s} >/dev/null)" return f"({command} > {tee_str} 2> {tee_str} || touch {_error_path})" diff --git a/tierkreis/tierkreis/controller/executor/hpc/__init__.py b/tierkreis/tierkreis/controller/executor/hpc/__init__.py index e69de29bb..79c77e347 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/__init__.py +++ b/tierkreis/tierkreis/controller/executor/hpc/__init__.py @@ -0,0 +1 @@ +"""Collection of HPC executors.""" diff --git a/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py b/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py index f99167e40..f400a32af 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py +++ b/tierkreis/tierkreis/controller/executor/hpc/hpc_executor.py @@ -1,19 +1,35 @@ +"""Interface implementation for HPCExecutors.""" + import logging import subprocess +from collections.abc import Callable from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Callable, Protocol +from typing import Protocol from tierkreis.consts import TKR_DIR_KEY from tierkreis.controller.executor.commands import add_std_handlers from tierkreis.controller.executor.hpc.job_spec import JobSpec from tierkreis.exceptions import TierkreisError - logger = logging.getLogger(__name__) class HPCExecutor(Protocol): + """Generic protocol for an HPC executor. + + :fields: + launchers_path (Path | None): The locations to search for workers. + This will change the location from where the command is invoked + by appending "cd launchers_path && " + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + spec (JobSpec): A definition of the job specification. + script_fn (Callable[[JobSpec], str]): A template function to generate the + submission script from. + command (str): The base command to use. + """ + launchers_path: Path | None logs_path: Path errors_path: Path @@ -23,15 +39,42 @@ class HPCExecutor(Protocol): def generate_script( - template_fn: Callable[[JobSpec], str], spec: JobSpec, path: Path + template_fn: Callable[[JobSpec], str], + spec: JobSpec, + path: Path, ) -> None: - with open(path, "w+", encoding="utf-8") as fh: + """Generate a scheduler script by calling a template function. + + :param template_fn: The template function to call. + :type template_fn: Callable[[JobSpec], str] + :param spec: The job definition to generate the script for. + :type spec: JobSpec + :param path: The path to save the script to. + :type path: Path + """ + with Path.open(path, "w+", encoding="utf-8") as fh: fh.write(template_fn(spec)) def run_hpc_executor( - executor: HPCExecutor, launcher_name: str, worker_call_args_path: Path + executor: HPCExecutor, + launcher_name: str, + worker_call_args_path: Path, ) -> None: + """Run a worker function on an HPC executor. + + This is a generic function to run with with an HPC executor. + Similar to the :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + run function. + + :param executor: The executor to use for running + :type executor: HPCExecutor + :param launcher_name: Module description for the worker to run + :type launcher_name: str + :param worker_call_args_path: Location of the worker call args. + :type worker_call_args_path: Path + :raises TierkreisError: When job submission fails. + """ logger.info("START %s %s", launcher_name, worker_call_args_path) spec = executor.spec.model_copy() @@ -40,7 +83,9 @@ def run_hpc_executor( spec.command += " " + str(worker_call_args_path) spec.command = add_std_handlers( - executor.logs_path, executor.errors_path, spec.command + executor.logs_path, + executor.errors_path, + spec.command, ) submission_cmd = [executor.command] @@ -63,21 +108,20 @@ def run_hpc_executor( submission_cmd, start_new_session=True, capture_output=True, - universal_newlines=True, + text=True, + check=False, ) - with open(executor.logs_path, "a+") as fh: + with Path.open(executor.logs_path, "a+") as fh: fh.write(process.stdout) - with open(executor.errors_path, "a+") as fh: + with Path.open(executor.errors_path, "a+") as fh: fh.write(process.stdout) if process.returncode != 0: - with open(executor.errors_path, "a") as efh: + with Path.open(executor.errors_path, "a") as efh: efh.write("Error from script") efh.write(process.stderr) - print(process.stderr) - print("\n\npjsub script\n\n") - print(executor.script_fn(spec)) - raise TierkreisError(f"Executor failed with return code {process.returncode}") + msg = f"Executor failed with return code {process.returncode}" + raise TierkreisError(msg) diff --git a/tierkreis/tierkreis/controller/executor/hpc/job_spec.py b/tierkreis/tierkreis/controller/executor/hpc/job_spec.py index fd4335b46..5073a3ce9 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/job_spec.py +++ b/tierkreis/tierkreis/controller/executor/hpc/job_spec.py @@ -1,14 +1,41 @@ -from pathlib import Path +"""Definition of HPC resource classes. + +These are used to map resource requirements to the respective +settings in resource management systems (schedulers). +Using the value `None` typically will unset the flag. +It is not guaranteed that all schedulers can realize all the configurations. +""" + import platform +from pathlib import Path + from pydantic import BaseModel, Field class MpiSpec(BaseModel): + """MPI configuration. + + :fields: + max_proc_per_node (int | None): Number of MPI processes per compute node, + defaults to 1. + proc (int | None): Number of MPI processes (ranks), defaults to None (unset). + + """ + max_proc_per_node: int | None = 1 proc: int | None = None class ResourceSpec(BaseModel): + """General resource definitions. + + :fields: + nodes (int): Number of compute nodes, defaults to 1. + cores_per_node (int | None): Number of cores to ues per node, defaults to 1. + memory_gb (int | None): Memory per node in GB, defaults to 4. + gpus_per_nod (int | None): Physical GPUs to reserve on the node, defaults to 0. + """ + nodes: int = 1 cores_per_node: int | None = 1 memory_gb: int | None = 4 @@ -16,10 +43,31 @@ class ResourceSpec(BaseModel): class UserSpec(BaseModel): + """User specific configuration. + + :fields: + mail (str | None): User email to send job updates. + """ + mail: str | None = None # some clusters require this class ContainerSpec(BaseModel): + """Configuration for the use of container images in HPC. + + :warning: + Not fully supported yet. + + :fields: + images (str): URL to the container image. + engine (str): which engine to use. + name (str | None): Explicit image name, defaults to None. + extra_args (dict[str, str | None]): Environment args to pass to the container, + defaults to {}. + env_file (str | None): Path to a file with variable export definitions, + defaults to None. + """ + image: str engine: str # e.g. singularity, docker, enroot? name: str | None = None @@ -28,11 +76,44 @@ class ContainerSpec(BaseModel): class JobSpec(BaseModel): + """Resource definition for an HPC job. + + This is used to generate the job script for the scheduler. + + :fields: + job_name (str): Reference name for the job. + command (str): The command to execute on hpc. + E.g. "mpi run ..." + resource (ResourceSpec): Resource specification for the job. + account: (str | None): Account or group used to submit this job, + defaults to None. + mpi: (MpiSpec | None): The MPI specification. If this is set, will prepend + "mpirun" to the command string, defaults to None. + container: (ContainerSpec | None): The container specification for the job, + defaults to None. + walltime (str): Maximum walltime of the job in HH:MM:SS format, + defaults to "01:00:00". + queue: (str | None): Named queue to submit to, HPC center specific, + defaults to None. + output_path: (Path | None): Explicit job output, if not tkr output will be used, + defaults to None. + error_path: (Path | None): Explicit error output, if not tkr output will be used + defaults to None. + extra_scheduler_args (dict[str, str | None]): Configure additional flags and + options that are not provided in the spec. Flags are set as + extra_scheduler_args["flag_name"] = None, options set as + extra_scheduler_args["option_name"] = "option_value". + Defaults to {} + environment: (dict[str, str]): Provide additional environment variables to the + job, defaults to {}. + include_no_check_directory_flag: (bool): Set "--no-check-directory", + defaults to false. + """ + job_name: str command: str # used instead of popen.input resource: ResourceSpec account: str | None = None - """Account or group used to submit this job.""" mpi: MpiSpec | None = None user: UserSpec | None = None container: ContainerSpec | None = None @@ -46,12 +127,17 @@ class JobSpec(BaseModel): def pjsub_large_spec() -> JobSpec: + """Generate an example large job specification for FUGAKU. + + :return: A job spec running uv on FUGAKU. + :rtype: JobSpec + """ arch = platform.machine() uv_path = Path.home() / ".local" / f"bin_{arch}" / "uv" return JobSpec( job_name="pjsub_large", account="hp240496", - command=f"{str(uv_path)} run main.py", + command=f"{uv_path!s} run main.py", queue="q-QTM-M", resource=ResourceSpec(nodes=32), environment={ @@ -69,12 +155,17 @@ def pjsub_large_spec() -> JobSpec: def pjsub_small_spec() -> JobSpec: + """Generate an example small job specification for FUGAKU. + + :return: A job spec running uv on FUGAKU. + :rtype: JobSpec + """ arch = platform.machine() uv_path = Path.home() / ".local" / f"bin_{arch}" / "uv" return JobSpec( job_name="pjsub_small", account="hp240496", - command=f"{str(uv_path)} run main.py", + command=f"{uv_path!s} run main.py", resource=ResourceSpec(nodes=1), environment={ "VIRTUAL_ENVIRONMENT": "", diff --git a/tierkreis/tierkreis/controller/executor/hpc/pbs.py b/tierkreis/tierkreis/controller/executor/hpc/pbs.py index bf77b5728..1869929ef 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/pbs.py +++ b/tierkreis/tierkreis/controller/executor/hpc/pbs.py @@ -1,21 +1,33 @@ +"""Template and Executor for PBS.""" + +# ruff: noqa: ERA001 from pathlib import Path # from typing import Callable # from tierkreis.controller.executor.hpc.hpc_executor import run_hpc_executor from tierkreis.controller.executor.hpc.job_spec import JobSpec - _COMMAND_PREFIX = "#PBS" -def generate_pbs_script(spec: JobSpec) -> str: +def generate_pbs_script(spec: JobSpec) -> str: # noqa: C901, PLR0912 complexity to cover options + """Generate a job submission script according to PBS. + + This uses the "PBS"/qsub syntax and represents a mapping from JobSpec + to the native flags. + + :param spec: The job to generate a script for. + :type spec: JobSpec + :return: A job script for the PBS scheduler. + :rtype: str + """ # 1. Shebang and file header lines = [ """#!/bin/bash # # PBS Job Script generated by TIERKREIS -# --- Core Job Specifications ---""" +# --- Core Job Specifications ---""", ] # 2. Name lines.append(f"{_COMMAND_PREFIX} -N {spec.job_name}") @@ -40,10 +52,9 @@ def generate_pbs_script(spec: JobSpec) -> str: lines.append("\n# --- User Details ---") if spec.account is not None: lines.append(f"{_COMMAND_PREFIX} -A {spec.account}") - if spec.user is not None: - if spec.user.mail is not None: - lines.append(f"{_COMMAND_PREFIX} -m e") # end only - lines.append(f"{_COMMAND_PREFIX} -M {spec.user.mail}") + if spec.user is not None and spec.user.mail is not None: + lines.append(f"{_COMMAND_PREFIX} -m e") # end only + lines.append(f"{_COMMAND_PREFIX} -M {spec.user.mail}") # 5. Output and Error handling lines.append("\n# --- Output and Error Handling ---") @@ -69,8 +80,7 @@ def generate_pbs_script(spec: JobSpec) -> str: lines.append("\n# --- Environment Setup ---") if spec.environment != {}: env = ",".join( - f"{key}={value if value else '""'}" - for key, value in spec.environment.items() + f"{key}={value or '""'}" for key, value in spec.environment.items() ) lines.append(f"-v w{env}") # 9. Container logic # taken from nscc docs for enroot: @@ -82,7 +92,8 @@ def generate_pbs_script(spec: JobSpec) -> str: lines.append(f"{_COMMAND_PREFIX} -l {key}={value}") if spec.container.env_file is not None: lines.append( - f"{_COMMAND_PREFIX} -l {spec.container.engine}_env_file={spec.container.env_file}" + f"{_COMMAND_PREFIX} -l {spec.container.engine}_env_file" + f"={spec.container.env_file}", ) # check if this makes sense for others beside enroot # 10. User Command, (prologue), command, (epilogue) @@ -93,7 +104,7 @@ def generate_pbs_script(spec: JobSpec) -> str: # Disabled for now, needs testing with a PBS system, will be re-enabled later -# See: Issue #182 +# See: TODO@philipp-seitz: Issue #182 # class PBSExecutor: # def __init__( # self, diff --git a/tierkreis/tierkreis/controller/executor/hpc/pjsub.py b/tierkreis/tierkreis/controller/executor/hpc/pjsub.py index 96ee92988..c38406410 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/pjsub.py +++ b/tierkreis/tierkreis/controller/executor/hpc/pjsub.py @@ -1,7 +1,8 @@ -# from functools import partial +"""Template and Executor for PJSUB(FUGAKU).""" + from functools import partial from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING from uuid import uuid4 from tierkreis.controller.executor.hpc.hpc_executor import run_hpc_executor @@ -11,18 +12,30 @@ pjsub_small_spec, ) +if TYPE_CHECKING: + from collections.abc import Callable _COMMAND_PREFIX = "#PJM" -def generate_pjsub_script(spec: JobSpec) -> str: +def generate_pjsub_script(spec: JobSpec) -> str: # noqa: C901 complexity to cover options + """Generate a job submission script according to PJSUB. + + This uses the "PJM"/pjsub syntax and represents a mapping from JobSpec + to the native flags. + + :param spec: The job to generate a script for. + :type spec: JobSpec + :return: A job script for the PJSUB scheduler. + :rtype: str + """ # 1. Shebang and file header lines = [ """#!/bin/bash # # PJSUB Job Script generated by TIERKREIS -# --- Core Job Specifications ---""" +# --- Core Job Specifications ---""", ] # 2. Name lines.append(f"{_COMMAND_PREFIX} -N {spec.job_name}") @@ -40,12 +53,12 @@ def generate_pjsub_script(spec: JobSpec) -> str: lines.append("\n# --- User Details ---") if spec.account is not None: lines.append(f"{_COMMAND_PREFIX} -g {spec.account}") - if spec.user is not None: - if spec.user.mail is not None: - lines.append(f"{_COMMAND_PREFIX} -m e") # end only - lines.append(f"{_COMMAND_PREFIX} --mail-list {spec.user.mail}") + if spec.user is not None and spec.user.mail is not None: + lines.append(f"{_COMMAND_PREFIX} -m e") # end only + lines.append(f"{_COMMAND_PREFIX} --mail-list {spec.user.mail}") - # 5. Output and Error handling uses Bash because pjsub always overwrites instead of appends. + # 5. Output and Error handling + # uses bash because pjsub always overwrites instead of appends. # So redirect to temporary files. lines.append("\n# --- Output and Error Handling ---") lines.append(f"{_COMMAND_PREFIX} -j") @@ -58,7 +71,8 @@ def generate_pjsub_script(spec: JobSpec) -> str: lines.append(f'{_COMMAND_PREFIX} --mpi "proc={spec.mpi.proc}"') if spec.mpi.max_proc_per_node is not None: lines.append( - f'{_COMMAND_PREFIX} --mpi "max-proc-per-node={spec.mpi.max_proc_per_node}"' + f'{_COMMAND_PREFIX} --mpi "max-proc-per-node' + f'={spec.mpi.max_proc_per_node}"', ) # 7. User specific @@ -77,13 +91,19 @@ def generate_pjsub_script(spec: JobSpec) -> str: lines.append("\n# --- User Command ---") lines.append(spec.command) - with open("./script", "w+") as fh: + with Path.open(Path("./script"), "w+") as fh: fh.write("\n".join(lines)) return "\n".join(lines) class PJSUBExecutor: + """An executor for the PJSUB submission system. + + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + Implements: :py:class:`tierkreis.controller.executor.hpc.hpc_executor.HPCExecutor` + """ + def __init__( self, registry_path: Path | None, @@ -103,6 +123,13 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: + """Run the node according to ControllerExecutor protocol. + + :param launcher_name: module description of worker to run. + :type launcher_name: str + :param worker_call_args_path: Location of the worker call args. + :type worker_call_args_path: Path + """ self.errors_path = ( self.logs_path.parent.parent / worker_call_args_path.parent / "errors" ) diff --git a/tierkreis/tierkreis/controller/executor/hpc/slurm.py b/tierkreis/tierkreis/controller/executor/hpc/slurm.py index a6d7f5d28..7cd002505 100644 --- a/tierkreis/tierkreis/controller/executor/hpc/slurm.py +++ b/tierkreis/tierkreis/controller/executor/hpc/slurm.py @@ -1,20 +1,35 @@ +"""Template and executor for SLURM.""" + from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING + from tierkreis.controller.executor.hpc.hpc_executor import run_hpc_executor from tierkreis.controller.executor.hpc.job_spec import JobSpec +if TYPE_CHECKING: + from collections.abc import Callable _COMMAND_PREFIX = "#SBATCH" -def generate_slurm_script(spec: JobSpec) -> str: +def generate_slurm_script(spec: JobSpec) -> str: # noqa: C901, PLR0912 complexity to cover options + """Generate a job submission script according to SLURM. + + This uses the "sbatch" syntax and represents a mapping from JobSpec + to the native flags. + + :param spec: The job to generate a script for. + :type spec: JobSpec + :return: A job script for the SLURM scheduler. + :rtype: str + """ # 1. Shebang and file header lines = [ """#!/bin/bash # # SLURM Job Script generated by TIERKREIS -# --- Core Job Specifications ---""" +# --- Core Job Specifications ---""", ] # 2. Name lines.append(f"{_COMMAND_PREFIX} --job-name={spec.job_name}") @@ -23,7 +38,7 @@ def generate_slurm_script(spec: JobSpec) -> str: lines.append(f"{_COMMAND_PREFIX} --nodes={spec.resource.nodes}") if spec.resource.cores_per_node is not None: lines.append( - f"{_COMMAND_PREFIX} --cpus-per-task={spec.resource.cores_per_node}" + f"{_COMMAND_PREFIX} --cpus-per-task={spec.resource.cores_per_node}", ) if spec.resource.memory_gb is not None: lines.append(f"{_COMMAND_PREFIX} --mem={spec.resource.memory_gb}G") @@ -36,10 +51,9 @@ def generate_slurm_script(spec: JobSpec) -> str: lines.append("\n# --- User Details ---") if spec.account is not None: lines.append(f"{_COMMAND_PREFIX} --account={spec.account}") - if spec.user is not None: - if spec.user.mail is not None: - lines.append(f"{_COMMAND_PREFIX} --mail-type=END") # end only - lines.append(f"{_COMMAND_PREFIX} --mail-user={spec.user.mail}") + if spec.user is not None and spec.user.mail is not None: + lines.append(f"{_COMMAND_PREFIX} --mail-type=END") # end only + lines.append(f"{_COMMAND_PREFIX} --mail-user={spec.user.mail}") # 5. Output and Error handling lines.append("\n# --- Output and Error Handling ---") @@ -48,14 +62,14 @@ def generate_slurm_script(spec: JobSpec) -> str: if spec.output_path is not None: lines.append(f"{_COMMAND_PREFIX} --output={spec.output_path}") - # 6. MPI, #TODO check if this makes sense + # 6. MPI, #TODO@philipp-seitz: check if this makes sense if spec.mpi is not None: lines.append("\n# --- MPI ---") if spec.mpi.proc is not None: lines.append(f"{_COMMAND_PREFIX} --ntasks={spec.mpi.proc}") if spec.mpi.max_proc_per_node is not None: lines.append( - f"{_COMMAND_PREFIX} --ntasks-per-node={spec.mpi.max_proc_per_node}" + f"{_COMMAND_PREFIX} --ntasks-per-node={spec.mpi.max_proc_per_node}", ) # 7. User specific @@ -67,8 +81,7 @@ def generate_slurm_script(spec: JobSpec) -> str: lines.append("\n# --- Environment Setup ---") if spec.environment != {}: env = ",".join( - f"{key}={value if value else '""'}" - for key, value in spec.environment.items() + f"{key}={value or '""'}" for key, value in spec.environment.items() ) lines.append(f"--export={env}") # 9. Container logic @@ -79,7 +92,8 @@ def generate_slurm_script(spec: JobSpec) -> str: if spec.mpi.max_proc_per_node is None: spec.mpi.max_proc_per_node = 1 lines.append( - f"mpirun -n {spec.resource.nodes * spec.mpi.max_proc_per_node} {spec.command}" + f"mpirun -n {spec.resource.nodes * spec.mpi.max_proc_per_node}" + f" {spec.command}", ) else: lines.append(spec.command) @@ -88,6 +102,12 @@ def generate_slurm_script(spec: JobSpec) -> str: class SLURMExecutor: + """An executor for the SLURM submission system. + + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + Implements: :py:class:`tierkreis.controller.executor.hpc.hpc_executor.HPCExecutor` + """ + def __init__( self, registry_path: Path | None, @@ -107,6 +127,13 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: + """Run the node according to ControllerExecutor protocol. + + :param launcher_name: module description of worker to run. + :type launcher_name: str + :param worker_call_args_path: Location of the worker call args. + :type worker_call_args_path: Path + """ self.errors_path = ( self.logs_path.parent.parent / worker_call_args_path.parent / "errors" ) diff --git a/tierkreis/tierkreis/controller/executor/in_memory_executor.py b/tierkreis/tierkreis/controller/executor/in_memory_executor.py index a071a85dd..c6170e1aa 100644 --- a/tierkreis/tierkreis/controller/executor/in_memory_executor.py +++ b/tierkreis/tierkreis/controller/executor/in_memory_executor.py @@ -1,22 +1,31 @@ +"""In memory implementation.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) +import importlib.util import json import logging -import importlib.util from pathlib import Path from tierkreis.controller.data.location import WorkerCallArgs from tierkreis.controller.executor.check_launcher import check_and_set_launcher from tierkreis.controller.storage.in_memory import ControllerInMemoryStorage -from tierkreis.worker.storage.in_memory import InMemoryWorkerStorage from tierkreis.exceptions import TierkreisError - +from tierkreis.worker.storage.in_memory import InMemoryWorkerStorage logger = logging.getLogger(__name__) class InMemoryExecutor: - """Executes workers in the same process as the controller. + """Execute workers in the same process as the controller. + Loads the worker as python module if possible. + Cannot only run python workers in conjunction with ControllerInMemoryStorage. Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + registry_path (Path): The locations to search for worker modules. + storage (ControllerInMemoryStorage): + Storage reference to access in memory values. """ def __init__(self, registry_path: Path, storage: ControllerInMemoryStorage) -> None: @@ -30,14 +39,17 @@ def run( ) -> None: logger.info("START %s %s", launcher_name, worker_call_args_path) call_args = WorkerCallArgs( - **json.loads(self.storage.read(worker_call_args_path)) + **json.loads(self.storage.read(worker_call_args_path)), ) launcher_path = check_and_set_launcher(self.registry_path, launcher_name, ".py") spec = importlib.util.spec_from_file_location("in_memory", launcher_path) if spec is None or spec.loader is None: - raise TierkreisError( + msg = ( f"Couldn't load module main.py in {self.registry_path / launcher_name}" ) + raise TierkreisError( + msg, + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) worker_storage = InMemoryWorkerStorage(self.storage) diff --git a/tierkreis/tierkreis/controller/executor/multiple.py b/tierkreis/tierkreis/controller/executor/multiple.py index 4dbd7c6ef..f686ea16e 100644 --- a/tierkreis/tierkreis/controller/executor/multiple.py +++ b/tierkreis/tierkreis/controller/executor/multiple.py @@ -1,3 +1,6 @@ +"""A meta executor consisting of multiple single executors.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) from pathlib import Path from tierkreis.controller.executor.protocol import ControllerExecutor @@ -7,7 +10,15 @@ class MultipleExecutor: """Composes multiple executors into a single object. + Will execute all worker tasks on the assigned executor or default. Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + default (ControllerExecutor): The default executor to use for all unspecified + tasks. + executors (dict[str, ControllerExecutor]): A mapping of name -> executor. + assignments (dict[str, string]): A mapping of worker to executor name + """ def __init__( @@ -24,7 +35,6 @@ def run( self, launcher_name: str, worker_call_args_path: Path, - enable_logging: bool = True, ) -> None: executor_name = self.assignments.get(launcher_name, None) # If there is no assignment for the worker, use the default. @@ -32,8 +42,12 @@ def run( return self.default.run(launcher_name, worker_call_args_path) executor = self.executors.get(executor_name) if executor is None: + msg = ( + f"{launcher_name} is assigned to non-existent" + f" executor name: {executor_name}." + ) raise TierkreisError( - f"{launcher_name} is assigned to non-existent executor name: {executor_name}." + msg, ) return executor.run(launcher_name, worker_call_args_path) diff --git a/tierkreis/tierkreis/controller/executor/protocol.py b/tierkreis/tierkreis/controller/executor/protocol.py index 7d3a11cf4..872eccadb 100644 --- a/tierkreis/tierkreis/controller/executor/protocol.py +++ b/tierkreis/tierkreis/controller/executor/protocol.py @@ -1,3 +1,5 @@ +"""The base executor protocol.""" + from pathlib import Path from typing import Protocol @@ -14,16 +16,20 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: - """Run the node defined by the node_definition path. + """Run the node defined by the worker_call_args_path path. Specifies the worker to run by its launcher name. - For example the function "builtins.iadd" will call the builtins worker's iadd function. - The call arguments for the function call are retrieved retrieved from its location. + For example the function "builtins.iadd" will call the builtins worker's + iadd function. The call arguments for the function call are retrieved retrieved + from its location. + + The executor ensures workers are progressed correctly; This includes: + - setting up error and log files and making them available + - checking progress (e.g. _done file) + - enabling path resolution between tkr paths and worker inputs - :param launcher_name: module description of launcher to run. + :param launcher_name: module description of worker to run. :type launcher_name: str :param worker_call_args_path: Location of the worker call args. :type worker_call_args_path: Path """ - - ... diff --git a/tierkreis/tierkreis/controller/executor/registries.py b/tierkreis/tierkreis/controller/executor/registries.py index 45efdb9b1..fa7607fe0 100644 --- a/tierkreis/tierkreis/controller/executor/registries.py +++ b/tierkreis/tierkreis/controller/executor/registries.py @@ -1,15 +1,26 @@ +"""Utility for executors to resolve worker paths.""" + from pathlib import Path from tierkreis.exceptions import TierkreisError def find_registry_for_worker( - worker_name: str, registry_paths: Path | list[Path] + worker_name: str, + registry_paths: Path | list[Path], ) -> Path: """Return the first registry path containing a worker named `worker_name`. - Assumes that the worker is a directory.""" + Assumes that the worker is a directory. + :param worker_name: The worker to search for. + :type worker_name: str + :param registry_paths: List of all possible paths. + :type registry_paths: Path | list[Path] + :raises TierkreisError: If the worker is not found. + :return: First match to the worker + :rtype: Path + """ if isinstance(registry_paths, Path): registry_paths = [registry_paths] @@ -19,4 +30,6 @@ def find_registry_for_worker( if worker_name in subdirs: return registry if worker_path is None: - raise TierkreisError(f"{worker_name} not in registries {registry_paths}") + msg = f"{worker_name} not in registries {registry_paths}" + raise TierkreisError(msg) + return None diff --git a/tierkreis/tierkreis/controller/executor/shell_executor.py b/tierkreis/tierkreis/controller/executor/shell_executor.py index 502ebd924..a2221ba26 100644 --- a/tierkreis/tierkreis/controller/executor/shell_executor.py +++ b/tierkreis/tierkreis/controller/executor/shell_executor.py @@ -1,3 +1,6 @@ +"""Default executor for arbitrary scripts.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import json import os import subprocess @@ -11,7 +14,19 @@ class ShellExecutor: """Executes workers in an unix shell. + Simply runs any shell script as a worker, if certain conditions on input/output + conditions are met, namely the paths/values are provided through the process + environment and the script is responsible for reading/writing them. + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + launchers_path (Path): The locations to search for external workers. + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + workflow_dir (Path): The workflow dir to resolve relative paths. + timeout (int): Timeout for the process communication, defaults to 10 seconds. + env: (dict[str,str]): Additional environments to hand to the spawned subprocess. """ def __init__( @@ -20,6 +35,7 @@ def __init__( workflow_dir: Path, timeout: int = 10, env: dict[str, str] | None = None, + *, export_values: bool = False, ) -> None: self.launchers_path = registry_path @@ -35,41 +51,56 @@ def run( launcher_name: str, worker_call_args_path: Path, ) -> None: - launcher_path = self.launchers_path / launcher_name - + self.errors_path = worker_call_args_path.parent / "logs" launcher_path = check_and_set_launcher( - self.launchers_path, launcher_name, ".sh" + self.launchers_path, + launcher_name, + ".sh", ) - with open(self.workflow_dir.parent / worker_call_args_path) as fh: + with Path.open(self.workflow_dir.parent / worker_call_args_path) as fh: call_args = WorkerCallArgs(**json.load(fh)) env = os.environ.copy() | self.env.copy() env.update( - self._create_env(call_args, self.workflow_dir.parent, self.export_values) + self._create_env( + call_args, + self.workflow_dir.parent, + export_values=self.export_values, + ), ) env["worker_call_args_file"] = str( - self.workflow_dir.parent / worker_call_args_path + self.workflow_dir.parent / worker_call_args_path, ) done_path = self.workflow_dir.parent / call_args.done_path _error_path = done_path.parent / "_error" if TKR_DIR_KEY not in env: env[TKR_DIR_KEY] = str(self.logs_path.parent.parent) - tee_str = f">(tee -a {str(self.errors_path)} {str(self.logs_path)} >/dev/null)" + tee_str = f">(tee -a {self.errors_path!s} {self.logs_path!s} >/dev/null)" proc = subprocess.Popen( - ["bash"], + ["/bin/bash"], start_new_session=True, stdin=subprocess.PIPE, env=env, ) proc.communicate( - f"({launcher_path} {worker_call_args_path} > {tee_str} 2> {tee_str} && touch {done_path}|| touch {_error_path})&".encode(), + f"({launcher_path} {worker_call_args_path} > {tee_str} 2> {tee_str} " + f"&& touch {done_path}|| touch {_error_path})&".encode(), timeout=self.timeout, ) def _create_env( - self, call_args: WorkerCallArgs, base_dir: Path, export_values: bool + self, + call_args: WorkerCallArgs, + base_dir: Path, + *, + export_values: bool, ) -> dict[str, str]: + """Set up an environment as interface between controller and worker function. + + If export_values is set, will also write the values of ports to the env. + This is useful if you don't want / can't read the files directly. + """ env = { "checkpoints_directory": str(base_dir), "function_name": str(base_dir / call_args.function_name), @@ -91,6 +122,6 @@ def _create_env( return env values = {} for k, v in call_args.inputs.items(): - with open(v) as fh: + with Path.open(v) as fh: values[f"input_{k}_value"] = fh.read() return env diff --git a/tierkreis/tierkreis/controller/executor/stdinout.py b/tierkreis/tierkreis/controller/executor/stdinout.py index e95f4cf6c..d120d71df 100644 --- a/tierkreis/tierkreis/controller/executor/stdinout.py +++ b/tierkreis/tierkreis/controller/executor/stdinout.py @@ -1,3 +1,6 @@ +"""Special case implementation for external workers.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import json import shutil import subprocess @@ -10,7 +13,17 @@ class StdInOut: """Executes workers in an unix shell. + Assumes the worker takes a single input from stdin and will produce a single output + to stdout. + Will pipe other outputs to errors / logs. + Works by creating a subprocess Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + launchers_path (Path): The locations to search for external workers. + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + workflow_dir (Path): The workflow dir to resolve relative paths. """ def __init__(self, registry_path: Path, workflow_dir: Path) -> None: @@ -27,25 +40,28 @@ def run( launcher_path = _check_bin(launcher_name) if launcher_path is None: launcher_path = check_and_set_launcher( - self.launchers_path, launcher_name, ".sh" + self.launchers_path, + launcher_name, + ".sh", ) - with open(self.workflow_dir.parent / worker_call_args_path) as fh: + with Path.open(self.workflow_dir.parent / worker_call_args_path) as fh: call_args = WorkerCallArgs(**json.load(fh)) - input_file = self.workflow_dir.parent / list(call_args.inputs.values())[0] - output_file = self.workflow_dir.parent / list(call_args.outputs.values())[0] + input_file = self.workflow_dir.parent / next(iter(call_args.inputs.values())) + output_file = self.workflow_dir.parent / next(iter(call_args.outputs.values())) done_path = self.workflow_dir.parent / call_args.done_path - tee_str = f">(tee -a {str(self.errors_path)} {str(self.logs_path)} >/dev/null)" + tee_str = f">(tee -a {self.errors_path!s} {self.logs_path!s} >/dev/null)" _error_path = done_path.parent / "_error" proc = subprocess.Popen( - ["bash"], + ["/bin/bash"], start_new_session=True, stdin=subprocess.PIPE, ) proc.communicate( - f"({launcher_path} <{input_file} >{output_file} 2> {tee_str} && touch {done_path} || touch {_error_path})&".encode(), + f"({launcher_path} <{input_file} > {output_file} 2> {tee_str}" + f" && touch {done_path}|| touch {_error_path})&".encode(), timeout=10, ) diff --git a/tierkreis/tierkreis/controller/executor/task_executor.py b/tierkreis/tierkreis/controller/executor/task_executor.py index 61e9e4060..a46775036 100644 --- a/tierkreis/tierkreis/controller/executor/task_executor.py +++ b/tierkreis/tierkreis/controller/executor/task_executor.py @@ -1,6 +1,9 @@ +"""A meta executor consisting assigning executors to tasks.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import json +from fnmatch import filter # noqa: A004 from pathlib import Path -from fnmatch import filter from tierkreis.controller.data.location import WorkerCallArgs from tierkreis.controller.executor.protocol import ControllerExecutor @@ -9,19 +12,23 @@ class TaskExecutor: - """A Tierkreis executor that routes tasks to other executors based on the fully qualified task name. + """A Tierkreis executor that routes tasks to other executors. + Routing is based on the fully qualified task name. The fully qualified task name is of the form . . - Glob syntax can be used to route multiple tasks to the same executor.""" + Glob syntax can be used to route multiple tasks to the same executor. + """ def __init__( - self, assignments: dict[str, ControllerExecutor], storage: ControllerStorage + self, + assignments: dict[str, ControllerExecutor], + storage: ControllerStorage, ) -> None: self.assignments = assignments self.workflow_dir = storage.workflow_dir def run(self, launcher_name: str, worker_call_args_path: Path) -> None: - with open(self.workflow_dir.parent / worker_call_args_path) as fh: + with Path.open(self.workflow_dir.parent / worker_call_args_path) as fh: call_args = WorkerCallArgs(**json.load(fh)) qualified_task = f"{launcher_name}.{call_args.function_name}" @@ -31,4 +38,5 @@ def run(self, launcher_name: str, worker_call_args_path: Path) -> None: executor.run(launcher_name, worker_call_args_path) return - raise TierkreisError(f"No assigned executor for task {qualified_task}") + msg = f"No assigned executor for task {qualified_task}" + raise TierkreisError(msg) diff --git a/tierkreis/tierkreis/controller/executor/uv_executor.py b/tierkreis/tierkreis/controller/executor/uv_executor.py index 474eb0d9b..c4ae7285b 100644 --- a/tierkreis/tierkreis/controller/executor/uv_executor.py +++ b/tierkreis/tierkreis/controller/executor/uv_executor.py @@ -1,3 +1,6 @@ +"""Default python executor based on uv.""" + +# ruff: noqa: D102 (class methods inherited from ControllerExecutor) import logging import os import shutil @@ -15,7 +18,19 @@ class UvExecutor: """Executes workers in an UV python environment. + Depends on uv to run, hence the worker needs a pyproject.toml / a respective script. + Works out of the box with the cli worker definitions. + The env field can be used to provide additional variables; for example + controlling the python / uv version through $VIRTUAL_ENVIRONMENT. + Also to resolve paths, the $TKR_DIR will be set to the workflow directory. + Implements: :py:class:`tierkreis.controller.executor.protocol.ControllerExecutor` + + :fields: + launchers_path (Path): The locations to search for python workers. + logs_path (Path): The controller log file. + errors_path (Path): The controller error file for the function node. + env: (dict[str,str]): Additional environments to hand to the spawned subprocess. """ def __init__( @@ -38,14 +53,15 @@ def run( self.errors_path = ( self.logs_path.parent.parent / worker_call_args_path.parent - / "logs" # made we should change this + / "logs" # maybe we should change this ) logger.info("START %s %s", launcher_name, worker_call_args_path) if uv_path is None: uv_path = shutil.which("uv") if uv_path is None: - raise TierkreisError("uv is required to use the uv_executor") + msg = "uv is required to use the uv_executor" + raise TierkreisError(msg) registry_path = find_registry_for_worker(launcher_name, self.registries) worker_path = check_and_set_launcher(registry_path, launcher_name, ".py").parent @@ -55,15 +71,16 @@ def run( if TKR_DIR_KEY not in env: env[TKR_DIR_KEY] = str(self.logs_path.parent.parent) _error_path = self.errors_path.parent / "_error" - tee_str = f">(tee -a {str(self.errors_path)} {str(self.logs_path)} >/dev/null)" + tee_str = f">(tee -a {self.errors_path!s} {self.logs_path!s} >/dev/null)" proc = subprocess.Popen( - ["bash"], + ["/bin/bash"], start_new_session=True, stdin=subprocess.PIPE, cwd=worker_path, env=env, ) proc.communicate( - f"({uv_path} run main.py {worker_call_args_path} > {tee_str} 2> {tee_str} || touch {_error_path}) &".encode(), + f"({uv_path} run main.py {worker_call_args_path} > {tee_str} 2> {tee_str}" + f" || touch {_error_path}) &".encode(), timeout=10, ) diff --git a/tierkreis/tierkreis/controller/start.py b/tierkreis/tierkreis/controller/start.py index 383e0d5ed..9850f6022 100644 --- a/tierkreis/tierkreis/controller/start.py +++ b/tierkreis/tierkreis/controller/start.py @@ -1,10 +1,11 @@ +"""Main functionality to start nodes in a graph.""" + import logging import subprocess import sys from dataclasses import dataclass from pathlib import Path - -from typing_extensions import assert_never +from typing import assert_never from tierkreis.consts import PACKAGE_PATH from tierkreis.controller.data.core import PortID @@ -24,6 +25,14 @@ @dataclass class NodeRunData: + """Data required to run a node. + + :fields: + node_location (Loc): The location of the node to run. + node (NodeDef): The node definition to run. + output_list (list[PortID]): The list of output port ids for the node. + """ + node_location: Loc node: NodeDef output_list: list[PortID] @@ -34,6 +43,15 @@ def start_nodes( executor: ControllerExecutor, node_run_data: list[NodeRunData], ) -> None: + """Start multiple nodes at once. + + :param storage: The storage backend for the controller. + :type storage: ControllerStorage + :param executor: The executor backend for the controller. + :type executor: ControllerExecutor + :param node_run_data: The list of nodes to start (by their data). + :type node_run_data: list[NodeRunData] + """ started_locs: set[Loc] = set() for node_run_datum in node_run_data: if node_run_datum.node_location in started_locs: @@ -42,11 +60,20 @@ def start_nodes( started_locs.add(node_run_datum.node_location) -def run_builtin(def_path: Path, logs_path: Path) -> None: - logger.info("START builtin %s", def_path) - with open(logs_path, "a") as fh: +def run_builtin(call_args_path: Path, logs_path: Path) -> None: + """Run a builtin task. + + This is run directly by the controller. + + :param call_args_path: The path to the call arguments file. + :type call_args_path: Path + :param logs_path: The main controller log. + :type logs_path: Path + """ + logger.info("START builtin %s", call_args_path) + with Path.open(logs_path, "a") as fh: subprocess.Popen( - [sys.executable, "main.py", def_path], + [sys.executable, "main.py", call_args_path], start_new_session=True, cwd=PACKAGE_PATH / "tierkreis" / "builtins", stderr=fh, @@ -58,8 +85,25 @@ def start( storage: ControllerStorage, executor: ControllerExecutor, node_run_data: NodeRunData, - enable_logging: bool = True, ) -> None: + """Start the execution of a node. + + Identiefies the node type and starts it accordingly. + - For function nodes, it uses the executor to run the worker. + - Recursively starts higher order nodes (eval, loop, map) + - Routes the inputs and outputs for the nodes to the correct locations in storage. + + To start its node it must have its inputs available. + Inputs can be provided by the parent node (in the case of higher order nodes). + + :param storage: The storage backend for the controller. + :type storage: ControllerStorage + :param executor: The executor backend for the controller. + :type executor: ControllerExecutor + :param node_run_data: The data required to run a node. + :type node_run_data: NodeRunData + :raises TierkreisError: If the node is an orphan. + """ node_location = node_run_data.node_location node = node_run_data.node output_list = node_run_data.output_list @@ -68,22 +112,27 @@ def start( parent = node_location.parent() if parent is None: - raise TierkreisError(f"{node.type} node must have parent Loc.") + msg = f"{node.type} node must have parent Loc." + raise TierkreisError(msg) ins = {k: (parent.N(idx), p) for k, (idx, p) in node.inputs.items()} - logger.debug(f"start {node_location} {node} {ins} {output_list}") + logger.debug("start %s %s %s %s", node_location, node, ins, output_list) if node.type == "function": name = node.function_name launcher_name = ".".join(name.split(".")[:-1]) name = name.split(".")[-1] call_args_path = storage.write_worker_call_args( - node_location, name, ins, output_list + node_location, + name, + ins, + output_list, ) - logger.debug(f"Executing {(str(node_location), name, ins, output_list)}") + logger.debug("Executing %s", (str(node_location), name, ins, output_list)) if isinstance(storage, ControllerInMemoryStorage) and isinstance( - executor, InMemoryExecutor + executor, + InMemoryExecutor, ): executor.run(launcher_name, call_args_path) elif launcher_name == "builtins": @@ -99,7 +148,7 @@ def start( elif node.type == "output": storage.mark_node_finished(node_location) - pipe_inputs_to_output_location(storage, parent, ins) + _pipe_inputs_to_output_location(storage, parent, ins) storage.mark_node_finished(parent) elif node.type == "const": @@ -113,14 +162,12 @@ def start( ins["body"] = (parent.N(node.graph[0]), node.graph[1]) ins.update(g.fixed_inputs) - pipe_inputs_to_output_location(storage, node_location.N(-1), ins) + _pipe_inputs_to_output_location(storage, node_location.N(-1), ins) elif node.type == "loop": ins["body"] = (parent.N(node.body[0]), node.body[1]) - pipe_inputs_to_output_location(storage, node_location.N(-1), ins) - if ( - node.name is not None - ): # should we do this only in debug mode? -> need to think through how this would work + _pipe_inputs_to_output_location(storage, node_location.N(-1), ins) + if node.name is not None: storage.write_debug_data(node.name, node_location) start( storage, @@ -138,10 +185,10 @@ def start( elif node.type == "map": first_ref = next(x for x in ins.values() if x[1] == "*") - map_eles = outputs_iter(storage, first_ref[0]) - if not map_eles: + map_elements = outputs_iter(storage, first_ref[0]) + if not map_elements: storage.mark_node_finished(node_location) - for idx, p in map_eles: + for idx, p in map_elements: eval_inputs: dict[PortID, tuple[Loc, PortID]] = {} eval_inputs["body"] = (parent.N(node.body[0]), node.body[1]) for k, (i, port) in ins.items(): @@ -149,8 +196,10 @@ def start( eval_inputs[k] = (i, p) else: eval_inputs[k] = (i, port) - pipe_inputs_to_output_location( - storage, node_location.M(idx).N(-1), eval_inputs + _pipe_inputs_to_output_location( + storage, + node_location.M(idx).N(-1), + eval_inputs, ) # Necessary in the node visualization storage.write_node_def( @@ -158,16 +207,13 @@ def start( Eval((-1, "body"), node.inputs, outputs=node.outputs), ) - elif node.type == "ifelse": - pass - - elif node.type == "eifelse": + elif node.type in {"ifelse", "eifelse"}: pass else: assert_never(node) -def pipe_inputs_to_output_location( +def _pipe_inputs_to_output_location( storage: ControllerStorage, output_loc: Loc, inputs: dict[PortID, OutputLoc], diff --git a/tierkreis/tierkreis/controller/storage/__init__.py b/tierkreis/tierkreis/controller/storage/__init__.py index e69de29bb..085e534b4 100644 --- a/tierkreis/tierkreis/controller/storage/__init__.py +++ b/tierkreis/tierkreis/controller/storage/__init__.py @@ -0,0 +1 @@ +"""Storage definitions for the controller.""" diff --git a/tierkreis/tierkreis/controller/storage/adjacency.py b/tierkreis/tierkreis/controller/storage/adjacency.py index 564130b22..03677dc99 100644 --- a/tierkreis/tierkreis/controller/storage/adjacency.py +++ b/tierkreis/tierkreis/controller/storage/adjacency.py @@ -1,3 +1,5 @@ +"""Graph information based on adjacency.""" + import logging from tierkreis.controller.data.core import PortID, ValueRef @@ -9,14 +11,41 @@ def unfinished_inputs( - storage: ControllerStorage, loc: Loc, node: NodeDef + storage: ControllerStorage, + loc: Loc, + node: NodeDef, ) -> list[ValueRef]: + """Find the unfinished inputs of a node. + + :param storage: The storage to write from. + :type storage: ControllerStorage + :param loc: The node location to check for. + :type loc: Loc + :param node: The node definition containing the output names. + :type node: NodeDef + :return: A list of references to node inputs. + :rtype: list[ValueRef] + """ # ifelse is lazy: only wait for pred before starting ins = [node.pred] if node.type == "ifelse" else in_edges(node).values() - ins = [x for x in ins if x[0] >= 0] # inputs at -1 already finished + ins = [x for x in ins if x[0] >= 0] # inputs at -1 already finished they're linked return [x for x in ins if not storage.is_node_finished(loc.N(x[0]))] def outputs_iter(storage: ControllerStorage, loc: Loc) -> list[tuple[int, PortID]]: + """Find all the outputs of a node and provide them with their index as map elements. + + This is only used in map nodes to go from the * port to the values + of actual map elements. + This can be from an unfold where we get (index, index) + or map (index, "eval_output_name-index") + + :param storage: The storage to read from. + :type storage: ControllerStorage + :param loc: The location to get the outputs from. + :type loc: Loc + :return: A tuple of (index, portname) of + :rtype: list[tuple[int, PortID]] + """ eles = storage.read_output_ports(loc) return [(int(x.split("-")[-1]), x) for x in eles] diff --git a/tierkreis/tierkreis/controller/storage/exceptions.py b/tierkreis/tierkreis/controller/storage/exceptions.py index cafb2be11..6080f0e3f 100644 --- a/tierkreis/tierkreis/controller/storage/exceptions.py +++ b/tierkreis/tierkreis/controller/storage/exceptions.py @@ -1,4 +1,7 @@ +"""Tierkreis Errors.""" + from pathlib import Path + from tierkreis.exceptions import TierkreisError @@ -6,7 +9,7 @@ class TierkreisStorageError(TierkreisError): """An error with the chosen Tierkreis storage layer.""" -class EntryNotFound(TierkreisStorageError): +class EntryNotFoundError(TierkreisStorageError): """Storage entry not found.""" path: Path @@ -15,4 +18,5 @@ def __init__(self, path: Path) -> None: self.path = path def __str__(self) -> str: + """Print string of self.""" return str(self.path) diff --git a/tierkreis/tierkreis/controller/storage/filestorage.py b/tierkreis/tierkreis/controller/storage/filestorage.py index 6188a61e6..db53191f7 100644 --- a/tierkreis/tierkreis/controller/storage/filestorage.py +++ b/tierkreis/tierkreis/controller/storage/filestorage.py @@ -1,22 +1,34 @@ +"""Default file storage implementation.""" + import os import shutil from pathlib import Path from time import time_ns +from typing import override from uuid import UUID -from tierkreis.controller.storage.exceptions import EntryNotFound +from tierkreis.controller.storage.exceptions import EntryNotFoundError from tierkreis.controller.storage.protocol import ( ControllerStorage, StorageEntryMetadata, ) +DEFAULT_DIRECTORY = Path.home() / ".tierkreis" / "checkpoints" + class ControllerFileStorage(ControllerStorage): + """Storage backend using the filesystem. + + This storage implementation operates by relegating calls to the os filesystem. + Calling with `do_cleanup` will ensure that previous runs are deleted. + """ + def __init__( self, workflow_id: UUID, name: str | None = None, - tierkreis_directory: Path = Path.home() / ".tierkreis" / "checkpoints", + tierkreis_directory: Path = DEFAULT_DIRECTORY, + *, do_cleanup: bool = False, ) -> None: self.tkr_dir = tierkreis_directory @@ -25,6 +37,7 @@ def __init__( if do_cleanup: self.delete(self.workflow_dir) + @override def delete(self, path: Path) -> None: uid = os.getuid() tmp_dir = Path(f"/tmp/{uid}/tierkreis/archive/{self.workflow_id}/{time_ns()}") @@ -32,12 +45,15 @@ def delete(self, path: Path) -> None: if self.exists(path): shutil.move(path, tmp_dir) + @override def exists(self, path: Path) -> bool: return path.exists() + @override def list_subpaths(self, path: Path) -> list[Path]: - return [sub_path for sub_path in path.iterdir()] + return list(path.iterdir()) + @override def link(self, src: Path, dst: Path) -> None: dst.parent.mkdir(parents=True, exist_ok=True) if dst.exists() and dst.resolve() == src: @@ -46,19 +62,22 @@ def link(self, src: Path, dst: Path) -> None: try: os.link(src, dst) except (FileNotFoundError, FileExistsError) as exc: - raise EntryNotFound(src) from exc + raise EntryNotFoundError(src) from exc + @override def mkdir(self, path: Path) -> None: return path.mkdir(parents=True, exist_ok=True) + @override def read(self, path: Path) -> bytes: try: - with open(path, "rb") as fh: + with Path.open(path, "rb") as fh: return fh.read() except FileNotFoundError as exc: - raise EntryNotFound(path) from exc + raise EntryNotFoundError(path) from exc - def touch(self, path: Path, is_dir: bool = False) -> None: + @override + def touch(self, path: Path, *, is_dir: bool = False) -> None: if is_dir: path.mkdir(parents=True, exist_ok=True) return @@ -66,10 +85,12 @@ def touch(self, path: Path, is_dir: bool = False) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.touch() + @override def stat(self, path: Path) -> StorageEntryMetadata: return StorageEntryMetadata(path.stat().st_mtime) + @override def write(self, path: Path, value: bytes) -> None: path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "wb+") as fh: + with Path.open(path, "wb+") as fh: fh.write(value) diff --git a/tierkreis/tierkreis/controller/storage/graphdata.py b/tierkreis/tierkreis/controller/storage/graphdata.py index 994a3b0c5..8ed9ebb04 100644 --- a/tierkreis/tierkreis/controller/storage/graphdata.py +++ b/tierkreis/tierkreis/controller/storage/graphdata.py @@ -1,10 +1,12 @@ +"""Virtual GraphStorage for visualization.""" + import json from pathlib import Path +from typing import Any, override from uuid import UUID -from typing import Any - from pydantic import BaseModel, Field + from tierkreis.controller.data.core import PortID from tierkreis.controller.data.graph import ( Eval, @@ -14,8 +16,8 @@ ) from tierkreis.controller.data.location import Loc, OutputLoc, WorkerCallArgs from tierkreis.controller.storage.protocol import ( - StorageEntryMetadata, ControllerStorage, + StorageEntryMetadata, ) from tierkreis.exceptions import TierkreisError @@ -35,6 +37,13 @@ class NodeData(BaseModel): class GraphDataStorage(ControllerStorage): + """Storage backend using in-memory GraphData for workflow execution. + + This storage implementation operates read-only on a GraphData object without + writing to disk. + Used for visualization without running the workflow. + """ + def __init__( self, workflow_id: UUID, @@ -47,36 +56,57 @@ def __init__( self.graph = graph self.tkr_dir = Path.home() / ".tierkreis" + @override def delete(self, path: Path) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def exists(self, path: Path) -> bool: - raise NotImplementedError("GraphDataStorage is only for graph construction.") + msg = "GraphDataStorage is only for graph construction." + raise NotImplementedError(msg) + @override def list_subpaths(self, path: Path) -> list[Path]: - raise NotImplementedError("GraphDataStorage uses GraphData not paths.") + msg = "GraphDataStorage uses GraphData not paths." + raise NotImplementedError(msg) + @override def link(self, src: Path, dst: Path) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def mkdir(self, path: Path) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read(self, path: Path) -> bytes: - raise NotImplementedError("GraphDataStorage uses GraphData not paths.") + msg = "GraphDataStorage uses GraphData not paths." + raise NotImplementedError(msg) + @override def stat(self, path: Path) -> StorageEntryMetadata: - raise NotImplementedError("GraphDataStorage is only for graph construction.") + msg = "GraphDataStorage is only for graph construction." + raise NotImplementedError(msg) - def touch(self, path: Path, is_dir: bool = False) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + @override + def touch(self, path: Path) -> None: + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def write(self, path: Path, value: bytes) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def write_node_def(self, node_location: Loc, node: NodeDef) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_node_def(self, node_location: Loc) -> NodeDef: try: if node_location.pop_last()[0][0] in ["M", "L"]: @@ -86,6 +116,7 @@ def read_node_def(self, node_location: Loc) -> NodeDef: node, _ = graph_node_from_loc(node_location, self.graph) return node + @override def write_worker_call_args( self, node_location: Loc, @@ -93,28 +124,39 @@ def write_worker_call_args( inputs: dict[PortID, OutputLoc], output_list: list[PortID], ) -> Path: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_worker_call_args(self, node_location: Loc) -> WorkerCallArgs: + msg = f"Node location {node_location} doesn't have a associate call args." raise TierkreisError( - f"Node location {node_location} doesn't have a associate call args." + msg, ) - def read_errors(self, node_location: Loc = Loc()) -> str: + @override + def read_errors(self, node_location: Loc | None = None) -> str: return "" + @override def node_has_error(self, node_location: Loc) -> bool: return False + @override def write_node_errors(self, node_location: Loc, error_logs: str) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def mark_node_finished(self, node_location: Loc) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def is_node_finished(self, node_location: Loc) -> bool: return False + @override def link_outputs( self, new_location: Loc, @@ -122,16 +164,23 @@ def link_outputs( old_location: Loc, old_port: PortID, ) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def write_output( - self, node_location: Loc, output_name: PortID, value: bytes + self, + node_location: Loc, + output_name: PortID, + value: bytes, ) -> Path: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_output(self, node_location: Loc, output_name: PortID) -> bytes: node, graph = graph_node_from_loc(node_location, self.graph) - if -1 == node_location.peek_index() and output_name == "body": + if node_location.peek_index() == -1 and output_name == "body": return graph.model_dump_json().encode() outputs = _build_node_outputs(node) @@ -139,25 +188,33 @@ def read_output(self, node_location: Loc, output_name: PortID) -> bytes: if output := outputs[output_name]: return output return b"null" - raise TierkreisError(f"No output named {output_name} in node {node_location}") + msg = f"No output named {output_name} in node {node_location}" + raise TierkreisError(msg) + @override def read_output_ports(self, node_location: Loc) -> list[PortID]: node, _ = graph_node_from_loc(node_location, self.graph) outputs = _build_node_outputs(node) return list(filter(lambda k: k != "*", outputs.keys())) + @override def is_node_started(self, node_location: Loc) -> bool: return False + @override def read_metadata(self, node_location: Loc) -> dict[str, Any]: return self.nodes[node_location].metadata + @override def write_metadata(self, node_location: Loc) -> None: - raise NotImplementedError("GraphDataStorage is read only storage.") + msg = "GraphDataStorage is read only storage." + raise NotImplementedError(msg) + @override def read_started_time(self, node_location: Loc) -> str | None: return None + @override def read_finished_time(self, node_location: Loc) -> str | None: return None @@ -167,13 +224,11 @@ def _build_node_outputs(node: NodeDef) -> dict[PortID, None | bytes]: if isinstance(node.value, dict): if "nodes" not in node.value: return {"value": json.dumps(node.value).encode()} - else: - return {"value": b"Graph"} - elif isinstance(node.value, GraphData): return {"value": b"Graph"} - else: - return {"value": json.dumps(node.value).encode()} - outputs: dict[PortID, None | bytes] = {val: None for val in node.outputs} + if isinstance(node.value, GraphData): + return {"value": b"Graph"} + return {"value": json.dumps(node.value).encode()} + outputs: dict[PortID, None | bytes] = dict.fromkeys(node.outputs) if "*" in outputs: outputs["0"] = None return outputs diff --git a/tierkreis/tierkreis/controller/storage/in_memory.py b/tierkreis/tierkreis/controller/storage/in_memory.py index 9bf441ee7..3871fb1a5 100644 --- a/tierkreis/tierkreis/controller/storage/in_memory.py +++ b/tierkreis/tierkreis/controller/storage/in_memory.py @@ -1,15 +1,25 @@ +"""In memory implementation of a storage layer.""" + from pathlib import Path -from uuid import UUID from time import time +from typing import override +from uuid import UUID -from tierkreis.controller.storage.exceptions import EntryNotFound +from tierkreis.controller.storage.exceptions import EntryNotFoundError from tierkreis.controller.storage.protocol import ( - StorageEntryMetadata, ControllerStorage, + StorageEntryMetadata, ) class InMemoryFileData: + """Class to emulate the file system behaviour in memory. + + :fields: + value (bytes): The content of a file, typically used for outputs or empty. + stats (StorageEntryMetadata): A metadata entry. + """ + value: bytes stats: StorageEntryMetadata @@ -19,6 +29,13 @@ def __init__(self, value: bytes) -> None: class ControllerInMemoryStorage(ControllerStorage): + """In-memory implementation of ControllerStorage. + + Stores workflow files in memory using a dictionary instead of the filesystem. + Uses a mapping Path -> FileData to emulate the required filesystem structure. + Useful when debugging applications in conjunction with the InMemoryExecutor. + """ + def __init__( self, workflow_id: UUID, @@ -31,44 +48,51 @@ def __init__( self.files: dict[Path, InMemoryFileData] = {} + @override def delete(self, path: Path) -> None: self.files = {} + @override def exists(self, path: Path) -> bool: return path in list(self.files.keys()) + @override def list_subpaths(self, path: Path) -> list[Path]: if path == self.workflow_dir: - nodes = set( - [ - Path("/".join(str(x).split("/")[:2])) - for x in self.files.keys() - if str(x).startswith(str(path) + "/") - ] - ) + nodes = { + Path("/".join(str(x).split("/")[:2])) + for x in self.files + if str(x).startswith(str(path) + "/") + } return list(nodes) - return [x for x in self.files.keys() if str(x).startswith(str(path) + "/")] + return [x for x in self.files if str(x).startswith(str(path) + "/")] + @override def link(self, src: Path, dst: Path) -> None: try: self.files[dst] = self.files[src] except KeyError as exc: - raise EntryNotFound(src) from exc + raise EntryNotFoundError(src) from exc + @override def mkdir(self, path: Path) -> None: return + @override def read(self, path: Path) -> bytes: try: return self.files[path].value except KeyError as exc: - raise EntryNotFound(path) from exc + raise EntryNotFoundError(path) from exc - def touch(self, path: Path, is_dir: bool = False) -> None: + @override + def touch(self, path: Path) -> None: self.files[path] = InMemoryFileData(b"") + @override def stat(self, path: Path) -> StorageEntryMetadata: return self.files[path].stats + @override def write(self, path: Path, value: bytes) -> None: self.files[path] = InMemoryFileData(value) diff --git a/tierkreis/tierkreis/controller/storage/protocol.py b/tierkreis/tierkreis/controller/storage/protocol.py index 45cc6d438..efea1d0e4 100644 --- a/tierkreis/tierkreis/controller/storage/protocol.py +++ b/tierkreis/tierkreis/controller/storage/protocol.py @@ -1,8 +1,10 @@ +"""The storage interface.""" + import json import logging from abc import ABC, abstractmethod from dataclasses import asdict, dataclass -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path from typing import Any, assert_never from uuid import UUID @@ -10,7 +12,7 @@ from tierkreis.controller.data.core import PortID from tierkreis.controller.data.graph import NodeDef, NodeDefModel from tierkreis.controller.data.location import Loc, OutputLoc, WorkerCallArgs -from tierkreis.controller.storage.exceptions import EntryNotFound +from tierkreis.controller.storage.exceptions import EntryNotFoundError from tierkreis.exceptions import TierkreisError logger = logging.getLogger(__name__) @@ -20,79 +22,162 @@ class StorageEntryMetadata: """Collection of commonly found metadata. - Storage implementations should decide which are applicable.""" + Storage implementations should decide which are applicable. + :fields: + st_mtime (float | None): The start time of a node, defaults to None. + """ st_mtime: float | None = None @dataclass class StorageDebugData: - """Collection of commonly found debugdata. + """Collection of commonly found debug data. Currently only used for loop_nodes - Storage implementations should decide which are applicable.""" + Storage implementations should decide which are applicable. + :fields: + loop_loc (str | None): The Loc of loop node, defaults to None. + Can only be known after the node has run. + """ loop_loc: str | None = None class ControllerStorage(ABC): + """Storage interface for the tierkreis controller. + + :abstract: + Conceptually, the storage layer represents the current state of the operation. + It includes information such as: + - Checkpoints + - Node definitions + - Outputs and their values + - Metadata + - Debug Data + + This interface primarily targets filesystems as underlying storage method. + For storages based on other methods, you manually need to map paths + representing node locations to an internal address. + This interface already handles translations from Locs to Paths + + :fields: + tkr_dir (Path): The base of the storage. + workflow_id (UUID): The unique workflow id. + name (str | None): Optional name for a workflow, defaults to None. + + :properties: + workflow_dir (Path): The workflow storage location. + logs_path (Path): Location of the workflow logs. + debug_path (Path): Location of the workflow debug information. + """ + tkr_dir: Path workflow_id: UUID name: str | None @abstractmethod def delete(self, path: Path) -> None: - """Delete the storage entry at the specified path. + r"""Delete the storage entry at the specified path. - Also delete any related data of the form \"{path}/**/*\".""" + Also delete any related data of the form \"{path}/**/*\". + Only necessary for persistent storage types. + + :param path: The storage location to delete. + :type path: Path + """ @abstractmethod def exists(self, path: Path) -> bool: - """Is there an entry in the storage at the specified path?""" + """Check whether there is an entry in the storage at the specified path. + + :param path: The storage location to check. + :type path: Path + :return: True if it exists. + :rtype: bool + """ @abstractmethod def link(self, src: Path, dst: Path) -> None: - """The storage entry at `dst` should have the same value as the entry at `src`.""" + """Link storage entry at `dst` to have the same value as the entry at `src`. + + :param src: The source entry. + :type src: Path + :param dst: The destination entry. + :type dst: Path + """ @abstractmethod def list_subpaths(self, path: Path) -> list[Path]: - """List all the paths starting with the specified path. + """List all the paths starting with the specified path in the storage. This is used when the number of entries can only be determined at runtime. - For example in a map node.""" + For example in a map node. + + :param path: The location to list children in. + :type path: Path + :return: A list of child entries. + :rtype: list[Path] + """ @abstractmethod def mkdir(self, path: Path) -> None: - """Create an empty directory (and parents) at this path. + """Create an empty directory (and parents) at this storage location. - Probably only required for file-based storage.""" + Probably only required for file-based storage. + + :param path: The location where to create the directory. + :type path: Path + """ @abstractmethod def read(self, path: Path) -> bytes: - """Read the storage entry at the specified path.""" + """Read the storage entry at the specified path. + + :param path: The location to read from. + :type path: Path + :return: The value at the storage location. + :rtype: bytes + """ @abstractmethod def stat(self, path: Path) -> StorageEntryMetadata: - """Get applicable stats for storage entry.""" + """Get applicable stats for storage entry. + + :param path: The location to get stats for. + :type path: Path + :return: The stats of this entry. + :rtype: StorageEntryMetadata + """ @abstractmethod def touch(self, path: Path) -> None: - """Create empty storage entry at the specified path.""" + """Create empty storage entry at the specified location. + + :param path: The location where to generate the entry. + :type path: Path + """ @abstractmethod def write(self, path: Path, value: bytes) -> None: - """Write the given bytes to the storage entry at the specified path.""" + """Write the given bytes to the storage entry at the specified path. + + :param path: The location to write to. + :type path: Path + :param value: The value to write. + :type value: bytes + """ @property - def workflow_dir(self) -> Path: + def workflow_dir(self) -> Path: # noqa: D102 documented in class return self.tkr_dir / str(self.workflow_id) @property - def logs_path(self) -> Path: + def logs_path(self) -> Path: # noqa: D102 documented in class return self.workflow_dir / "logs" @property - def debug_path(self) -> Path: + def debug_path(self) -> Path: # noqa: D102 documented in class return self.workflow_dir / "debug" def _nodedef_path(self, node_location: Loc) -> Path: @@ -120,13 +205,28 @@ def _worker_logs_path(self, node_location: Loc) -> Path: return self.workflow_dir / str(node_location) / "logs" def clean_graph_files(self) -> None: + """Delete the workflow directory of a graph.""" self.delete(self.workflow_dir) - def write_node_def(self, node_location: Loc, node: NodeDef): + def write_node_def(self, node_location: Loc, node: NodeDef) -> None: + """Write a node definition to storage. + + :param node_location: The location to write to. + :type node_location: Loc + :param node: The node definition to write. + :type node: NodeDef + """ bs = NodeDefModel(root=node).model_dump_json().encode() self.write(self._nodedef_path(node_location), bs) def read_node_def(self, node_location: Loc) -> NodeDef: + """Read the definition of a node in storage. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The retrieved node definition. + :rtype: NodeDef + """ bs = self.read(self._nodedef_path(node_location)) return NodeDefModel(**json.loads(bs)).root @@ -137,6 +237,19 @@ def write_worker_call_args( inputs: dict[PortID, OutputLoc], output_list: list[PortID], ) -> Path: + """Write the call arguments for a worker task to storage. + + :param node_location: The location to write to. + :type node_location: Loc + :param function_name: The task name. + :type function_name: str + :param inputs: The inputs to the task (outputs from previous nodes). + :type inputs: dict[PortID, OutputLoc] + :param output_list: The list of output ports of the task. + :type output_list: list[PortID] + :return: The respective path in storage. + :rtype: Path + """ call_args_path = self._worker_call_args_path(node_location) node_definition = WorkerCallArgs( function_name=function_name, @@ -162,6 +275,13 @@ def write_worker_call_args( return call_args_path.relative_to(self.tkr_dir) def read_worker_call_args(self, node_location: Loc) -> WorkerCallArgs: + """Read the worker call arguments from storage. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The call arguments. + :rtype: WorkerCallArgs + """ node_definition_path = self._worker_call_args_path(node_location) return WorkerCallArgs(**json.loads(self.read(node_definition_path))) @@ -172,30 +292,81 @@ def link_outputs( old_location: Loc, old_port: PortID, ) -> None: + """Link an output from one node to another. + + Linking ensures that the values at both locations are the same. + + :param new_location: The new location to write to. + :type new_location: Loc + :param new_port: The port to link to. + :type new_port: PortID + :param old_location: The old location to read from. + :type old_location: Loc + :param old_port: The old location to link from. + :type old_port: PortID + :raises TierkreisError: If linking is not possible. + """ new_dir = self._output_path(new_location, new_port) try: self.link(self._output_path(old_location, old_port), new_dir) - except EntryNotFound: - logger.info( - f"Could not find {old_location}. " - "Tasks using this location will try to use a default value if specified." + except EntryNotFoundError: + logger.warning( + "Could not find %s. Tasks using this" + " location will try to use a default value if specified.", + old_location, ) except OSError as e: + msg = "Workflow already exists." + "Try running with a different ID or do_cleanup." raise TierkreisError( - "Workflow already exists. Try running with a different ID or do_cleanup." + msg, ) from e def write_output( - self, node_location: Loc, output_name: PortID, value: bytes + self, + node_location: Loc, + output_name: PortID, + value: bytes, ) -> Path: + """Write the value of an output to storage. + + :param node_location: The location to write to. + :type node_location: Loc + :param output_name: The port for which to write the value. + :type output_name: PortID + :param value: The value to write. + :type value: bytes + :return: The respective path in storage. + :rtype: Path + """ output_path = self._output_path(node_location, output_name) self.write(output_path, bytes(value)) return output_path def read_output(self, node_location: Loc, output_name: PortID) -> bytes: + """Read the named output at the node location. + + :param node_location: The location to read from. + :type node_location: Loc + :param output_name: The port identifying the output. + :type output_name: PortID + :return: The value at the port. + :rtype: bytes + """ return self.read(self._output_path(node_location, output_name)) - def read_errors(self, node_location: Loc = Loc()) -> str: + def read_errors(self, node_location: Loc | None = None) -> str: + """Read the errors that occurred at the node location. + + Only valid for function nodes (tasks) and the top level node ("-"). + + :param node_location: The location to read from, defaults to None + :type node_location: Loc | None, optional + :return: The error message that occurred. "" if nothing was logged. + :rtype: str + """ + if node_location is None: + node_location = Loc() if self.exists(self._worker_logs_path(node_location)): return self.read(self._worker_logs_path(node_location)).decode() if self.exists(self._error_path(node_location)): @@ -203,63 +374,166 @@ def read_errors(self, node_location: Loc = Loc()) -> str: return "" def write_node_errors(self, node_location: Loc, error_logs: str) -> None: + """Write the errors of a node to storage. + + Only valid for function nodes (tasks) and the top level node ("-"). + + :param node_location: The location to write to. + :type node_location: Loc + :param error_logs: The error message to write. + :type error_logs: str + """ + if node_location == Loc(): + self.write(self._error_path(node_location), error_logs.encode()) + return self.write(self._worker_logs_path(node_location), error_logs.encode()) def read_output_ports(self, node_location: Loc) -> list[PortID]: + """Read the list of named outputs of a node in storage. + + :param node_location: The location to read from. + :type node_location: Loc + :return: A list of output names. + :rtype: list[PortID] + """ dir_list = self.list_subpaths(self._outputs_dir(node_location)) dir_list.sort() return [x.name for x in dir_list] def is_node_started(self, node_location: Loc) -> bool: + """Check whether a node is started. + + A node is started <==> the controller has written its definition. + + :param node_location: The location to check. + :type node_location: Loc + :return: True if the node is started. + :rtype: bool + """ return self.exists(Path(self._nodedef_path(node_location))) def is_node_finished(self, node_location: Loc) -> bool: + """Check whether a node is finished. + + A node is finished <==> A _done file/marker is set. + + :param node_location: The location to check. + :type node_location: Loc + :return: True if the node is finished. + :rtype: bool + """ return self.exists(self._done_path(node_location)) def latest_loop_iteration(self, loc: Loc) -> Loc: + """Find the latest iteration location of a loop node. + + :param loc: The location to check. + :type loc: Loc + :return: A location representing the latest iteration. + :rtype: Loc + """ i = 0 while self.is_node_started(loc.L(i + 1)): i += 1 return loc.L(i) def node_has_error(self, node_location: Loc) -> bool: + """Check whether a node has encountered an error. + + Only valid for function nodes (tasks) and the top level node ("-"). + A node is errored <==> A _error file/marker is set. + + :param node_location: The location to check. + :type node_location: Loc + :return: True if the node has an error. + :rtype: bool + """ return self.exists(self._error_path(node_location)) def mark_node_finished(self, node_location: Loc) -> None: + """Mark a node as successfully finished. + + :param node_location: The location to mark. + :type node_location: Loc + """ self.touch(self._done_path(node_location)) if (parent := node_location.parent()) is not None: self.touch(self._metadata_path(parent)) def write_metadata(self, node_location: Loc) -> None: - j = json.dumps({"name": self.name, "start_time": datetime.now().isoformat()}) - self.write(self._metadata_path(node_location), j.encode()) + """Write the metadata for a node. + + Currently metadata contains name and start time. + + :param node_location: The location to write to. + :type node_location: Loc + """ + json_string = json.dumps( + {"name": self.name, "start_time": datetime.now(UTC).isoformat()}, + ) + self.write(self._metadata_path(node_location), json_string.encode()) def read_metadata(self, node_location: Loc) -> dict[str, Any]: + """Read the metadata of a node. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The metadata stored at the location. + :rtype: dict[str, Any] + """ return json.loads(self.read(self._metadata_path(node_location))) def read_started_time(self, node_location: Loc) -> str | None: + """Read the start time of a node. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The time as is isoformat if the node has started. + :rtype: str | None + """ node_def = Path(self._nodedef_path(node_location)) if not self.exists(node_def): return None since_epoch = self.stat(node_def).st_mtime if since_epoch is None: return None - return datetime.fromtimestamp(since_epoch).isoformat() + return datetime.fromtimestamp(since_epoch, UTC).isoformat() def read_finished_time(self, node_location: Loc) -> str | None: + """Read the finish time of a node. + + :param node_location: The location to read from. + :type node_location: Loc + :return: The time as is isoformat if the node has finished. + :rtype: str | None + """ done = Path(self._done_path(node_location)) if not self.exists(done): return None since_epoch = self.stat(done).st_mtime if since_epoch is None: return None - return datetime.fromtimestamp(since_epoch).isoformat() + return datetime.fromtimestamp(since_epoch, UTC).isoformat() def read_loop_trace(self, node_location: Loc, output_name: PortID) -> list[bytes]: + """Read the trace of a loop node for a given output. + + A trace is a list of values v, where v[i] represents the value of the output at + iteration i. + + :param node_location: The location to read from. + :type node_location: Loc + :param output_name: The output to trace. + :type output_name: PortID + :raises TierkreisError: If the node at the location is not a loop node. + :return: A list of values considered the trace. + :rtype: list[bytes] + """ definition = self.read_node_def(node_location) if definition.type != "loop": - raise TierkreisError("Can only read traces from loop nodes.") + msg = "Can only read traces from loop nodes." + raise TierkreisError(msg) result = [] i = 0 @@ -269,58 +543,105 @@ def read_loop_trace(self, node_location: Loc, output_name: PortID) -> list[bytes return result def loc_from_node_name(self, node_name: str) -> Loc | None: + """Find the location of a node for a given name. + + Currently only loop nodes can be named. + Loop names are stored as debug data. + This can only be invoked after running the workflow. + + :param node_name: The name to search for. + :type node_name: str + :return: Returns the location if found in storage. + :rtype: Loc | None + """ debug_data = StorageDebugData(**self.read_debug_data(node_name)) if debug_data.loop_loc is not None: return Loc(debug_data.loop_loc) + return None + + def write_debug_data(self, name: str, node_location: Loc) -> None: + """Write the debug data of a node. + + Currently name is derived from a named loop. - def write_debug_data(self, name: str, loc: Loc) -> None: + :param name: The name to write to. + :type name: str + :param node_location: The location to write as debug information. + :type node_location: Loc + """ self.mkdir(self.debug_path) - data = StorageDebugData(loop_loc=loc) + data = StorageDebugData(loop_loc=node_location) self.write(self.debug_path / name, json.dumps(asdict(data)).encode()) def read_debug_data(self, name: str) -> dict[str, Any]: + """Read the debug data for a given name. + + Currently name is derived from a named loop. + + :param name: The name to read from. + :type name: str + :return: The data available for the name. + :rtype: dict[str, Any] + """ return json.loads(self.read(self.debug_path / name)) - def dependents(self, loc: Loc) -> set[Loc]: - """Nodes that are fully invalidated if the node at the given loc is invalidated. + def dependents(self, node_location: Loc) -> set[Loc]: + """Get the dependents (successors) of a node. - This does not include the direct parent Loc, which is only partially invalidated. + Dependents are nodes that are fully invalidated if + the node at the given loc is invalidated. + This does not include the direct parent Loc, + which is only partially invalidated. + + :param loc: The location to get the dependents for. + :type loc: Loc + :return: A set of dependent nodes. + :rtype: set[Loc] """ - descs: set[Loc] = set() - step, parent = loc.pop_last() + descendants: set[Loc] = set() + step, parent = node_location.pop_last() match step: case "-": pass case ("N", _): - nodedef = self.read_node_def(loc) + nodedef = self.read_node_def(node_location) if nodedef.type == "output": - descs.update(self.dependents(parent)) + descendants.update(self.dependents(parent)) for output_set in nodedef.outputs.values(): for output in output_set: - descs.add(parent.N(output)) - descs.update(self.dependents(parent.N(output))) + descendants.add(parent.N(output)) + descendants.update(self.dependents(parent.N(output))) case ("M", _): - descs.update(self.dependents(parent)) + descendants.update(self.dependents(parent)) case ("L", idx): latest_idx = self.latest_loop_iteration(parent).peek_index() - [descs.add(parent.L(i)) for i in range(idx + 1, latest_idx + 1)] - descs.update(self.dependents(parent)) + [descendants.add(parent.L(i)) for i in range(idx + 1, latest_idx + 1)] + descendants.update(self.dependents(parent)) case _: assert_never(step) - return descs + return descendants def restart_task(self, loc: Loc) -> list[Loc]: """Restart the task/function node at the given loc. - Fully dependent nodes will be removed from the storage. + A node is restarted by removing its (and its dependents) started flag. + The controller will then pick it up as not started. + Fully dependent nodes will be completely removed from the storage. The parent locs will be partially invalidated. - Returns the invalidated nodes.""" + Returns the invalidated nodes. + :param loc: _description_ + :type loc: Loc + :raises TierkreisError: If the node at the location is not a task. + :return: _description_ + :rtype: list[Loc] + """ nodedef = self.read_node_def(loc) if nodedef.type != "function": - raise TierkreisError("Can only restart task/function nodes.") + msg = "Can only restart task/function nodes." + raise TierkreisError(msg) # Remove fully invalidated nodes. deps = self.dependents(loc) @@ -331,7 +652,8 @@ def restart_task(self, loc: Loc) -> list[Loc]: [self.delete(self._done_path(x)) for x in partials] [self.delete(self.workflow_dir / a / "outputs") for a in partials] - # Mark given Loc as not started, so that the controller picks it up on the next tick. + # Mark given Loc as not started + # so that the controller picks it up on the next tick. self.delete(self._nodedef_path(loc)) return list(deps) diff --git a/tierkreis/tierkreis/controller/storage/walk.py b/tierkreis/tierkreis/controller/storage/walk.py index 0adf8809f..dfe9f4ae1 100644 --- a/tierkreis/tierkreis/controller/storage/walk.py +++ b/tierkreis/tierkreis/controller/storage/walk.py @@ -1,3 +1,11 @@ +"""Functions to walk a (computational) workflow graph. + +In conjunction with `start()` this is one of the primary functions of the +tierkreis controller. +By walking the graph we update nodes with new inputs from finished nodes until +they can be started. +""" + from dataclasses import dataclass, field from logging import getLogger from typing import assert_never @@ -24,11 +32,27 @@ @dataclass class WalkResult: + """Dataclass to keep track of the nodes we encounter during the walk. + + :fields: + inputs_ready (list[NodeRunData]): A list of nodes that now have all inputs ready + and therefore can be started. + started (list[Loc]): A list of locations that have been started (on this walk). + errored (list[Loc]): A list of locations that have encountered an error. + """ + inputs_ready: list[NodeRunData] started: list[Loc] errored: list[Loc] = field(default_factory=list[Loc]) def extend(self, walk_result: "WalkResult") -> None: + """Extend a current walk result with an existing one. + + Simply extends all three list fields accordingly. + + :param walk_result: The walk_result to update self with. + :type walk_result: WalkResult + """ self.inputs_ready.extend(walk_result.inputs_ready) self.started.extend(walk_result.started) self.errored.extend(walk_result.errored) @@ -41,19 +65,66 @@ def unfinished_results( node: NodeDef, graph: GraphData, ) -> int: + """Find and walk all the unfinished results. + + Finds all "blocking" nodes of the current nodes and marks them to be started. + Blocking nodes are inputs that are not done yet. + We walk recursively into the nodes that are not finished yet to progress them, + by marking them ready for starting or done. + + :param result: The walk result, where we add all unfinished nodes. + Used to bubble up the nodes from the recursive walk call. + :type result: WalkResult + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + e.g, the eval node containing the current node. + :type parent: Loc + :param node: The current node for which we check inputs. + :type node: NodeDef + :param graph: The graph to walk. + :type graph: GraphData + :return: The number of nodes that have unfinished inputs. + :rtype: int + """ unfinished = unfinished_inputs(storage, parent, node) [result.extend(walk_node(storage, parent, x[0], graph)) for x in unfinished] return len(unfinished) def walk_node( - storage: ControllerStorage, parent: Loc, idx: NodeIndex, graph: GraphData + storage: ControllerStorage, + parent: Loc, + idx: NodeIndex, + graph: GraphData, ) -> WalkResult: - """Should only be called when a node has not finished.""" + """Walk a graph node. + + Should only be called when a node has not finished. + This is one of the core functions of the controller. + It checks for a current node how to proceed: + - Continue if its already done + - Mark it for starting if its inputs are now ready and its not started. + - Mark the respective next nodes to start, e.g. in case of an eval start + the first nodes inside (that now have their inputs ready). + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + e.g, the eval node containing the current node. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param graph: The graph to walk. + :type graph: GraphData + :return: A summary list of finished, errored, and ready nodes. + :rtype: WalkResult + """ loc = parent.N(idx) if storage.node_has_error(loc): - logger.error(f"Node {loc} has encountered an error.") - logger.debug(f"\n\n{storage.read_errors(loc)}\n\n") + # we immediately stop if a node has an error and bubble the error up + logger.error("Node %s has encountered an error.", loc) + logger.debug("\n\n%s\n\n", storage.read_errors(loc)) return WalkResult([], [], [loc]) node = graph.nodes[idx] @@ -61,13 +132,20 @@ def walk_node( result = WalkResult([], []) if unfinished_results(result, storage, parent, node, graph): + # cannot start, don't have all inputs yet return result if not storage.is_node_started(loc): + # have all inputs, start current node return WalkResult([node_run_data], []) + # Handle cases where we have nested graphs. + # Basically we have to forward the now available outputs from outer scope + # into the scope of the nested graph, so we check inside the nested + # for new candidates. match node.type: case "eval": + # step inside the nested graph for walking message = storage.read_output(parent.N(node.graph[0]), node.graph[1]) g = ptype_from_bytes(message, GraphData) return walk_node(storage, loc, g.output_idx(), g) @@ -85,6 +163,7 @@ def walk_node( return walk_map(storage, parent, idx, node) case "ifelse": + # walk the next node only after we have the value of "pred" pred = storage.read_output(parent.N(node.pred[0]), node.pred[1]) next_node = node.if_true if pred == b"true" else node.if_false next_loc = parent.N(next_node[0]) @@ -92,13 +171,13 @@ def walk_node( storage.link_outputs(loc, Labels.VALUE, next_loc, next_node[1]) storage.mark_node_finished(loc) return WalkResult([], []) - else: - return walk_node(storage, parent, next_node[0], graph) + return walk_node(storage, parent, next_node[0], graph) case "eifelse": return walk_eifelse(storage, parent, idx, node) case "function": + # Current node can start, done will be marked by executor. return WalkResult([], [loc]) case "input": @@ -108,33 +187,66 @@ def walk_node( def walk_loop( - storage: ControllerStorage, parent: Loc, idx: NodeIndex, loop: Loop + storage: ControllerStorage, + parent: Loc, + idx: NodeIndex, + loop: Loop, ) -> WalkResult: + """Walk a loop node. + + The controller walks a loop by: + - checking the current iteration + - checking the `should_continue` port + - mapping outputs to inputs between iterations + - and finally producing the outputs. + + Each iteration is evaluated by inserting a virtual eval node + at the location PARENT_LOC.L.N-1 that then gets picked up by walk + and the next start step. + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + E.g. the "eval" containing this statement. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param loop: The loop node being walked. + :type loop: Loop + :return: A summary list of finished, errored, and ready nodes: + Either empty (loop is done), or next loop iteration. + :rtype: WalkResult + """ loc = parent.N(idx) if storage.is_node_finished(loc): - return WalkResult([], [], []) + return WalkResult([], []) + # find the last iteration new_location = storage.latest_loop_iteration(loc) - + # and read the graph definition message = storage.read_output(loc.N(-1), BODY_PORT) g = ptype_from_bytes(message, GraphData) loop_outputs = g.nodes[g.output_idx()].inputs - + # if the iteration is not done finished, walk its nested graph (body) if not storage.is_node_finished(new_location): return walk_node(storage, new_location, g.output_idx(), g) # Latest iteration is finished. Do we BREAK or CONTINUE? should_continue = ptype_from_bytes( - storage.read_output(new_location, loop.continue_port), bool + storage.read_output(new_location, loop.continue_port), + bool, ) if should_continue is False: + # were done here, set outputs to parent scope for k in loop_outputs: storage.link_outputs(loc, k, new_location, k) storage.mark_node_finished(loc) return WalkResult([], []) - ins = {k: (-1, k) for k in loop.inputs.keys()} + # continue looping, provide the inputs for the next iter from the current + ins = {k: (-1, k) for k in loop.inputs} ins.update(loop_outputs) + # Mark the next node as ready node_run_data = NodeRunData( loc.L(new_location.peek_index() + 1), Eval((-1, BODY_PORT), ins, outputs=loop.outputs), @@ -144,26 +256,56 @@ def walk_loop( def walk_map( - storage: ControllerStorage, parent: Loc, idx: NodeIndex, map: Map + storage: ControllerStorage, + parent: Loc, + idx: NodeIndex, + map_node: Map, ) -> WalkResult: + """Walk a map node. + + We refer to the evaluation of the graph with a set of inputs as map elements. + I. e. one map consist map fun [a,b,c] will have the element fun a, fun b, fun c. + Each of these elements are treated as an virtual eval + at the location PARENT_LOC.M.N-1 that then gets picked up by walk. + In contrast to loop, as all map elements can be immediately started this is set + in the start function, which makes the virtual evals optional; + They are treated differently than evals/loops. + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + E.g. the "eval" containing this statement. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param map_node: The map node being walked. + :type map_node: Map + :return: A summary list of finished, errored, and ready nodes: + Either empty (map is done), or all intermediate nodes in the map elements. + :rtype: WalkResult + """ loc = parent.N(idx) result = WalkResult([], []) if storage.is_node_finished(loc): return result - first_ref = next(x for x in map.inputs.values() if x[1] == "*") + # find all values to map over + first_ref = next(x for x in map_node.inputs.values() if x[1] == "*") map_eles = outputs_iter(storage, parent.N(first_ref[0])) + # find all map elements that are not done unfinished = [i for i, _ in map_eles if not storage.is_node_finished(loc.M(i))] + # Read the graph def message = storage.read_output(loc.M(0).N(-1), BODY_PORT) g = ptype_from_bytes(message, GraphData) + # Walk all map elements simultaneously [result.extend(walk_node(storage, loc.M(p), g.output_idx(), g)) for p in unfinished] if len(unfinished) > 0: return result - + # All map elements are done, mark the entire map done map_outputs = g.nodes[g.output_idx()].inputs for i, j in map_eles: - for output in map_outputs.keys(): + for output in map_outputs: storage.link_outputs(loc, f"{output}-{j}", loc.M(i), output) storage.mark_node_finished(loc) @@ -176,6 +318,24 @@ def walk_eifelse( idx: NodeIndex, node: EagerIfElse, ) -> WalkResult: + """Walk an eager if else node. + + In an eager if else node we have already evaluated all its + inputs (pred, if_true, if_false). + Therefore we just need to move the correct inputs to its outputs. + + :param storage: The storage to write to. + :type storage: ControllerStorage + :param parent: The parent node of the current node. + E.g. the "eval" containing this statement. + :type parent: Loc + :param idx: The index (respective to the parent) of the current node. + :type idx: NodeIndex + :param node: The eager if else node being walked. + :type node: EagerIfElse + :return: An empty walk result since we here nothing else is to do. + :rtype: WalkResult + """ loc = parent.N(idx) pred = storage.read_output(parent.N(node.pred[0]), node.pred[1]) next_node = node.if_true if pred == b"true" else node.if_false diff --git a/tierkreis/tierkreis/exceptions.py b/tierkreis/tierkreis/exceptions.py index 2e1c8357d..f98859773 100644 --- a/tierkreis/tierkreis/exceptions.py +++ b/tierkreis/tierkreis/exceptions.py @@ -1,2 +1,5 @@ +"""Tierkreis exception definitions.""" + + class TierkreisError(Exception): """An error thrown in the Tierkreis library.""" diff --git a/tierkreis/tierkreis/executor.py b/tierkreis/tierkreis/executor.py index 286c236ef..5cdc6d49e 100644 --- a/tierkreis/tierkreis/executor.py +++ b/tierkreis/tierkreis/executor.py @@ -1,13 +1,15 @@ -from tierkreis.controller.executor.shell_executor import ShellExecutor -from tierkreis.controller.executor.uv_executor import UvExecutor -from tierkreis.controller.executor.multiple import MultipleExecutor +"""Tierkreis executors definitions.""" + from tierkreis.controller.executor.hpc.pjsub import PJSUBExecutor +from tierkreis.controller.executor.multiple import MultipleExecutor +from tierkreis.controller.executor.shell_executor import ShellExecutor from tierkreis.controller.executor.task_executor import TaskExecutor +from tierkreis.controller.executor.uv_executor import UvExecutor __all__ = [ - "ShellExecutor", - "UvExecutor", "MultipleExecutor", "PJSUBExecutor", + "ShellExecutor", "TaskExecutor", + "UvExecutor", ] diff --git a/tierkreis/tierkreis/graphs/__init__.py b/tierkreis/tierkreis/graphs/__init__.py index e69de29bb..0b23ab3ca 100644 --- a/tierkreis/tierkreis/graphs/__init__.py +++ b/tierkreis/tierkreis/graphs/__init__.py @@ -0,0 +1 @@ +"""Preconstructed graphs for reuse.""" diff --git a/tierkreis/tierkreis/graphs/fold.py b/tierkreis/tierkreis/graphs/fold.py index 274cabb98..4e1018600 100644 --- a/tierkreis/tierkreis/graphs/fold.py +++ b/tierkreis/tierkreis/graphs/fold.py @@ -1,4 +1,7 @@ -from typing import Generic, NamedTuple, TypeVar +"""Preconstructed graph for folding operations.""" + +from typing import NamedTuple, TypeVar + from tierkreis.builder import GraphBuilder, TypedGraphRef from tierkreis.builtins.stubs import head, igt, tkr_len from tierkreis.controller.data.graph import GraphData @@ -6,25 +9,28 @@ from tierkreis.controller.data.types import PType -class FoldGraphOuterInputs[A: PType, B: PType](NamedTuple): +class _FoldGraphOuterInputs[A: PType, B: PType](NamedTuple): func: TKR[GraphData] accum: TKR[B] values: TKR[list[A]] -class FoldGraphOuterOutputs[A: PType, B: PType](NamedTuple): +class _FoldGraphOuterOutputs[A: PType, B: PType](NamedTuple): accum: TKR[B] values: TKR[list[A]] should_continue: TKR[bool] -class InnerFuncInput[A: PType, B: PType](NamedTuple): +class _InnerFuncInput[A: PType, B: PType](NamedTuple): accum: TKR[B] value: TKR[A] -def _fold_graph_outer[A: PType, B: PType](): - g = GraphBuilder(FoldGraphOuterInputs[A, B], FoldGraphOuterOutputs[A, B]) +def _fold_graph_outer[A: PType, B: PType]() -> GraphBuilder[ + _FoldGraphOuterInputs[A, B], + _FoldGraphOuterOutputs[A, B], +]: + g = GraphBuilder(_FoldGraphOuterInputs[A, B], _FoldGraphOuterOutputs[A, B]) func = g.inputs.func accum = g.inputs.accum @@ -38,41 +44,67 @@ def _fold_graph_outer[A: PType, B: PType](): headed = g.task(head(values)) # Apply the function if we were able to pop off a value. - tgd = TypedGraphRef[InnerFuncInput, TKR[B]]( - func.value_ref(), TKR[B], InnerFuncInput + tgd = TypedGraphRef[_InnerFuncInput, TKR[B]]( + func.value_ref(), + TKR[B], + _InnerFuncInput, ) - applied_next = g.eval(tgd, InnerFuncInput(accum, headed.head)) + applied_next = g.eval(tgd, _InnerFuncInput(accum, headed.head)) next_accum = g.ifelse(non_empty, applied_next, accum) next_values = g.ifelse(non_empty, headed.rest, values) - g.outputs(FoldGraphOuterOutputs(next_accum, next_values, non_empty)) + g.outputs(_FoldGraphOuterOutputs(next_accum, next_values, non_empty)) return g -A = TypeVar("A", bound=PType, covariant=True) -B = TypeVar("B", bound=PType, covariant=True) +A_co = TypeVar("A_co", bound=PType, covariant=True) +B_co = TypeVar("B_co", bound=PType, covariant=True) + +class FoldGraphInputs[A: PType, B: PType](NamedTuple): + """Inputs to a fold graph. + + :fields: + initial (B): The initial value. + values (list[A]): The list of values to fold over. + """ -class FoldGraphInputs(NamedTuple, Generic[A, B]): initial: TKR[B] values: TKR[list[A]] -class FoldFunctionInput(NamedTuple, Generic[A, B]): +class FoldFunctionInput[A: PType, B: PType](NamedTuple): + """Input type of a fold function. + + :fields: + accum (B): The accumulator. + value (A): The current value. + """ + accum: TKR[B] value: TKR[A] -# fold : {func: (b -> a -> b)} -> {initial: b} -> {values: list[a]} -> {value: b} -# fold : { A x B -> B } -> { list[A] x B -> B } -def fold_graph( - func: GraphBuilder[FoldFunctionInput[A, B], TKR[B]], -) -> GraphBuilder[FoldGraphInputs[A, B], TKR[B]]: - g = GraphBuilder(FoldGraphInputs[A, B], TKR[B]) - foldfunc = g._graph_const(func) - # TODO: include the computation inside the fold - ins = FoldGraphOuterInputs( - TKR(*foldfunc.graph_ref), g.inputs.initial, g.inputs.values +def fold_graph[A_co: PType, B_co: PType]( + func: GraphBuilder[FoldFunctionInput[A_co, B_co], TKR[B_co]], +) -> GraphBuilder[FoldGraphInputs[A_co, B_co], TKR[B_co]]: + """Construct a fold graph. + + fold : {func: (b -> a -> b)} -> {initial: b} -> {values: list[a]} -> {value: b} + fold : { A x B -> B } -> { list[A] x B -> B } + + :param func: The function to fold over. + :type func: GraphBuilder[FoldFunctionInput[A_co, B_co], TKR[B_co]] + :return: A graph implementing the fold function. + :rtype: GraphBuilder[FoldGraphInputs[A_co, B_co], TKR[B_co]] + """ + g = GraphBuilder(FoldGraphInputs[A_co, B_co], TKR[B_co]) + foldfunc = g._graph_const(func) # noqa: SLF001 + # TODO @mwpb: include the computation inside the fold + ins = _FoldGraphOuterInputs( + TKR(*foldfunc.graph_ref), + g.inputs.initial, + g.inputs.values, ) loop = g.loop(_fold_graph_outer(), ins) g.outputs(loop.accum) diff --git a/tierkreis/tierkreis/graphs/nexus/__init__.py b/tierkreis/tierkreis/graphs/nexus/__init__.py new file mode 100644 index 000000000..0a94d4f45 --- /dev/null +++ b/tierkreis/tierkreis/graphs/nexus/__init__.py @@ -0,0 +1 @@ +"""Nexus graphs.""" diff --git a/tierkreis/tierkreis/graphs/nexus/submit_poll.py b/tierkreis/tierkreis/graphs/nexus/submit_poll.py index 622b228b6..0d779bed0 100644 --- a/tierkreis/tierkreis/graphs/nexus/submit_poll.py +++ b/tierkreis/tierkreis/graphs/nexus/submit_poll.py @@ -1,27 +1,46 @@ +"""Sample graphs to interact with nexus using the Nexus Worker.""" + # ruff: noqa: F821 from typing import NamedTuple + from tierkreis.builder import GraphBuilder from tierkreis.builtins.stubs import tkr_sleep from tierkreis.controller.data.models import TKR, OpaqueType from tierkreis.nexus_worker import ( - upload_circuit, - start_execute_job, - is_running, get_results, + is_running, + start_execute_job, + upload_circuit, ) -type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] +type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] # noqa: SLF001 type BackendResult = OpaqueType["pytket.backends.backendresult.BackendResult"] type ExecuteJobRef = OpaqueType["qnexus.models.references.ExecuteJobRef"] -type ExecutionProgram = OpaqueType["qnexus.models.references.ExecuteJobRef"] class UploadCircuitInputs(NamedTuple): + """The inputs to upload a circuit. + + :fields: + project_name (str): The name of the project to upload to. + circuit (Circuit): The tket circuit to upload. + """ + project_name: TKR[str] circuit: TKR[Circuit] class JobInputs(NamedTuple): + """The inputs to a nexus job. + + :fields: + project_name (str): The name of the project to upload to. + job_name (str): The name of the job. + circuit (list[Circuit]): The list of circuits part of the job. + n_shots (int): The number of shots (repetitions) of each circuit. + backend_config (BackendConfig): The qnexus configuration of the backend. + """ + project_name: TKR[str] job_name: TKR[str] circuits: TKR[list[Circuit]] @@ -29,33 +48,54 @@ class JobInputs(NamedTuple): backend_config: TKR[OpaqueType["qnexus.BackendConfig"]] -class LoopOutputs(NamedTuple): +class _LoopOutputs(NamedTuple): results: TKR[list[BackendResult]] should_continue: TKR[bool] -def upload_circuit_graph(): - g = GraphBuilder(UploadCircuitInputs, TKR[ExecutionProgram]) +def upload_circuit_graph() -> GraphBuilder[UploadCircuitInputs, TKR[ExecuteJobRef]]: + """Construct a graph to upload a circuit to nexus. + + :return: A uploading graph. + :rtype: GraphBuilder[UploadCircuitInputs, TKR[ExecuteJobRef]] + """ + g = GraphBuilder(UploadCircuitInputs, TKR[ExecuteJobRef]) programme = g.task(upload_circuit(g.inputs.project_name, g.inputs.circuit)) - g.outputs(programme) # type: ignore + g.outputs(programme) # type: ignore[arg-type] return g -def polling_loop_body(polling_interval: float): - g = GraphBuilder(TKR[ExecuteJobRef], LoopOutputs) +def _polling_loop_body( + polling_interval: float, +) -> GraphBuilder[TKR[ExecuteJobRef], _LoopOutputs]: + g = GraphBuilder(TKR[ExecuteJobRef], _LoopOutputs) pred = g.task(is_running(g.inputs)) - wait = g.ifelse(pred, g.task(tkr_sleep(g.const(polling_interval))), g.const(False)) + wait = g.ifelse( + pred, + g.task(tkr_sleep(g.const(polling_interval))), + g.const(value=False), + ) results = g.ifelse(pred, g.const([]), g.task(get_results(g.inputs))) - g.outputs(LoopOutputs(results=results, should_continue=wait)) + g.outputs(_LoopOutputs(results=results, should_continue=wait)) return g -def nexus_submit_and_poll(polling_interval: float = 30.0): +def nexus_submit_and_poll( + polling_interval: float = 30.0, +) -> GraphBuilder[JobInputs, TKR[list[BackendResult]]]: + """Construct a graph submitting and polling a nexus job. + + :param polling_interval: The polling interval in seconds, defaults to 30.0 + :type polling_interval: float, optional + :return: A graph performing submission and polling. + :rtype: GraphBuilder[JobInputs, TKR[list[BackendResult]]] + """ g = GraphBuilder(JobInputs, TKR[list[BackendResult]]) upload_inputs = g.map( - lambda x: UploadCircuitInputs(g.inputs.project_name, x), g.inputs.circuits + lambda x: UploadCircuitInputs(g.inputs.project_name, x), + g.inputs.circuits, ) programmes = g.map(upload_circuit_graph(), upload_inputs) @@ -63,12 +103,12 @@ def nexus_submit_and_poll(polling_interval: float = 30.0): start_execute_job( g.inputs.project_name, g.inputs.job_name, - programmes, # type: ignore + programmes, # type: ignore[arg-type] g.inputs.n_shots, - g.inputs.backend_config, # type: ignore - ) + g.inputs.backend_config, # type: ignore[arg-type] + ), ) - res = g.loop(polling_loop_body(polling_interval), ref) + res = g.loop(_polling_loop_body(polling_interval), ref) g.outputs(res.results) return g diff --git a/tierkreis/tierkreis/graphs/simulate/__init__.py b/tierkreis/tierkreis/graphs/simulate/__init__.py index e69de29bb..9647a1a74 100644 --- a/tierkreis/tierkreis/graphs/simulate/__init__.py +++ b/tierkreis/tierkreis/graphs/simulate/__init__.py @@ -0,0 +1 @@ +"""Simulation Graphs.""" diff --git a/tierkreis/tierkreis/graphs/simulate/compile_simulate.py b/tierkreis/tierkreis/graphs/simulate/compile_simulate.py index 77e9d003d..bddf0ff36 100644 --- a/tierkreis/tierkreis/graphs/simulate/compile_simulate.py +++ b/tierkreis/tierkreis/graphs/simulate/compile_simulate.py @@ -1,23 +1,38 @@ +"""Sample graphs to simulate quantum circuits on different backends.""" + # ruff: noqa: F821 from typing import Literal, NamedTuple -from tierkreis.builder import GraphBuilder -from tierkreis.controller.data.models import TKR, OpaqueType -from tierkreis.builtins.stubs import tkr_zip, untuple + from tierkreis.aer_worker import ( get_compiled_circuit as aer_compile, +) +from tierkreis.aer_worker import ( run_circuit as aer_run, ) +from tierkreis.builder import GraphBuilder +from tierkreis.builtins.stubs import str_eq, tkr_zip, untuple +from tierkreis.controller.data.models import TKR, OpaqueType from tierkreis.qulacs_worker import ( get_compiled_circuit as qulacs_compile, +) +from tierkreis.qulacs_worker import ( run_circuit as qulacs_run, ) -from tierkreis.builtins.stubs import str_eq type BackendResult = OpaqueType["pytket.backends.backendresult.BackendResult"] -type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] +type Circuit = OpaqueType["pytket._tket.circuit.Circuit"] # noqa: SLF001 class SimulateJobInputs(NamedTuple): + """Input to simulate multiple quantum circuits on a local backend. + + :fields: + simulator_name (Literal): either 'aer' or 'qulacs'. + circuits (list[Ciruit]): The list of circuits to run. + n_shots (int): The number of shots (repetitions) of each circuit. + compilation_optimisation_level (int): tket optimisation level in [0,1,2,3]. + """ + simulator_name: TKR[Literal["aer", "qulacs"]] circuits: TKR[list[Circuit]] n_shots: TKR[list[int]] @@ -25,12 +40,27 @@ class SimulateJobInputs(NamedTuple): class SimulateJobInputsSingle(NamedTuple): + """Input to simulate multiple quantum circuits on a local backend. + + :fields: + simulator_name (Literal): either 'aer' or 'qulacs'. + circuit_shots (tuple[Ciruit, int]): The circuits to run and the number of shots. + compilation_optimisation_level (int): tket optimisation level in [0,1,2,3]. + """ + simulator_name: TKR[Literal["aer", "qulacs"]] circuit_shots: TKR[tuple[Circuit, int]] compilation_optimisation_level: TKR[int] -def aer_simulate_single(): +def aer_simulate_single() -> GraphBuilder[SimulateJobInputsSingle, TKR[BackendResult]]: + """Construct a graph to simulate a single circuit using qiskit aer. + + This ignores the simulator_name field. + + :return: The graph for the simulation. + :rtype: GraphBuilder[SimulateJobInputsSingle, TKR[BackendResult]] + """ g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult]) circuit_shots = g.task(untuple(g.inputs.circuit_shots)) @@ -38,14 +68,24 @@ def aer_simulate_single(): aer_compile( circuit=circuit_shots.a, optimisation_level=g.inputs.compilation_optimisation_level, - ) + ), ) res = g.task(aer_run(compiled_circuit, circuit_shots.b)) g.outputs(res) return g -def qulacs_simulate_single(): +def qulacs_simulate_single() -> GraphBuilder[ + SimulateJobInputsSingle, + TKR[BackendResult], +]: + """Construct a graph to simulate a single circuit using qulacs. + + This ignores the simulator_name field. + + :return: The graph for the simulation. + :rtype: GraphBuilder[SimulateJobInputsSingle, TKR[BackendResult]] + """ g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult]) circuit_shots = g.task(untuple(g.inputs.circuit_shots)) @@ -53,27 +93,42 @@ def qulacs_simulate_single(): qulacs_compile( circuit=circuit_shots.a, optimisation_level=g.inputs.compilation_optimisation_level, - ) + ), ) res = g.task(qulacs_run(compiled_circuit, circuit_shots.b)) g.outputs(res) return g -def compile_simulate_single(): +def compile_simulate_single() -> GraphBuilder[ + SimulateJobInputsSingle, + TKR[BackendResult], +]: + """CConstruct a graph to simulate a single job on either aer or qulacs. + + :return: The graph for the simulation. + :rtype: GraphBuilder[ SimulateJobInputsSingle, TKR[BackendResult], ] + """ g = GraphBuilder(SimulateJobInputsSingle, TKR[BackendResult]) aer_res = g.eval(aer_simulate_single(), g.inputs) qulacs_res = g.eval(qulacs_simulate_single(), g.inputs) res = g.ifelse( - g.task(str_eq(g.inputs.simulator_name, g.const("aer"))), aer_res, qulacs_res + g.task(str_eq(g.inputs.simulator_name, g.const("aer"))), + aer_res, + qulacs_res, ) g.outputs(res) return g -def compile_simulate(): +def compile_simulate() -> GraphBuilder[SimulateJobInputs, TKR[list[BackendResult]]]: + """Construct a graph to simulate multiple jobs on either aer or qulacs. + + :return: The graph for the simulation. + :rtype: GraphBuilder[SimulateJobInputs, TKR[list[BackendResult]]] + """ g = GraphBuilder(SimulateJobInputs, TKR[list[BackendResult]]) circuits_shots = g.task(tkr_zip(g.inputs.circuits, g.inputs.n_shots)) diff --git a/tierkreis/tierkreis/hpc.py b/tierkreis/tierkreis/hpc.py index dbb60488d..4839a951b 100644 --- a/tierkreis/tierkreis/hpc.py +++ b/tierkreis/tierkreis/hpc.py @@ -1,3 +1,5 @@ -from tierkreis.controller.executor.hpc.job_spec import JobSpec, ResourceSpec, MpiSpec +"""Tierkreis HPC utilities.""" -__all__ = ["JobSpec", "ResourceSpec", "MpiSpec"] +from tierkreis.controller.executor.hpc.job_spec import JobSpec, MpiSpec, ResourceSpec + +__all__ = ["JobSpec", "MpiSpec", "ResourceSpec"] diff --git a/tierkreis/tierkreis/idl/__init__.py b/tierkreis/tierkreis/idl/__init__.py new file mode 100644 index 000000000..728c8171a --- /dev/null +++ b/tierkreis/tierkreis/idl/__init__.py @@ -0,0 +1 @@ +"""Parsing for the Tierkreis typespec for external workers.""" diff --git a/tierkreis/tierkreis/idl/models.py b/tierkreis/tierkreis/idl/models.py index 8bce83804..66e167e7a 100644 --- a/tierkreis/tierkreis/idl/models.py +++ b/tierkreis/tierkreis/idl/models.py @@ -1,31 +1,43 @@ +"""Tierkreis IDL models representation used for TSP parsing.""" + +from collections.abc import Mapping, Sequence from dataclasses import dataclass from types import NoneType -from typing import Annotated, Mapping, Self, Sequence, get_args, get_origin +from typing import Annotated, Self, get_args, get_origin from tierkreis.controller.data.core import RestrictedNamedTuple from tierkreis.controller.data.types import _is_generic - type ElementaryType = ( - type[int] - | type[float] - | type[bytes] - | type[str] - | type[bool] - | type[NoneType] - | type[Mapping] - | type[Sequence] + type[int | float | bytes | str | bool | NoneType | Mapping | Sequence] | str # Custom type e.g. forward reference ) @dataclass class GenericType: + """A Tierkreis worker generic type. + + Represents a single (composed) type in worker definitions. + + :fields: + origin (ElementaryType): The base type, e.g., str in list[str]. + args: (Sequence[GenericType | str]) The nested types. + e.g., list[str] in set[list[str]] + """ + origin: ElementaryType args: "Sequence[GenericType | str]" @classmethod def from_type(cls, t: type) -> "Self": + """Construct a generic type from a python type. + + :param t: The python type. + :type t: type + :return: The Tierkreis type. + :rtype: Self + """ if get_origin(t) is Annotated: return cls.from_type(get_args(t)[0]) @@ -44,19 +56,49 @@ def _included_structs(cls, t: "GenericType") -> "set[GenericType]": return outs def included_structs(self) -> "set[GenericType]": + """Find the included structs of this type. + + A struct is an instance of RestrictedNamedTuple or opaque strings. + :return: The list of structs + :rtype: set[GenericType] + """ return GenericType._included_structs(self) def __hash__(self) -> int: + """Produce a hash of the generic type. + + :return: The hash. + :rtype: int + """ return hash(self.origin) def __eq__(self, value: object) -> bool: + """Check the equality of self with an object. + + self == value <==> self.origin == value.origin + + + :param other: The object to compare to. + :type other: object + :return: If bothe object have the same origin. + :rtype: bool + """ if not hasattr(value, "origin"): return False - return self.origin == getattr(value, "origin") + return self.origin == value.origin @dataclass class TypedArg: + """A Tierkreis worker method argument. + + Represents a single argument to a tasks in a worker. + :fields: + name (str): The argument name. + t (GenericType): The argument type. + has_default(bool): Whether the argument has a default value. + """ + name: str t: GenericType has_default: bool = False @@ -64,6 +106,17 @@ class TypedArg: @dataclass class Method: + """A Tierkreis worker method. + + Represents a tasks in a worker. + + :fields: + name (str): The method name. + args (list[TypedArg]): The list of method arguments. + return_type (GenericType): The method return type. + return_type_is_portmapping (bool): Whether the return_type is a portmapping. + """ + name: GenericType args: list[TypedArg] return_type: GenericType @@ -72,18 +125,51 @@ class Method: @dataclass class Interface: + """A Tierkreis worker interface. + + Represents a list of tasks contained in the worker. + + :fields: + name (str): The worker name. + methods (list[Method]): The available tasks in the worker. + """ + name: str methods: list[Method] @dataclass class Model: + """A Tierkreis worker model. + + Represents a type in a worker. + + :fields: + is_portmapping (bool): Whether the model is a portmapping. + t (GenericType): The type of the model. + decl (list[TypedArg]) The list of its typed arguments. + """ + is_portmapping: bool t: GenericType decls: list[TypedArg] def __hash__(self) -> int: + """Produce a hash of the model. + + :return: The hash. + :rtype: int + """ return hash(self.t.origin) - def __lt__(self, other: "Model"): + def __lt__(self, other: "Model") -> bool: + """Check order of two models. + + Uses lexicographical ordering of the origin of the models (generic) types. + + :param other: The model to compare to. + :type other: Model + :return: If self comes before other. + :rtype: bool + """ return str(self.t.origin) < str(other.t.origin) diff --git a/tierkreis/tierkreis/idl/parser.py b/tierkreis/tierkreis/idl/parser.py index f6e76ee02..3e5276629 100644 --- a/tierkreis/tierkreis/idl/parser.py +++ b/tierkreis/tierkreis/idl/parser.py @@ -5,31 +5,51 @@ doesn't type check things correctly. """ +from __future__ import annotations + +import contextlib import re -from typing import Callable, Never, overload +from typing import TYPE_CHECKING, Any, Never, overload +if TYPE_CHECKING: + from collections.abc import Callable from tierkreis.exceptions import TierkreisError -class ParserError(TierkreisError): ... +class ParserError(TierkreisError): + """An error raised when parsing fails in Tierkreis.""" class Parser[T]: + """A parser for an arbitrary type in tierkreis. + + :fields: + fn: The parsing function. + """ + fn: Callable[[str], tuple[T, str]] - def __init__(self, fn: Callable[[str], tuple[T, str]]): + def __init__(self, fn: Callable[[str], tuple[T, str]]) -> None: self.fn = fn def __call__(self, ins: str) -> tuple[T, str]: + """Call the parses on a string. + + :param ins: The string to parse. + :type ins: str + :return: The parsed string and its type. + :rtype: tuple[T, str] + """ ins = ins.strip() return self.fn(ins) def __or__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[T|S]": + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[T | S]: """Try the left parser and only if it fails try the right parser.""" - def f(ins: str): + def f(ins: str) -> tuple[T, str] | tuple[S, str]: try: return self(ins) except ParserError: @@ -38,11 +58,12 @@ def f(ins: str): return Parser(f) def __and__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[tuple[T,S]]": + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[tuple[T, S]]: """Use the left parser and then use the right parser on the remaining input.""" - def f(ins: str): + def f(ins: str) -> tuple[tuple[T, S], str]: s, remaining = self(ins) t, remaining = other(remaining) return (s, t), remaining @@ -50,11 +71,16 @@ def f(ins: str): return Parser(f) def __lshift__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[T]": - """Use the left parser and then the right parser but discard the result of the right parser.""" + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[T]: + """Leftshift parsers. + + Use the left parser and then the right parser + but discard the result of the right parser. + """ - def f(ins: str): + def f(ins: str) -> tuple[T, str]: t, remaining = self(ins) _, remaining = other(remaining) return t, remaining @@ -62,21 +88,26 @@ def f(ins: str): return Parser(f) def __rshift__[S]( - self, other: "Parser[S]" | Callable[[str], tuple[S, str]] - ) -> "Parser[S]": - """Use the left parser and then the right parser but discard the result of the left parser.""" + self, + other: Parser[S] | Callable[[str], tuple[S, str]], + ) -> Parser[S]: + """Rightshift parsers. - def f(ins: str): + Use the left parser and then the right parser + but discard the result of the left parser. + """ + + def f(ins: str) -> tuple[S, str]: _, remaining = self(ins) s, remaining = other(remaining) return s, remaining return Parser(f) - def opt(self) -> "Parser[T|None]": + def opt(self) -> Parser[T | None]: """Make the parser optional; if it fails then return None and carry on.""" - def f(ins: str): + def f(ins: str) -> tuple[T, str] | tuple[None, str]: try: return self(ins) except ParserError: @@ -84,41 +115,41 @@ def f(ins: str): return Parser(f) - def map[A](self, fn: Callable[[T], A]) -> "Parser[A]": + def map[A](self, fn: Callable[[T], A]) -> Parser[A]: """Apply `fn` to transform the output of the parser.""" - def f(ins: str): + def f(ins: str) -> tuple[A, str]: t, remaining = self(ins) return fn(t), remaining return Parser(f) - def coerce[A](self, a: A) -> "Parser[A]": + def coerce[A](self, a: A) -> Parser[A]: """Shorthand for maps that don't need an argument. - Not strictly speaking required.""" + Not strictly speaking required. + """ - def f(ins: str): - t, remaining = self(ins) + def f(ins: str) -> tuple[A, str]: + _t, remaining = self(ins) return a, remaining return Parser(f) - def rep(self, sep: "Parser[str] | None" = None) -> "Parser[list[T]]": + def rep(self, sep: Parser[str] | None = None) -> Parser[list[T]]: """Repeatedly apply a parser with an optional separator. - The results of the separator parser are discarded.""" + The results of the separator parser are discarded. + """ - def f(ins: str): + def f(ins: str) -> tuple[list[T], str]: outs: list[T] = [] while True: try: t, ins = self(ins) if sep: - try: + with contextlib.suppress(ParserError): _, ins = sep(ins) - except ParserError: - pass outs.append(t) except ParserError: break @@ -126,14 +157,16 @@ def f(ins: str): return Parser(f) - def fail(self, entity: str) -> "Parser[Never]": + def fail(self, entity: str) -> Parser[Never]: """Fail early if we find something we don't support. - Not strictly speaking required.""" + Not strictly speaking required. + """ - def f(ins: str): + def f(ins: str) -> Never: self(ins) - raise TierkreisError(f"{entity} not supported.") + msg = f"{entity} not supported." + raise TierkreisError(msg) return Parser(f) @@ -153,10 +186,12 @@ def seq[A, B, C, D, E]( *args: *tuple[Parser[A], Parser[B], Parser[C], Parser[D], Parser[E]], ) -> Parser[tuple[A, B, C, D, E]]: ... def seq(*args: Parser) -> Parser[tuple]: - """Run a sequence of parsers one after the other - and collect their outputs in a tuple.""" + """Run a sequence of parsers. - def f(ins: str): + Runs parsers one after the other and collect their outputs in a tuple. + """ + + def f(ins: str) -> tuple[tuple[Any, ...], str]: outs = [] for arg in args: s, ins = arg(ins) @@ -167,31 +202,38 @@ def f(ins: str): def lit(*args: str) -> Parser[str]: - """If the input starts with one of the strings in `args` - then take the string off the stream and return it.""" + """Match literal strings at the start of stream and remove them. + + If the input starts with one of the strings in `args` + then take the string off the stream and return it. + """ - def f(ins: str): + def f(ins: str) -> tuple[str, str]: for a in args: if ins.startswith(a): return a, ins[len(a) :] - raise ParserError(f"lit: expected {args} found '{ins[:20]}'") + msg = f"lit: expected {args} found '{ins[:20]}'" + raise ParserError(msg) return Parser(f) def reg(regex: str) -> Parser[str]: - """If start of the input matches the `regex` - then take the matching text off the stream and return it. + """Match a regex against the start of stream and remove it. - Please don't pass match groups within the regex; they will be taken care of.""" + If start of the input matches the `regex` then take the matching text off + the stream and return it. + Please don't pass match groups within the regex; they will be taken care of. + """ - def f(ins: str): + def f(ins: str) -> tuple[str, str]: r = re.compile("^(" + regex + ")") if a := r.match(ins): return a.group(0), ins[a.end() :] - raise ParserError(f"reg: expected regex {regex} found '{ins[:20]}'") + msg = f"reg: expected regex {regex} found '{ins[:20]}'" + raise ParserError(msg) return Parser(f) diff --git a/tierkreis/tierkreis/idl/spec.py b/tierkreis/tierkreis/idl/spec.py index 9faee5b9c..c9f7c5f22 100644 --- a/tierkreis/tierkreis/idl/spec.py +++ b/tierkreis/tierkreis/idl/spec.py @@ -7,11 +7,9 @@ """ from tierkreis.idl.models import Interface, Method, Model, TypedArg - from tierkreis.idl.parser import lit, seq from tierkreis.idl.type_symbols import generic_t, ident, type_symbol - type_decl = ((ident << lit(":")) & type_symbol).map(lambda x: TypedArg(*x)) model = seq( lit("@portmapping").opt().map(lambda x: x is not None) << lit("model"), diff --git a/tierkreis/tierkreis/idl/type_symbols.py b/tierkreis/tierkreis/idl/type_symbols.py index ace5f3669..ca1c4a477 100644 --- a/tierkreis/tierkreis/idl/type_symbols.py +++ b/tierkreis/tierkreis/idl/type_symbols.py @@ -5,12 +5,10 @@ from types import NoneType from typing import ForwardRef + from tierkreis.idl.models import GenericType from tierkreis.idl.parser import Parser, lit, reg, seq -type _TypeT = type | ForwardRef - - signed_int = lit("integer", "int64", "int32", "int16", "int8", "safeint") unsigned_int = lit("uint64", "uint32", "uint16", "uint8") integer_t = (signed_int | unsigned_int).coerce(GenericType(int, [])) @@ -30,19 +28,40 @@ def array_t(ins: str) -> tuple[GenericType, str]: + """Parse a array generic type. + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return (lit("Array<") >> type_symbol << lit(">")).map( - lambda x: GenericType(list, [x]) + lambda x: GenericType(list, [x]), )(ins) def record_t(ins: str) -> tuple[GenericType, str]: + """Parse a record generic type. + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return (lit("Record<") >> type_symbol << lit(">")).map( - lambda x: GenericType(dict, [GenericType(str, []), x]) + lambda x: GenericType(dict, [GenericType(str, []), x]), )(ins) @Parser def generic_t(ins: str) -> tuple[GenericType, str]: + """Parse a generic type. + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return seq( ident, (lit("<") >> ident.rep(lit(",")) << lit(">")).opt().map(lambda x: x or []), @@ -51,6 +70,15 @@ def generic_t(ins: str) -> tuple[GenericType, str]: @Parser def type_symbol(ins: str) -> tuple[GenericType, str]: + """Parse a regular type symbol. + + E.g. int, float, ... + + :param ins: The string to parse. + :type ins: str + :return: The parsed type and its string representation. + :rtype: tuple[GenericType, str] + """ return ( integer_t | float_t diff --git a/tierkreis/tierkreis/labels.py b/tierkreis/tierkreis/labels.py index 0a6cf0474..c93513a35 100644 --- a/tierkreis/tierkreis/labels.py +++ b/tierkreis/tierkreis/labels.py @@ -4,8 +4,9 @@ class Labels: """Special port labels used by builtin functions.""" - def __init__(self): - raise RuntimeError("Do not instantiate") + def __init__(self) -> None: + msg = "Do not instantiate" + raise RuntimeError(msg) THUNK = "thunk" VALUE = "value" diff --git a/tierkreis/tierkreis/logger_setup.py b/tierkreis/tierkreis/logger_setup.py index 39c851902..228800c58 100644 --- a/tierkreis/tierkreis/logger_setup.py +++ b/tierkreis/tierkreis/logger_setup.py @@ -1,7 +1,9 @@ +"""Sets up the Tierkreis logger.""" + import logging +import sys from os import getenv from pathlib import Path -import sys from tierkreis.consts import TKR_DATE_FMT_KEY, TKR_LOG_FMT_KEY, TKR_LOG_LEVEL_KEY @@ -12,6 +14,15 @@ def set_tkr_logger( file_name: Path, level: int | str = logging.INFO, ) -> None: + """Set up the 'tierkreis' logger. + + Adds a filehandler for use in the controller. + + :param file_name: The file to use for the logging. + :type file_name: Path + :param level: The log level, defaults to logging.INFO + :type level: int | str, optional + """ logger = logging.getLogger(LOGGER_NAME) if logger.hasHandlers(): [logger.removeHandler(h) for h in logger.handlers] @@ -23,10 +34,23 @@ def set_tkr_logger( logger.addHandler(handler) except FileNotFoundError: - logging.warning("Could not log to file, logging to std out instead.") + root_logger = logging.getLogger() + root_logger.warning("Could not log to file, logging to std out instead.") def add_handler_from_environment(logger: logging.Logger) -> logging.Handler: + """Add a handler to a logger from TKR env variables. + + Adds a stream handler on stderr with log level, format and date format + taken from the environment variables $TKR_LOG_LEVEL, $TKR_LOG_FMT and + $TKR_DATE_FORMAT. + Returns the created handler so it can be removed later if needed. + + :param logger: The logger to add the handler to. + :type logger: logging.Logger + :return: The created handler. + :rtype: logging.Handler + """ log_level = getenv(TKR_LOG_LEVEL_KEY, logging.INFO) if log_level is not None: logger.setLevel(log_level) diff --git a/tierkreis/tierkreis/models.py b/tierkreis/tierkreis/models.py index 1405f4468..326fbb76c 100644 --- a/tierkreis/tierkreis/models.py +++ b/tierkreis/tierkreis/models.py @@ -1,5 +1,7 @@ +"""Tierkreis models for graph builder definitions.""" + from tierkreis.controller.data.core import EmptyModel from tierkreis.controller.data.models import TKR, portmapping from tierkreis.controller.data.types import Struct -__all__ = ["EmptyModel", "TKR", "portmapping", "Struct"] +__all__ = ["TKR", "EmptyModel", "Struct", "portmapping"] diff --git a/tierkreis/tierkreis/namespace.py b/tierkreis/tierkreis/namespace.py index b966572bd..550106686 100644 --- a/tierkreis/tierkreis/namespace.py +++ b/tierkreis/tierkreis/namespace.py @@ -1,17 +1,20 @@ +"""Namespace for a tierkreis worker.""" + +import shutil +import subprocess +from collections.abc import Callable from dataclasses import dataclass, field from inspect import Signature, signature from logging import getLogger from pathlib import Path -import shutil -import subprocess -from typing import Callable, Self +from typing import Self + from tierkreis.codegen import format_method, format_model from tierkreis.controller.data.models import PModel, is_portmapping from tierkreis.controller.data.types import Struct, has_default, is_ptype from tierkreis.exceptions import TierkreisError -from tierkreis.idl.spec import spec from tierkreis.idl.models import GenericType, Interface, Method, Model, TypedArg - +from tierkreis.idl.spec import spec logger = getLogger(__name__) WorkerFunction = Callable[..., PModel] @@ -19,21 +22,36 @@ @dataclass class Namespace: + """The namespace of a worker. + + attr name: The name of the namespace. + attr methods: The methods in the namespace. + attr models: The models in the namespace. + """ + name: str - methods: list[Method] = field(default_factory=lambda: []) - models: set[Model] = field(default_factory=lambda: set()) + methods: list[Method] = field(default_factory=list) + models: set[Model] = field(default_factory=set) + + def add_struct(self, generic_type: GenericType) -> None: + """Add a struct to the namespace. - def add_struct(self, gt: GenericType) -> None: - if not isinstance(gt.origin, Struct) or Model(False, gt, []) in self.models: + :param generic_type: The generic type to add. + :type generic_type: GenericType + """ + if ( + not isinstance(generic_type.origin, Struct) + or Model(is_portmapping=False, t=generic_type, decls=[]) in self.models + ): return - annotations = gt.origin.__annotations__ + annotations = generic_type.origin.__annotations__ decls = [TypedArg(k, GenericType.from_type(x)) for k, x in annotations.items()] for decl in decls: [self.add_struct(g) for g in decl.t.included_structs()] - portmapping_flag = True if is_portmapping(gt.origin) else False - model = Model(portmapping_flag, gt, decls) + portmapping_flag = bool(is_portmapping(generic_type.origin)) + model = Model(portmapping_flag, generic_type, decls) self.models.add(model) @staticmethod @@ -41,15 +59,22 @@ def _validate_signature(func: WorkerFunction) -> Signature: sig = signature(func) for param in sig.parameters.values(): if not is_ptype(param.annotation): - raise TierkreisError(f"Expected PType got {param.annotation}") + msg = f"Expected PType got {param.annotation}" + raise TierkreisError(msg) out = sig.return_annotation if not is_portmapping(out) and not is_ptype(out) and out is not None: - raise TierkreisError(f"Expected PModel found {out}") + msg = f"Expected PModel found {out}" + raise TierkreisError(msg) return sig def add_function(self, func: WorkerFunction) -> None: + """Add a function to the namespace. + + :param func: The function to add. + :type func: WorkerFunction + """ sig = self._validate_signature(func) method = Method( @@ -63,12 +88,22 @@ def add_function(self, func: WorkerFunction) -> None: ) self.methods.append(method) - for t in func.__annotations__.values(): - [self.add_struct(x) for x in GenericType.from_type(t).included_structs()] + for annotation_type in func.__annotations__.values(): + [ + self.add_struct(struct) + for struct in GenericType.from_type(annotation_type).included_structs() + ] @classmethod def from_spec_file(cls, path: Path) -> "Namespace": - with open(path) as fh: + """Generate a Namespace from a tsp spec file. + + :param path: The path to the spec file. + :type path: Path + :return: The generated namespace. + :rtype: Namespace + """ + with Path.open(path) as fh: namespace_spec = spec(fh.read()) return cls._from_spec(namespace_spec[0]) @@ -77,18 +112,23 @@ def _from_spec(cls, args: tuple[list[Model], Interface]) -> "Self": models = args[0] interface = args[1] namespace = cls(interface.name, models=set(models)) - for f in interface.methods: - model = next((x for x in models if x.t == f.return_type), None) + for method in interface.methods: + model = next((x for x in models if x.t == method.return_type), None) if model is not None: - f.return_type_is_portmapping = model.is_portmapping - namespace.methods.append(f) + method.return_type_is_portmapping = model.is_portmapping + namespace.methods.append(method) return namespace def stubs(self) -> str: - functions = [format_method(self.name, f) for f in self.methods] + """Generate type stubs strings for the namespace. + + :return: The generated stubs as string. + :rtype: str + """ + functions = [format_method(self.name, method) for method in self.methods] functions_str = "\n\n".join(functions) - models_str = "\n\n".join([format_model(x) for x in sorted(list(self.models))]) + models_str = "\n\n".join([format_model(model) for model in sorted(self.models)]) return f'''"""Code generated from {self.name} namespace. Please do not edit.""" @@ -103,17 +143,28 @@ def stubs(self) -> str: ''' def write_stubs(self, stubs_path: Path) -> None: - """Writes the type stubs to stubs_path. + """Write the type stubs to stubs_path. :param stubs_path: The location to write to. :type stubs_path: Path """ - with open(stubs_path, "w+") as fh: + with Path.open(stubs_path, "w+") as fh: fh.write(self.stubs()) ruff_binary = shutil.which("ruff") if ruff_binary: - subprocess.run([ruff_binary, "format", stubs_path]) - subprocess.run([ruff_binary, "check", "--fix", stubs_path]) + subprocess.run([ruff_binary, "format", stubs_path], check=False) + subprocess.run( + [ + ruff_binary, + "check", + "--fix", + "--ignore", + "D,N801,UP007", + "--unsafe-fixes", + stubs_path, + ], + check=False, + ) else: logger.warning("No ruff binary found. Stubs will contain raw codegen.") diff --git a/tierkreis/tierkreis/pkg/__init__.py b/tierkreis/tierkreis/pkg/__init__.py index 6173137d9..69cda6e35 100644 --- a/tierkreis/tierkreis/pkg/__init__.py +++ b/tierkreis/tierkreis/pkg/__init__.py @@ -1,15 +1,16 @@ """Manage dependencies for a Tierkreis project in the standard directory layout. Usually the target directory is the `workers_external` directory of the project. -Each `TKRDependency` adds a folder to the target directory containing the worker code/executable. -(Alternatively on systems where copying/moving large numbers of inodes is slow the TKRDependency can add a symlink.) -Each `TKRDependency` has at its disposal a directory `worker_cache / TKRDependency.type` for any caching it needs. +Each `TKRDependency` adds a folder to the target directory containing the worker + code/executable. (Alternatively on systems where copying/moving large numbers of inodes + is slow the TKRDependency can add a symlink.) Each `TKRDependency` has at its disposal + a directory `worker_cache / TKRDependency.type` for any caching it needs. """ import logging -from os import unlink from pathlib import Path from shutil import rmtree + from tierkreis.consts import WORKER_CACHE from tierkreis.pkg.base import TKRDependency @@ -17,30 +18,33 @@ def install_dependencies( - deps: dict[str, TKRDependency], target_dir: Path, worker_cache: Path = WORKER_CACHE -): + deps: dict[str, TKRDependency], + target_dir: Path, + worker_cache: Path = WORKER_CACHE, +) -> None: """Install the dependencies in `deps` into `target_dir`.""" for worker_name, dep in deps.items(): dep.install(worker_name, target_dir, worker_cache) -def remove_dependencies(deps: list[str], target_dir: Path): +def remove_dependencies(deps: list[str], target_dir: Path) -> None: """Remove the worker directories listed in `deps` from the `target_dir`. - Assumes that the worker directory is a symlink or a folder.""" - + Assumes that the worker directory is a symlink or a folder. + """ for dep in deps: worker_dir = target_dir / dep if worker_dir.is_symlink(): - unlink(worker_dir) + Path.unlink(worker_dir) elif worker_dir.is_dir(): rmtree(worker_dir) else: logger.warning( - f"Expected {worker_dir} to be a symlink or directory. Taking no action." + "Expected %s to be a symlink or directory. Taking no action.", + worker_dir, ) -def clear_cache(worker_cache: Path = WORKER_CACHE): - """Remove all cached files used to install external worker depenencies.""" +def clear_cache(worker_cache: Path = WORKER_CACHE) -> None: + """Remove all cached files used to install external worker dependencies.""" rmtree(worker_cache) diff --git a/tierkreis/tierkreis/pkg/base.py b/tierkreis/tierkreis/pkg/base.py index 2eebf2755..7ac913676 100644 --- a/tierkreis/tierkreis/pkg/base.py +++ b/tierkreis/tierkreis/pkg/base.py @@ -1,18 +1,39 @@ +"""Base class for dependencies of Tierkreis workers.""" + from abc import ABC, abstractmethod from pathlib import Path from pydantic import BaseModel + from tierkreis.consts import WORKER_CACHE class TKRDependency(ABC, BaseModel): + """A worker dependency for a Tierkreis project. + + :fields: + type (str): The type of the dependency. Used for the cache. + """ + type: str @abstractmethod def install( - self, worker_name: str, target_dir: Path, worker_cache: Path = WORKER_CACHE - ) -> None: ... + self, + worker_name: str, + target_dir: Path, + worker_cache: Path = WORKER_CACHE, + ) -> None: + """Install a worker called `worker_name` into the `target_dir`. + + The `TKRDependency` has at its disposal a directory + `worker_cache / TKRDependency.type` for any caching it needs. - """Install a worker called `worker_name` into the `target_dir`. - - The `TKRDependency` has at its disposal a directory `worker_cache / TKRDependency.type` for any caching it needs.""" + :param worker_name: The name of the worker. + :type worker_name: str + :param target_dir: The target directory to install the worker into. + :type target_dir: Path + :param worker_cache: The worker cache directory, defaults to WORKER_CACHE + :type worker_cache: Path, optional + """ + ... diff --git a/tierkreis/tierkreis/pkg/github.py b/tierkreis/tierkreis/pkg/github.py index 7932a9076..b5fe40867 100644 --- a/tierkreis/tierkreis/pkg/github.py +++ b/tierkreis/tierkreis/pkg/github.py @@ -1,40 +1,71 @@ -from os import symlink -from pathlib import Path +"""Worker dependency from github.""" + +import shutil import subprocess +from pathlib import Path +from typing import override + from tierkreis.consts import WORKER_CACHE +from tierkreis.exceptions import TierkreisError from tierkreis.pkg.base import TKRDependency class GitHubDependency(TKRDependency): + """A TKRDependency that installs a worker from a github repository. + + :fields: + account (str): The github account to clone from. + repo (str): The github repository to clone from. + subdirectory (str): The subdirectory in the repository to use as the worker. + branch (str): The branch to clone from. Defaults to "main". + """ + type: str = "github" account: str repo: str subdirectory: str - branch: str = "main" - def cache_subdir(self, worker_cache: Path): + def cache_subdir(self, worker_cache: Path) -> Path: + """Get the worker cache subdirectory. + + :param worker_cache: The base directory for worker caches. + :type worker_cache: Path + :return: The subdirectory path within the worker cache. + :rtype: Path + """ return worker_cache / "github" / self.account / self.repo / self.branch + @override def install( - self, worker_name: str, target_dir: Path, worker_cache: Path = WORKER_CACHE - ): + self, + worker_name: str, + target_dir: Path, + worker_cache: Path = WORKER_CACHE, + ) -> None: cache_dir = self.cache_subdir(worker_cache) cache_dir.mkdir(exist_ok=True, parents=True) - + git_path = shutil.which("git") + if git_path is None: + msg = "git is required to use github dependencies." + raise TierkreisError(msg) git_dir = cache_dir / ".git" git_url = f"https://github.com/{self.account}/{self.repo}" if not git_dir.exists(): - subprocess.run(["git", "clone", git_url, "."], cwd=cache_dir) + subprocess.run([git_path, "clone", git_url, "."], cwd=cache_dir, check=True) - subprocess.run(["git", "restore", "."], cwd=cache_dir) - subprocess.run(["git", "clean", "-f"], cwd=cache_dir) - subprocess.run(["git", "checkout", self.branch], cwd=cache_dir) - subprocess.run(["git", "pull", "--rebase"], cwd=cache_dir) + subprocess.run([git_path, "restore", "."], cwd=cache_dir, check=True) + subprocess.run([git_path, "clean", "-f"], cwd=cache_dir, check=True) + subprocess.run([git_path, "checkout", self.branch], cwd=cache_dir, check=True) + subprocess.run([git_path, "pull", "--rebase"], cwd=cache_dir, check=True) worker_dir = target_dir / worker_name if worker_dir.exists(): worker_dir.unlink() - symlink(cache_dir / self.subdirectory, worker_dir, target_is_directory=True) + Path.symlink_to( + cache_dir / self.subdirectory, + worker_dir, + target_is_directory=True, + ) diff --git a/tierkreis/tierkreis/storage.py b/tierkreis/tierkreis/storage.py index 2c7df3867..68c367ed3 100644 --- a/tierkreis/tierkreis/storage.py +++ b/tierkreis/tierkreis/storage.py @@ -1,50 +1,65 @@ +"""Implementation to access node storage data.""" + from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import GraphData from tierkreis.controller.data.location import Loc -from tierkreis.controller.data.types import PType, ptype_from_bytes -from tierkreis.controller.storage.exceptions import EntryNotFound -from tierkreis.controller.storage.protocol import ControllerStorage +from tierkreis.controller.data.models import TModel +from tierkreis.controller.data.types import PType, is_optional, ptype_from_bytes +from tierkreis.controller.storage.exceptions import EntryNotFoundError from tierkreis.controller.storage.filestorage import ( ControllerFileStorage as FileStorage, ) from tierkreis.controller.storage.in_memory import ( ControllerInMemoryStorage as InMemoryStorage, ) +from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.exceptions import TierkreisError -from tierkreis.controller.data.models import TModel -from tierkreis.controller.data.types import is_optional __all__ = ["FileStorage", "InMemoryStorage"] def _read_output( - storage: ControllerStorage, port_name: str, annotation: type | None + storage: ControllerStorage, + port_name: str, + annotation: type | None, ) -> PType: - """Tries to get the output `port_name` from the root graph. - If `annotation` indicates that the value is optional then do not raise on EntryNotFound. - """ + """Try to get the output `port_name` from the root graph. + If `annotation` indicates that the value is optional + then do not raise on EntryNotFound. + """ try: return ptype_from_bytes(storage.read_output(Loc(), port_name)) - except EntryNotFound as exc: + except EntryNotFoundError as exc: if annotation and is_optional(annotation): return None - raise TierkreisError(f"Output {port_name} not found.") from exc + msg = f"Output {port_name} not found." + raise TierkreisError(msg) from exc def read_outputs[A: TModel, B: TModel]( - g: GraphData | GraphBuilder[A, B], storage: ControllerStorage + graph: GraphData | GraphBuilder[A, B], + storage: ControllerStorage, ) -> dict[str, PType] | PType: - """Read the outputs from the `storage`. + """Read the outputs of a workflow graph. - The bytes are parsed into Python types if possible.""" + The bytes are parsed into Python types if possible. + :param graph: The graph to read. + :type graph: GraphData | GraphBuilder + :param storage: The storage of the workflow run. + :type storage: ControllerStorage + :return: The output values. If the graph has a single output port named "value" it + is returned directly, otherwise a dictionary mapping output port names to their + values is returned. + :rtype: dict[str, PType] | PType + """ output_annotation = None - if isinstance(g, GraphBuilder): - output_annotation = g.outputs_type - g = g.get_data() + if isinstance(graph, GraphBuilder): + output_annotation = graph.outputs_type + graph = graph.get_data() - out_ports = list(g.nodes[g.output_idx()].inputs.keys()) + out_ports = list(graph.nodes[graph.output_idx()].inputs.keys()) if len(out_ports) == 1 and "value" in out_ports: return _read_output(storage, "value", output_annotation) @@ -62,17 +77,36 @@ def read_outputs[A: TModel, B: TModel]( def read_loop_trace( - g: GraphData | GraphBuilder, + graph: GraphData | GraphBuilder, storage: ControllerStorage, node_name: str, output_name: str | None = None, ) -> list[PType | dict[str, list[PType]]]: - """Reads the trace of a loop from storage.""" - if isinstance(g, GraphBuilder): - g = g.get_data() + """Read the trace of a named loop. + + This is useful to track intermediate values in a loop. + + :param graph: The graph to read. + :type graph: GraphData | GraphBuilder + :param storage: The storage of the workflow run. + :type storage: ControllerStorage + :param node_name: The name of the loop node. + :type node_name: str + :param output_name: The name of the output port to trace, defaults to None + :type output_name: str | None, optional + :raises TierkreisError: If the loop name is not found in debug data. + :raises TierkreisError: If the output name is not found in loop node output. + :return: A list of traced values. If output_name is None, each entry is a dict + mapping output port names to their values at each iteration, otherwise a list + of values for the specified output port is returned. + :rtype: list[PType | dict[str, list[PType]]] + """ + if isinstance(graph, GraphBuilder): + graph = graph.get_data() loc = storage.loc_from_node_name(node_name) if loc is None: - raise TierkreisError(f"Loop name {node_name} not found in debug data.") + msg = f"Loop name {node_name} not found in debug data." + raise TierkreisError(msg) output_names = storage.read_output_ports(loc) if output_name is None: traces = { @@ -80,9 +114,13 @@ def read_loop_trace( for name in output_names if name != "should_continue" } - return [dict(zip(traces.keys(), vals)) for vals in zip(*traces.values())] + return [ + dict(zip(traces.keys(), vals, strict=False)) + for vals in zip(*traces.values(), strict=False) + ] if output_name not in output_names: - raise TierkreisError(f"Output name {output_name} not found in loop node output") + msg = f"Output name {output_name} not found in loop node output" + raise TierkreisError(msg) results = storage.read_loop_trace(loc, output_name) return [ptype_from_bytes(r) for r in results] diff --git a/tierkreis/tierkreis/worker/__init__.py b/tierkreis/tierkreis/worker/__init__.py index e69de29bb..3ce15025b 100644 --- a/tierkreis/tierkreis/worker/__init__.py +++ b/tierkreis/tierkreis/worker/__init__.py @@ -0,0 +1 @@ +"""Tierkreis worker package for user defined tasks.""" diff --git a/tierkreis/tierkreis/worker/storage/__init__.py b/tierkreis/tierkreis/worker/storage/__init__.py index e69de29bb..d0bdf3a16 100644 --- a/tierkreis/tierkreis/worker/storage/__init__.py +++ b/tierkreis/tierkreis/worker/storage/__init__.py @@ -0,0 +1 @@ +"""Worker storage implementations.""" diff --git a/tierkreis/tierkreis/worker/storage/filestorage.py b/tierkreis/tierkreis/worker/storage/filestorage.py index d466b1991..a88ed0906 100644 --- a/tierkreis/tierkreis/worker/storage/filestorage.py +++ b/tierkreis/tierkreis/worker/storage/filestorage.py @@ -1,14 +1,24 @@ +"""Filestorage implementation analog to ControllerFileStorage.""" + +# ruff: noqa: D102 (class methods inherited from WorkerStorage) import json -from glob import glob import os +from glob import glob from pathlib import Path from tierkreis.consts import TKR_DIR_KEY from tierkreis.controller.data.location import WorkerCallArgs -from tierkreis.controller.storage.exceptions import EntryNotFound +from tierkreis.controller.storage.exceptions import EntryNotFoundError class WorkerFileStorage: + """File storage implementation for workers. + + :fields: + tierkreis_dir: The directory to use for storing tierkreis data, + defaults to ~/.tierkreis/checkpoints. + """ + def __init__(self, tierkreis_dir: Path | None = None) -> None: if tierkreis_dir is not None: self.tierkreis_dir = tierkreis_dir @@ -23,20 +33,20 @@ def resolve(self, path: Path | str) -> Path: def read_call_args(self, path: Path) -> WorkerCallArgs: try: - with open(self.resolve(path), "r") as fh: + with Path.open(self.resolve(path)) as fh: return WorkerCallArgs(**json.loads(fh.read())) except FileNotFoundError as exc: - raise EntryNotFound(path) from exc + raise EntryNotFoundError(path) from exc def read_input(self, path: Path) -> bytes: try: - with open(self.resolve(path), "rb") as fh: + with Path.open(self.resolve(path), "rb") as fh: return fh.read() except FileNotFoundError as exc: - raise EntryNotFound(path) from exc + raise EntryNotFoundError(path) from exc def write_output(self, path: Path, value: bytes) -> None: - with open(self.resolve(path), "wb+") as fh: + with Path.open(self.resolve(path), "wb+") as fh: fh.write(value) def glob(self, path_string: str) -> list[str]: @@ -46,5 +56,5 @@ def mark_done(self, path: Path) -> None: self.resolve(path).touch() def write_error(self, path: Path, error_logs: str) -> None: - with open(self.resolve(path), "w+") as f: + with Path.open(self.resolve(path), "w+") as f: f.write(error_logs) diff --git a/tierkreis/tierkreis/worker/storage/in_memory.py b/tierkreis/tierkreis/worker/storage/in_memory.py index 3bcc7e404..09bb3f9cc 100644 --- a/tierkreis/tierkreis/worker/storage/in_memory.py +++ b/tierkreis/tierkreis/worker/storage/in_memory.py @@ -1,21 +1,30 @@ +"""In-memory storage implementation analog to ControllerInMemoryStorage.""" + +# ruff: noqa: D102 (class methods inherited from WorkerStorage) import fnmatch import json import logging from pathlib import Path from tierkreis.controller.data.location import WorkerCallArgs -from tierkreis.controller.storage.exceptions import EntryNotFound +from tierkreis.controller.storage.exceptions import EntryNotFoundError from tierkreis.controller.storage.in_memory import ( ControllerInMemoryStorage, InMemoryFileData, ) from tierkreis.exceptions import TierkreisError - logger = logging.getLogger(__name__) class InMemoryWorkerStorage: + """In-memory storage implementation for workers. + + Delegates calls to the ControllerInMemoryStorage used for the workflow. + :fields: + controller_storage: The controller storage. + """ + def __init__(self, controller_storage: ControllerInMemoryStorage) -> None: self.controller_storage = controller_storage @@ -27,25 +36,25 @@ def read_call_args(self, path: Path) -> WorkerCallArgs: bs = self.controller_storage.files[path].value return WorkerCallArgs(**json.loads(bs)) except KeyError as exc: - raise EntryNotFound(path) from exc + raise EntryNotFoundError(path) from exc def read_input(self, path: Path) -> bytes: try: return self.controller_storage.files[path].value except KeyError as exc: - raise EntryNotFound(path) from exc + raise EntryNotFoundError(path) from exc def write_output(self, path: Path, value: bytes) -> None: self.controller_storage.files[path] = InMemoryFileData(value) def glob(self, path_string: str) -> list[str]: - files = [str(x) for x in self.controller_storage.files.keys()] - matching = fnmatch.filter(files, path_string) - return matching + files = [str(x) for x in self.controller_storage.files] + return fnmatch.filter(files, path_string) def mark_done(self, path: Path) -> None: self.controller_storage.touch(path) - def write_error(self, path: Path, error_logs: str) -> None: + def write_error(self, _: Path, error_logs: str) -> None: logger.error(error_logs) - raise TierkreisError("Error occured when running graph in-memory.") + msg = "Error occurred when running graph in-memory." + raise TierkreisError(msg) diff --git a/tierkreis/tierkreis/worker/storage/protocol.py b/tierkreis/tierkreis/worker/storage/protocol.py index 040cede0a..c3bcca5b3 100644 --- a/tierkreis/tierkreis/worker/storage/protocol.py +++ b/tierkreis/tierkreis/worker/storage/protocol.py @@ -1,3 +1,5 @@ +"""Storage protocol for workers.""" + from pathlib import Path from typing import Protocol @@ -5,10 +7,91 @@ class WorkerStorage(Protocol): - def resolve(self, path: Path | str) -> Path: ... - def read_call_args(self, path: Path) -> WorkerCallArgs: ... - def read_input(self, path: Path) -> bytes: ... - def write_output(self, path: Path, value: bytes) -> None: ... - def glob(self, path_string: str) -> list[str]: ... - def mark_done(self, path: Path) -> None: ... - def write_error(self, path: Path, error_logs: str) -> None: ... + """Storage protocol for workers. + + :abstract: + """ + + def resolve(self, path: Path | str) -> Path: + """Resolve a path or str to a path in the storage. + + Transforms paths such that the storage can correctly read/write to them. + E.g. for a file storage, relative paths are resolved against a base directory. + + :param path: The path to resolve. + :type path: Path | str + :return: The resolved path according to the storage. + :rtype: Path + """ + ... + + def read_call_args(self, path: Path) -> WorkerCallArgs: + """Read the call args of a worker. + + The function name is part of the call args which is then used to + determine which function to call. + + :param path: The path to read from. + :type path: Path + :return: The call args of the worker function. + :rtype: WorkerCallArgs + """ + ... + + def read_input(self, path: Path) -> bytes: + """Read the input to a worker task. + + Input locations are defined in the call args. + + :param path: The path to read from. + :type path: Path + :return: The bytes read from the input location. + :rtype: bytes + """ + ... + + def write_output(self, path: Path, value: bytes) -> None: + """Write the outputs of a worker task. + + Output locations are defined in the call args. + + :param path: The path to write to. + :type path: Path + :param value: The bytes to write. + :type value: bytes + """ + ... + + def glob(self, path_string: str) -> list[str]: + """Get a list of paths matching the path glob. + + Used in map nodes to find all input values for the individual tasks. + + :param path_string: The glob string to match. + :type path_string: str + :return: A list of matching path strings. + :rtype: list[str] + """ + ... + + def mark_done(self, path: Path) -> None: + """Mark the task node as done. + + Done paths are defined in the call args. + + :param path: The path to mark as done. + :type path: Path + """ + ... + + def write_error(self, path: Path, error_logs: str) -> None: + """Write an error to the logs. + + Logs are stored in a location defined in the call args. + + :param path: The path to write the error logs to. + :type path: Path + :param error_logs: The message to write. + :type error_logs: str + """ + ... diff --git a/tierkreis/tierkreis/worker/worker.py b/tierkreis/tierkreis/worker/worker.py index 9d533df9a..0f347f3a4 100644 --- a/tierkreis/tierkreis/worker/worker.py +++ b/tierkreis/tierkreis/worker/worker.py @@ -1,7 +1,10 @@ -from inspect import Signature, signature +"""Tierkreis worker implementation.""" + import logging +from collections.abc import Callable +from inspect import Signature, signature from pathlib import Path -from typing import Callable, TypeVar +from typing import NoReturn, TypeVar from tierkreis.controller.data.core import PortID from tierkreis.controller.data.location import WorkerCallArgs @@ -16,7 +19,7 @@ has_default, ptype_from_bytes, ) -from tierkreis.controller.storage.exceptions import EntryNotFound +from tierkreis.controller.storage.exceptions import EntryNotFoundError from tierkreis.exceptions import TierkreisError from tierkreis.logger_setup import add_handler_from_environment from tierkreis.namespace import Namespace, WorkerFunction @@ -29,7 +32,7 @@ class TierkreisWorkerError(TierkreisError): - pass + """Exception raised when a worker encounters an error.""" F = TypeVar("F", bound=Callable[..., PModel]) @@ -39,7 +42,8 @@ class Worker: """A worker bundles a set of functionality under a common namespace. The main usage of a worker is to convert python functions into atomic tasks, - which can then be executed by the :py:class:`tierkreis.controller.executor.uv_executor.UvExecutor` + which can then be executed by the + :py:class:`tierkreis.controller.executor.uv_executor.UvExecutor` or similar Executors. From the worker type stubs can be generated to statically check the function calls. @@ -53,10 +57,14 @@ class Worker: def exp(x: float, a: float) -> float: return value = a * np.exp(x) - :param name: The name of the worker. - :type name: str - :param storage: Storage layer for the worker to interact with the ControllerStorage. - :type storage: WorkerStorage + :fields: + name (str) The name of the worker. + storage (WorkerStorage) Storage layer for the + worker to interact with the ControllerStorage. + namespace (Namespace) The namespace of the worker. + types (dict[MethodName, Signature]) Mapping function names to their signatures. + functions (dict[str, Callable[[WorkerCallArgs], None]]) + Mapping function names to their implementations. """ functions: dict[str, Callable[[WorkerCallArgs], None]] @@ -74,26 +82,33 @@ def __init__(self, name: str, storage: WorkerStorage | None = None) -> None: self.storage = storage def _load_args( - self, f: WorkerFunction, inputs: dict[str, Path] + self, + f: WorkerFunction, + inputs: dict[str, Path], ) -> dict[str, PType]: bs: dict[str, bytes] = {} for k, p in inputs.items(): try: bs[k] = self.storage.read_input(p) - except EntryNotFound: + except EntryNotFoundError as e: if not has_default(self.types[f.__name__].parameters[k]): - raise TierkreisError(f"Input {k} not found at {p}.") + msg = f"Input {k} not found at {p}." + raise TierkreisError(msg) from e args = {} for k, b in bs.items(): args[k] = ptype_from_bytes( - b, self.types[f.__name__].parameters[k].annotation + b, + self.types[f.__name__].parameters[k].annotation, ) return args def _save_results( - self, f: WorkerFunction, outputs: dict[PortID, Path], results: PModel - ): + self, + f: WorkerFunction, + outputs: dict[PortID, Path], + results: PModel, + ) -> None: d = dict_from_pmodel(results) ret = annotations_from_pmodel(signature(f).return_annotation) for result_name, path in outputs.items(): @@ -101,15 +116,24 @@ def _save_results( self.storage.write_output(path, bs) def add_types(self, func: WorkerFunction) -> None: + """Add the types of a function to the worker. + + :param func: The function to add types for. + :type func: WorkerFunction + """ self.types[func.__name__] = signature(func) def primitive_task( self, ) -> Callable[[PrimitiveTask], None]: - """Registers a python function as a primitive task with the worker.""" + """Register a python function as a primitive task with the worker. + + :return: The wrapped task. + :rtype: Callable[[PrimitiveTask], None] + """ def function_decorator(func: PrimitiveTask) -> None: - def wrapper(args: WorkerCallArgs): + def wrapper(args: WorkerCallArgs) -> None: func(args, self.storage) self.functions[func.__name__] = wrapper @@ -117,13 +141,17 @@ def wrapper(args: WorkerCallArgs): return function_decorator def task(self) -> Callable[[F], F]: - """Registers a python function as a task with the worker.""" + """Register a python function as a task with the worker. + + :return: The wrapped function. + :rtype: Callable[[Callable[..., PModel]], Callable[..., PModel]] + """ def function_decorator(func: F) -> F: self.namespace.add_function(func) self.add_types(func) - def wrapper(node_definition: WorkerCallArgs): + def wrapper(node_definition: WorkerCallArgs) -> None: kwargs = self._load_args(func, node_definition.inputs) results = func(**kwargs) self._save_results(func, node_definition.outputs, results) @@ -143,27 +171,42 @@ def run(self, worker_definition_path: Path) -> None: node_definition = self.storage.read_call_args(worker_definition_path) logger.debug(node_definition.model_dump()) + def _check_function(msg: str) -> NoReturn: + raise TierkreisError(msg) + try: function = self.functions.get(node_definition.function_name, None) if function is None: - raise TierkreisError( - f"{self.name}: function name {node_definition.function_name} not found" + msg = ( + f"{self.name}: function name" + f"{node_definition.function_name} not found" ) - logger.info(f"running: {node_definition.function_name} in {self.name}") + _check_function(msg) + logger.info("running: %s in %s", node_definition.function_name, self.name) function(node_definition) self.storage.mark_done(node_definition.done_path) except Exception as err: - logger.error("encountered error", exc_info=err) + logger.exception("encountered error", exc_info=err) self.storage.write_error(node_definition.error_path, str(err)) + msg = ( + f"Worker {self.name} encountered error when executing " + f"{node_definition.function_name}." + ) raise TierkreisWorkerError( - f"Worker {self.name} encountered error when executing {node_definition.function_name}." + msg, ) from err def app(self, argv: list[str]) -> None: - """Wrapper for UV execution.""" + """Run the worker as uv app. + + Either generate stubs or run the worker. + + :param argv: The cli args. + :type argv: list[str] + """ handler = add_handler_from_environment(logger) if argv[1] == "--stubs-path": self.namespace.write_stubs(Path(argv[2])) diff --git a/tierkreis_visualization/tierkreis_visualization/app.py b/tierkreis_visualization/tierkreis_visualization/app.py index 12892ecfb..dad5bd6eb 100644 --- a/tierkreis_visualization/tierkreis_visualization/app.py +++ b/tierkreis_visualization/tierkreis_visualization/app.py @@ -1,24 +1,25 @@ import signal from sys import argv + from tierkreis.controller.data.graph import GraphData from tierkreis_visualization.app_config import ( App, StorageType, - graph_data_lifespan, dev_lifespan, + graph_data_lifespan, ) from tierkreis_visualization.config import CONFIG +from tierkreis_visualization.routers.frontend import assets +from tierkreis_visualization.routers.frontend import router as frontend_router +from tierkreis_visualization.routers.workflows import router as workflows_router from tierkreis_visualization.storage import ( file_storage_fn, from_graph_data_storage_fn, graph_data_storage_fn, ) -from tierkreis_visualization.routers.frontend import assets -from tierkreis_visualization.routers.workflows import router as workflows_router -from tierkreis_visualization.routers.frontend import router as frontend_router -def transform_to_sigkill(signum, frame): +def transform_to_sigkill(signum, frame) -> None: signal.raise_signal(signal.SIGKILL) diff --git a/tierkreis_visualization/tierkreis_visualization/app_config.py b/tierkreis_visualization/tierkreis_visualization/app_config.py index 6d4125093..17e8cce34 100644 --- a/tierkreis_visualization/tierkreis_visualization/app_config.py +++ b/tierkreis_visualization/tierkreis_visualization/app_config.py @@ -1,11 +1,14 @@ +import webbrowser +from collections.abc import Callable from contextlib import asynccontextmanager from enum import Enum -from typing import Callable, cast +from typing import cast from uuid import UUID -import webbrowser -from fastapi import FastAPI + import fastapi import starlette.datastructures +from fastapi import FastAPI + from tierkreis.controller.storage.filestorage import ControllerFileStorage from tierkreis.controller.storage.graphdata import GraphDataStorage from tierkreis.controller.storage.protocol import ControllerStorage @@ -35,7 +38,7 @@ class Request(fastapi.Request): @property def app(self) -> App: - return cast(App, super().app) + return cast("App", super().app) @asynccontextmanager diff --git a/tierkreis_visualization/tierkreis_visualization/cli.py b/tierkreis_visualization/tierkreis_visualization/cli.py index 997d4e0e9..ad0855eac 100644 --- a/tierkreis_visualization/tierkreis_visualization/cli.py +++ b/tierkreis_visualization/tierkreis_visualization/cli.py @@ -1,7 +1,6 @@ import argparse - -from tierkreis_visualization.main import start, dev, graph +from tierkreis_visualization.main import dev, graph, start def parse_args( diff --git a/tierkreis_visualization/tierkreis_visualization/data/eval.py b/tierkreis_visualization/tierkreis_visualization/data/eval.py index 511fc2ca3..bf6a74af7 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/eval.py +++ b/tierkreis_visualization/tierkreis_visualization/data/eval.py @@ -2,20 +2,19 @@ from typing import assert_never from tierkreis.controller.data.core import NodeIndex -from tierkreis.controller.data.location import Loc from tierkreis.controller.data.graph import GraphData, IfElse +from tierkreis.controller.data.location import Loc from tierkreis.controller.data.types import ptype_from_bytes from tierkreis.controller.storage.adjacency import in_edges from tierkreis.controller.storage.protocol import ControllerStorage - from tierkreis.exceptions import TierkreisError -from tierkreis_visualization.data.models import PyNode, NodeStatus, PyEdge +from tierkreis_visualization.data.models import NodeStatus, PyEdge, PyNode from tierkreis_visualization.data.outputs import outputs_from_loc from tierkreis_visualization.routers.models import PyGraph def node_status( - storage: ControllerStorage, node_location: Loc, errored_nodes: list[Loc] + storage: ControllerStorage, node_location: Loc, errored_nodes: list[Loc], ) -> NodeStatus: if storage.is_node_finished(node_location): return "Finished" @@ -38,7 +37,7 @@ def add_conditional_edges( i: NodeIndex, node: IfElse, py_edges: list[PyEdge], -): +) -> None: try: pred = json.loads(storage.read_output(loc.N(node.pred[0]), node.pred[1])) except (FileNotFoundError, TierkreisError): @@ -63,7 +62,7 @@ def add_conditional_edges( def get_eval_node( - storage: ControllerStorage, node_location: Loc, errored_nodes: list[Loc] + storage: ControllerStorage, node_location: Loc, errored_nodes: list[Loc], ) -> PyGraph: thunk = storage.read_output(node_location.N(-1), "body") graph = ptype_from_bytes(thunk, GraphData) diff --git a/tierkreis_visualization/tierkreis_visualization/data/function.py b/tierkreis_visualization/tierkreis_visualization/data/function.py index f37d0eb57..9d67e72ae 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/function.py +++ b/tierkreis_visualization/tierkreis_visualization/data/function.py @@ -1,4 +1,5 @@ from pydantic import BaseModel + from tierkreis.controller.data.location import Loc from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.exceptions import TierkreisError @@ -12,7 +13,8 @@ class FunctionDefinition(BaseModel): def get_function_node(storage: ControllerStorage, loc: Loc) -> FunctionDefinition: parent = loc.parent() if parent is None: - raise TierkreisError("Func node must have parent.") + msg = "Func node must have parent." + raise TierkreisError(msg) if not storage.node_has_error(loc): return FunctionDefinition() return FunctionDefinition(has_error=True, error_message=storage.read_errors(loc)) diff --git a/tierkreis_visualization/tierkreis_visualization/data/graph.py b/tierkreis_visualization/tierkreis_visualization/data/graph.py index db0f49820..8a30e553d 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/graph.py +++ b/tierkreis_visualization/tierkreis_visualization/data/graph.py @@ -1,5 +1,7 @@ from typing import assert_never + from fastapi import HTTPException + from tierkreis.controller.data.location import Loc from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis_visualization.data.eval import get_eval_node @@ -40,7 +42,7 @@ def get_node_data(storage: ControllerStorage, loc: Loc) -> PyGraph: case "function" | "const" | "ifelse" | "eifelse" | "input" | "output": raise HTTPException( - 400, detail="Only eval, loop and map nodes return a graph." + 400, detail="Only eval, loop and map nodes return a graph.", ) case _: diff --git a/tierkreis_visualization/tierkreis_visualization/data/loop.py b/tierkreis_visualization/tierkreis_visualization/data/loop.py index a5994e95f..29516b93d 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/loop.py +++ b/tierkreis_visualization/tierkreis_visualization/data/loop.py @@ -1,10 +1,9 @@ from pydantic import BaseModel + from tierkreis.controller.data.location import Loc from tierkreis.controller.storage.protocol import ControllerStorage - - from tierkreis_visualization.data.eval import check_error -from tierkreis_visualization.data.models import PyNode, PyEdge +from tierkreis_visualization.data.models import PyEdge, PyNode from tierkreis_visualization.data.outputs import outputs_from_loc @@ -14,7 +13,7 @@ class LoopNodeData(BaseModel): def get_loop_node( - storage: ControllerStorage, node_location: Loc, errored_nodes: list[Loc] + storage: ControllerStorage, node_location: Loc, errored_nodes: list[Loc], ) -> LoopNodeData: i = 0 while storage.is_node_started(node_location.L(i + 1)): @@ -56,7 +55,7 @@ def get_loop_node( started_time=storage.read_started_time(new_location) or "", finished_time=storage.read_finished_time(new_location) or "", outputs=list(outputs), - ) + ), ) edges = [] for port_name in outputs: @@ -70,6 +69,6 @@ def get_loop_node( value=outputs_from_loc(storage, node_location.L(n), port_name), ) for n in range(i) - ] + ], ) return LoopNodeData(nodes=nodes, edges=edges) diff --git a/tierkreis_visualization/tierkreis_visualization/data/map.py b/tierkreis_visualization/tierkreis_visualization/data/map.py index dacba843b..e3f2867ac 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/map.py +++ b/tierkreis_visualization/tierkreis_visualization/data/map.py @@ -1,8 +1,9 @@ from pydantic import BaseModel + +from tierkreis.controller.data.graph import Map from tierkreis.controller.data.location import Loc from tierkreis.controller.storage.adjacency import outputs_iter from tierkreis.controller.storage.protocol import ControllerStorage -from tierkreis.controller.data.graph import Map from tierkreis.exceptions import TierkreisError from tierkreis_visualization.data.eval import check_error from tierkreis_visualization.data.models import PyEdge, PyNode @@ -14,11 +15,12 @@ class MapNodeData(BaseModel): def get_map_node( - storage: ControllerStorage, loc: Loc, map: Map, errored_nodes: list[Loc] + storage: ControllerStorage, loc: Loc, map: Map, errored_nodes: list[Loc], ) -> MapNodeData: parent = loc.parent() if parent is None: - raise TierkreisError("MAP node must have parent.") + msg = "MAP node must have parent." + raise TierkreisError(msg) node_ref = next(n for n, port in map.inputs.values() if port == "*") map_eles = outputs_iter(storage, parent.N(node_ref)) diff --git a/tierkreis_visualization/tierkreis_visualization/data/models.py b/tierkreis_visualization/tierkreis_visualization/data/models.py index 3b1b86aca..bebc6321d 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/models.py +++ b/tierkreis_visualization/tierkreis_visualization/data/models.py @@ -1,10 +1,12 @@ from typing import Literal + from pydantic import BaseModel + from tierkreis.controller.data.location import Loc NodeStatus = Literal["Not started", "Started", "Error", "Finished"] NodeType = Literal[ - "function", "ifelse", "map", "eval", "loop", "eifelse", "const", "output", "input" + "function", "ifelse", "map", "eval", "loop", "eifelse", "const", "output", "input", ] diff --git a/tierkreis_visualization/tierkreis_visualization/data/outputs.py b/tierkreis_visualization/tierkreis_visualization/data/outputs.py index 33d9d6b54..61f0e6804 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/outputs.py +++ b/tierkreis_visualization/tierkreis_visualization/data/outputs.py @@ -1,11 +1,12 @@ import array -from tierkreis.controller.storage.protocol import ControllerStorage + from tierkreis.controller.data.location import Loc +from tierkreis.controller.storage.protocol import ControllerStorage from tierkreis.exceptions import TierkreisError def outputs_from_loc( - storage: ControllerStorage, loc: Loc, port_name: str + storage: ControllerStorage, loc: Loc, port_name: str, ) -> str | None: try: raw_bytes = storage.read_output(loc, port_name) diff --git a/tierkreis_visualization/tierkreis_visualization/data/workflows.py b/tierkreis_visualization/tierkreis_visualization/data/workflows.py index 744bd065d..ae0a2dcdf 100644 --- a/tierkreis_visualization/tierkreis_visualization/data/workflows.py +++ b/tierkreis_visualization/tierkreis_visualization/data/workflows.py @@ -27,7 +27,7 @@ def get_workflows(storage_type: StorageType) -> list[WorkflowDisplay]: name="tmp", start_time=datetime.now().isoformat(), errors=[], - ) + ), ] return get_workflows_from_disk() @@ -47,8 +47,8 @@ def get_workflows_from_disk() -> list[WorkflowDisplay]: errors = list(set(errors)) workflows.append( WorkflowDisplay( - id=id, id_int=int(id), name=name, start_time=start, errors=errors - ) + id=id, id_int=int(id), name=name, start_time=start, errors=errors, + ), ) except (TypeError, ValueError): continue diff --git a/tierkreis_visualization/tierkreis_visualization/main.py b/tierkreis_visualization/tierkreis_visualization/main.py index bbeb3f857..cca299b99 100644 --- a/tierkreis_visualization/tierkreis_visualization/main.py +++ b/tierkreis_visualization/tierkreis_visualization/main.py @@ -5,7 +5,7 @@ def start() -> None: uvicorn.run( - "tierkreis_visualization.app:get_filestorage_app", timeout_graceful_shutdown=10 + "tierkreis_visualization.app:get_filestorage_app", timeout_graceful_shutdown=10, ) @@ -27,8 +27,8 @@ def graph(argv_index: int = 1) -> None: def openapi() -> None: - from tierkreis_visualization.openapi import generate_openapi from tierkreis_visualization.app import get_filestorage_app + from tierkreis_visualization.openapi import generate_openapi generate_openapi(get_filestorage_app()) diff --git a/tierkreis_visualization/tierkreis_visualization/openapi.py b/tierkreis_visualization/tierkreis_visualization/openapi.py index 0cda73317..2ff0848e6 100644 --- a/tierkreis_visualization/tierkreis_visualization/openapi.py +++ b/tierkreis_visualization/tierkreis_visualization/openapi.py @@ -1,12 +1,14 @@ import json from pathlib import Path + from fastapi import FastAPI -def generate_openapi(app: FastAPI): +def generate_openapi(app: FastAPI) -> None: """Write the openapi spec of `app` to a file. - This will run automatically in development mode.""" + This will run automatically in development mode. + """ from fastapi.openapi.utils import get_openapi spec = get_openapi( diff --git a/tierkreis_visualization/tierkreis_visualization/routers/frontend.py b/tierkreis_visualization/tierkreis_visualization/routers/frontend.py index 7261d8359..e82f86243 100644 --- a/tierkreis_visualization/tierkreis_visualization/routers/frontend.py +++ b/tierkreis_visualization/tierkreis_visualization/routers/frontend.py @@ -1,4 +1,5 @@ from pathlib import Path + from fastapi import APIRouter from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles @@ -9,7 +10,7 @@ @router.get("/{path:path}") -def read_root(): +def read_root(path): return FileResponse( - PACKAGE_DIR / "static" / "dist" / "index.html", media_type="text/html" + PACKAGE_DIR / "static" / "dist" / "index.html", media_type="text/html", ) diff --git a/tierkreis_visualization/tierkreis_visualization/routers/models.py b/tierkreis_visualization/tierkreis_visualization/routers/models.py index 977b4ce7e..8787d603d 100644 --- a/tierkreis_visualization/tierkreis_visualization/routers/models.py +++ b/tierkreis_visualization/tierkreis_visualization/routers/models.py @@ -1,6 +1,7 @@ from pydantic import BaseModel + from tierkreis.controller.data.location import Loc -from tierkreis_visualization.data.models import PyNode, PyEdge +from tierkreis_visualization.data.models import PyEdge, PyNode class PyGraph(BaseModel): diff --git a/tierkreis_visualization/tierkreis_visualization/routers/workflows.py b/tierkreis_visualization/tierkreis_visualization/routers/workflows.py index ad3c55184..d8e2a0f33 100644 --- a/tierkreis_visualization/tierkreis_visualization/routers/workflows.py +++ b/tierkreis_visualization/tierkreis_visualization/routers/workflows.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, HTTPException, Query, Response, status from starlette.responses import JSONResponse, PlainTextResponse from starlette.websockets import WebSocket, WebSocketDisconnect +from watchfiles import awatch # type: ignore from tierkreis.controller.data.location import Loc from tierkreis.controller.storage.graphdata import GraphDataStorage @@ -15,9 +16,6 @@ from tierkreis_visualization.app_config import Request from tierkreis_visualization.data.graph import get_node_data, parse_node_location from tierkreis_visualization.data.outputs import outputs_from_loc - -from watchfiles import awatch # type: ignore - from tierkreis_visualization.data.workflows import WorkflowDisplay, get_workflows from tierkreis_visualization.routers.models import GraphsResponse, PyGraph @@ -27,7 +25,7 @@ @router.websocket("/{workflow_id}/nodes/{node_location_str}") async def websocket_endpoint( - websocket: WebSocket, workflow_id: UUID, node_location_str: str + websocket: WebSocket, workflow_id: UUID, node_location_str: str, ) -> None: if workflow_id.int == 0: return @@ -72,9 +70,9 @@ def list_workflows(request: Request) -> list[WorkflowDisplay]: ) -@router.get("/{workflow_id}/graphs", response_model=GraphsResponse) +@router.get("/{workflow_id}/graphs") def list_nodes( - request: Request, workflow_id: UUID, locs: Annotated[list[Loc], Query()] + request: Request, workflow_id: UUID, locs: Annotated[list[Loc], Query()], ) -> GraphsResponse: storage = request.app.state.get_storage_fn(workflow_id) return GraphsResponse(graphs={loc: get_node_data(storage, loc) for loc in locs}) diff --git a/tierkreis_visualization/tierkreis_visualization/storage.py b/tierkreis_visualization/tierkreis_visualization/storage.py index 647f871b5..f588b9915 100644 --- a/tierkreis_visualization/tierkreis_visualization/storage.py +++ b/tierkreis_visualization/tierkreis_visualization/storage.py @@ -1,8 +1,8 @@ +import sys +from collections.abc import Callable from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from sys import argv -import sys -from typing import Callable from uuid import UUID from tierkreis.controller.data.graph import GraphData @@ -14,7 +14,7 @@ def file_storage_fn(tkr_dir: Path) -> Callable[[UUID], ControllerStorage]: def inner(workflow_id: UUID): return ControllerFileStorage( - workflow_id=workflow_id, tierkreis_directory=tkr_dir + workflow_id=workflow_id, tierkreis_directory=tkr_dir, ) return inner @@ -28,14 +28,16 @@ def graph_data_storage_fn( spec = spec_from_file_location("tkr_tmp.graph", mod_path) if spec is None: - raise ValueError(f"File is not a Python module: {mod_path}") + msg = f"File is not a Python module: {mod_path}" + raise ValueError(msg) module = module_from_spec(spec) sys.modules["tkr_tmp.graph"] = module loader = spec.loader if loader is None: - raise ValueError("Could not get loader from module.") + msg = "Could not get loader from module." + raise ValueError(msg) loader.exec_module(module) graph = getattr(module, var).data diff --git a/tierkreis_visualization/tierkreis_visualization/visualize_graph.py b/tierkreis_visualization/tierkreis_visualization/visualize_graph.py index c6c87caf4..145839ea5 100644 --- a/tierkreis_visualization/tierkreis_visualization/visualize_graph.py +++ b/tierkreis_visualization/tierkreis_visualization/visualize_graph.py @@ -1,9 +1,9 @@ import asyncio + import uvicorn from tierkreis.builder import GraphBuilder from tierkreis.controller.data.graph import GraphData - from tierkreis_visualization.app import app_from_graph_data diff --git a/tierkreis_workers/aer_worker/main.py b/tierkreis_workers/aer_worker/main.py index 73120c43b..0a2f9d881 100644 --- a/tierkreis_workers/aer_worker/main.py +++ b/tierkreis_workers/aer_worker/main.py @@ -2,13 +2,14 @@ from sys import argv from typing import Any -from tierkreis import Worker from pytket._tket.circuit import Circuit from pytket.backends.backendresult import BackendResult -from pytket.extensions.qiskit.qiskit_convert import tk_to_qiskit from pytket.extensions.qiskit.backends.aer import AerBackend +from pytket.extensions.qiskit.qiskit_convert import tk_to_qiskit from qiskit import qasm3 +from tierkreis import Worker + worker = Worker("aer_worker") logger = logging.getLogger(__name__) @@ -109,7 +110,7 @@ def submit_single(circuit: Circuit, n_shots: int) -> BackendResult: return AerBackend().run_circuit(circuit, n_shots=n_shots) -def main(): +def main() -> None: worker.app(argv) diff --git a/tierkreis_workers/ibmq_worker/default_pass.py b/tierkreis_workers/ibmq_worker/default_pass.py index ef46f6658..8b807f50f 100644 --- a/tierkreis_workers/ibmq_worker/default_pass.py +++ b/tierkreis_workers/ibmq_worker/default_pass.py @@ -7,10 +7,10 @@ BasePass, CliffordSimp, DecomposeBoxes, + FullMappingPass, FullPeepholeOptimise, GreedyPauliSimp, KAKDecomposition, - FullMappingPass, RemoveBarriers, RemoveRedundancies, SequencePass, @@ -18,7 +18,6 @@ ) from pytket.placement import GraphPlacement - IBMQ_GATE_SET: set[OpType] = { OpType.Rx, OpType.Rz, @@ -100,11 +99,11 @@ def default_compilation_pass( OpType.XXPhase, OpType.YYPhase, OpType.PhasedX, - } + }, ), ) passlist.append( - GreedyPauliSimp(thread_timeout=300, only_reduce=True, trials=10) + GreedyPauliSimp(thread_timeout=300, only_reduce=True, trials=10), ) assert arch is not None if not isinstance(arch, FullyConnected): @@ -115,7 +114,7 @@ def default_compilation_pass( arch, GraphPlacement(arch), [LexiLabellingMethod(), LexiRouteRoutingMethod(10)], - ) + ), ) if optimization_level == 1: passlist.append(SynthesiseTket()) @@ -125,7 +124,7 @@ def default_compilation_pass( KAKDecomposition(allow_swaps=False), CliffordSimp(False), SynthesiseTket(), - ] + ], ) if optimization_level == 3: # noqa: PLR2004 passlist.append(SynthesiseTket()) @@ -134,6 +133,6 @@ def default_compilation_pass( AutoRebase(primitive_gates), AutoSquash(primitive_1q_gates), RemoveRedundancies(), - ] + ], ) return SequencePass(passlist) diff --git a/tierkreis_workers/ibmq_worker/main.py b/tierkreis_workers/ibmq_worker/main.py index 7dadd365f..ca1f2ca9f 100644 --- a/tierkreis_workers/ibmq_worker/main.py +++ b/tierkreis_workers/ibmq_worker/main.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from sys import argv -from typing import Sequence +from default_pass import IBMQ_GATE_SET, default_compilation_pass from pytket._tket.circuit import Circuit from pytket.architecture import Architecture from pytket.backends.backendinfo import BackendInfo @@ -8,12 +9,9 @@ from pytket.extensions.qiskit.backends.ibm import IBMQBackend from pytket.passes import BasePass - from tierkreis import Worker from tierkreis.exceptions import TierkreisError -from default_pass import default_compilation_pass, IBMQ_GATE_SET - worker = Worker("ibmq_worker") @@ -35,15 +33,16 @@ def get_backend_info(device_name: str) -> BackendInfo: None, ) if info is None: + msg = f"Device {device_name} is not in the list of available IBMQ devices" raise TierkreisError( - f"Device {device_name} is not in the list of available IBMQ devices" + msg, ) return info @worker.task() def backend_pass_from_info( - backend_info: BackendInfo, optimisation_level: int = 2 + backend_info: BackendInfo, optimisation_level: int = 2, ) -> BasePass: """Returns a compilation pass according to the backend info. @@ -55,13 +54,13 @@ def backend_pass_from_info( :rtype: BasePass """ return IBMQBackend.pass_from_info( - backend_info, optimisation_level=optimisation_level + backend_info, optimisation_level=optimisation_level, ) @worker.task() def backend_default_compilation_pass( - device_name: str, optimisation_level: int = 2 + device_name: str, optimisation_level: int = 2, ) -> BasePass: """Returns the default compilation pass for a given device name. @@ -119,7 +118,7 @@ def compile( @worker.task() def compile_circuit_ibmq( - circuit: Circuit, device_name: str, optimisation_level: int = 2 + circuit: Circuit, device_name: str, optimisation_level: int = 2, ) -> Circuit: """Applies a predefined optimization pass for IBMQ devices. @@ -137,7 +136,7 @@ def compile_circuit_ibmq( @worker.task() def compile_circuits_ibmq( - circuits: list[Circuit], device_name: str, optimisation_level: int = 2 + circuits: list[Circuit], device_name: str, optimisation_level: int = 2, ) -> list[Circuit]: """Applies a predefined optimization pass for IBMQ devices. @@ -167,7 +166,7 @@ def run_circuit(circuit: Circuit, n_shots: int, device_name: str) -> BackendResu return backend.run_circuit(circuit, n_shots) -def main(): +def main() -> None: worker.app(argv) diff --git a/tierkreis_workers/nexus_worker/main.py b/tierkreis_workers/nexus_worker/main.py index 37a1038f9..d9bfb7118 100644 --- a/tierkreis_workers/nexus_worker/main.py +++ b/tierkreis_workers/nexus_worker/main.py @@ -3,7 +3,6 @@ from datetime import datetime from sys import argv from time import sleep -from qnexus.models.references import ExecutionResultRef import qnexus as qnx from pytket._tket.circuit import Circuit @@ -12,10 +11,10 @@ from qnexus import BackendConfig from qnexus.exceptions import ResourceFetchFailed from qnexus.models import QuantinuumConfig -from qnexus.models.references import ExecuteJobRef, ExecutionProgram -from tierkreis.exceptions import TierkreisError +from qnexus.models.references import ExecuteJobRef, ExecutionProgram, ExecutionResultRef from tierkreis import Worker +from tierkreis.exceptions import TierkreisError logger = logging.getLogger(__name__) worker = Worker("nexus_worker") @@ -32,9 +31,8 @@ def upload_circuit(project_name: str, circ: Circuit) -> ExecutionProgram: :return: A reference to the uploaded circuit. :rtype: ExecutionProgram """ - my_project_ref = qnx.projects.get_or_create(name=project_name) - circuit_name = circ.name if circ.name else f"circuit_{datetime.now()}" + circuit_name = circ.name or f"circuit_{datetime.now()}" qnx.context.set_active_project(my_project_ref) return qnx.circuits.upload(name=circuit_name, circuit=circ, project=my_project_ref) @@ -62,7 +60,6 @@ def start_execute_job( :return: A reference to the started execution job. :rtype: ExecuteJobRef """ - my_project_ref = qnx.projects.get_or_create(name=project_name) qnx.context.set_active_project(my_project_ref) return qnx.start_execute_job(circuits, n_shots, backend_config, job_name) @@ -78,15 +75,14 @@ def is_running(execute_ref: ExecuteJobRef) -> bool: :return: True if the job is still running, False otherwise. :rtype: bool """ - try: st = qnx.jobs.status(execute_ref).status - except ResourceFetchFailed as exc: - print(exc) + except ResourceFetchFailed: return True if st in [StatusEnum.CANCELLING, StatusEnum.CANCELLED, StatusEnum.ERROR]: - raise TierkreisError(f"Job status was {st}") + msg = f"Job status was {st}" + raise TierkreisError(msg) return st != StatusEnum.COMPLETED @@ -100,13 +96,13 @@ def get_results(execute_ref: ExecuteJobRef) -> list[BackendResult]: :return: A list of backend results for each circuit in the job. :rtype: list[BackendResult] """ - execute_job_result_refs = qnx.jobs.results(execute_ref) backend_results: list[BackendResult] = [] for i in range(len(execute_job_result_refs)): ref_result = execute_job_result_refs[i] if not isinstance(ref_result, ExecutionResultRef): - raise TierkreisError(f"Result incomplete: {ref_result}") + msg = f"Result incomplete: {ref_result}" + raise TierkreisError(msg) result = ref_result.download_result() assert isinstance(result, BackendResult) backend_results.append(result) @@ -118,19 +114,18 @@ def get_results(execute_ref: ExecuteJobRef) -> list[BackendResult]: @worker.task() def check_status(execute_ref: ExecuteJobRef) -> str: - warnings.warn("check_status is deprecated, use is_running instead") + warnings.warn("check_status is deprecated, use is_running instead", stacklevel=2) sleep(30) try: return str(qnx.jobs.status(execute_ref).status) - except ResourceFetchFailed as exc: - print(exc) + except ResourceFetchFailed: return str(StatusEnum.SUBMITTED) @worker.task() def submit(circuits: list[Circuit], n_shots: int) -> ExecuteJobRef: warnings.warn( - "submit is deprecated, use upload_circuit and start_execute_job instead" + "submit is deprecated, use upload_circuit and start_execute_job instead", stacklevel=2, ) my_project_ref = qnx.projects.get_or_create(name="Riken-Test") qnx.context.set_active_project(my_project_ref) @@ -142,20 +137,19 @@ def submit(circuits: list[Circuit], n_shots: int) -> ExecuteJobRef: name=f"My Circuit from {datetime.now()}", circuit=circ, project=my_project_ref, - ) + ), ) - execute_job_ref = qnx.start_execute_job( + return qnx.start_execute_job( programs=my_circuit_refs, name=f"My Execute Job from {datetime.now()}", n_shots=[n_shots] * len(my_circuit_refs), backend_config=QuantinuumConfig(device_name="reimei-E"), project=my_project_ref, ) - return execute_job_ref -def main(): +def main() -> None: worker.app(argv) diff --git a/tierkreis_workers/pytket_worker/compile_circuit.py b/tierkreis_workers/pytket_worker/compile_circuit.py index b9b26f6e9..8989169b0 100644 --- a/tierkreis_workers/pytket_worker/compile_circuit.py +++ b/tierkreis_workers/pytket_worker/compile_circuit.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from enum import Enum, auto -from typing import Sequence, assert_never +from typing import assert_never from pytket._tket.circuit import Circuit from pytket.architecture import Architecture, FullyConnected @@ -11,10 +12,10 @@ BasePass, CliffordSimp, DecomposeBoxes, + FullMappingPass, FullPeepholeOptimise, GreedyPauliSimp, KAKDecomposition, - FullMappingPass, RemoveBarriers, RemoveRedundancies, SequencePass, @@ -22,6 +23,7 @@ ) from pytket.placement import GraphPlacement from pytket.qasm.qasm import circuit_from_qasm_str, circuit_to_qasm_str + from tierkreis.exceptions import TierkreisError @@ -95,23 +97,27 @@ def compile_circuit( if input_format == CircuitFormat.QASM2: circuit = circuit_from_qasm_str(circuit) else: - raise TierkreisError("Invalid combination of input type and format.") + msg = "Invalid combination of input type and format." + raise TierkreisError(msg) if isinstance(circuit, bytes): if input_format == CircuitFormat.QIR: try: from pytket_qirpass import qir_to_pytket except ModuleNotFoundError: - raise TierkreisError("Could not resolve pytket_qirpass") + msg = "Could not resolve pytket_qirpass" + raise TierkreisError(msg) circuit = qir_to_pytket(circuit) else: - raise TierkreisError("Invalid combination of input type and format.") + msg = "Invalid combination of input type and format." + raise TierkreisError(msg) assert isinstance(circuit, Circuit) qubits: set[int] = set() if coupling_map is not None: - qubits = set([q for pair in coupling_map for q in pair]) + qubits = {q for pair in coupling_map for q in pair} if len(qubits) < len(circuit.qubits): - raise TierkreisError("Circuit does not fit on device.") + msg = "Circuit does not fit on device." + raise TierkreisError(msg) arch = Architecture(coupling_map) else: arch = FullyConnected(len(qubits)) @@ -130,10 +136,12 @@ def compile_circuit( try: from pytket.qir.conversion.api import pytket_to_qir except ModuleNotFoundError: - raise TierkreisError("Could not resolve pytket_qirpass") + msg = "Could not resolve pytket_qirpass" + raise TierkreisError(msg) ret = pytket_to_qir(circuit) if ret is None: - raise TierkreisError("Could not transform circuit to QIR.") + msg = "Could not transform circuit to QIR." + raise TierkreisError(msg) return ret case _: assert_never() @@ -185,11 +193,11 @@ def _default_pass( OpType.XXPhase, OpType.YYPhase, OpType.PhasedX, - } + }, ), ) passlist.append( - GreedyPauliSimp(thread_timeout=300, only_reduce=True, trials=10) + GreedyPauliSimp(thread_timeout=300, only_reduce=True, trials=10), ) assert arch is not None if not isinstance(arch, FullyConnected): @@ -201,7 +209,7 @@ def _default_pass( arch, GraphPlacement(arch), [LexiLabellingMethod(), LexiRouteRoutingMethod(10)], - ) + ), ) if optimization_level == 1: passlist.append(SynthesiseTket()) @@ -211,7 +219,7 @@ def _default_pass( KAKDecomposition(allow_swaps=False), CliffordSimp(False), SynthesiseTket(), - ] + ], ) if optimization_level == 3: # noqa: PLR2004 passlist.append(SynthesiseTket()) @@ -220,6 +228,6 @@ def _default_pass( AutoRebase(primitive_gates), AutoSquash(primitive_1q_gates), RemoveRedundancies(), - ] + ], ) return SequencePass(passlist) diff --git a/tierkreis_workers/pytket_worker/main.py b/tierkreis_workers/pytket_worker/main.py index 00bec66c0..2743b8d19 100644 --- a/tierkreis_workers/pytket_worker/main.py +++ b/tierkreis_workers/pytket_worker/main.py @@ -2,7 +2,12 @@ from sys import argv import qnexus as qnx - +from compile_circuit import ( + MINIMAL_GATE_SET, + CircuitFormat, + OptimizationLevel, + compile_circuit, +) from pytket._tket.circuit import Circuit from pytket._tket.unit_id import Bit from pytket.backends.backendinfo import BackendInfo @@ -20,13 +25,6 @@ from tierkreis import Worker from tierkreis.exceptions import TierkreisError -from compile_circuit import ( - MINIMAL_GATE_SET, - CircuitFormat, - OptimizationLevel, - compile_circuit, -) - worker = Worker("pytket_worker") @@ -46,7 +44,7 @@ def get_backend_info(config: BackendConfig) -> BackendInfo: :rtype: BackendInfo """ if not isinstance(config, IBMQConfig) or not isinstance(config, QuantinuumConfig): - raise NotImplementedError() + raise NotImplementedError device = next( filter( lambda x: x.device_name == config.backend_name, @@ -55,8 +53,9 @@ def get_backend_info(config: BackendConfig) -> BackendInfo: None, ) if device is None: + msg = f"Device {config.backend_name} is not in the list of available devices" raise TierkreisError( - f"Device {config.backend_name} is not in the list of available devices" + msg, ) return device.backend_info @@ -105,12 +104,15 @@ def compile_using_info( try: from pytket.extensions.qiskit.backends.ibm import IBMQBackend except ModuleNotFoundError as e: - raise TierkreisError( + msg = ( "Pytket worker could not import IBMQBackend." "Please mnake sure to install the extras to use this task." + ) + raise TierkreisError( + msg, ) from e compilation_pass = IBMQBackend.pass_from_info( - backend_info, optimisation_level, timeout + backend_info, optimisation_level, timeout, ) case QuantinuumConfig(): try: @@ -118,15 +120,18 @@ def compile_using_info( QuantinuumBackend, ) except ModuleNotFoundError as e: - raise TierkreisError( + msg = ( "Pytket worker could not import QuantinuumBackend." "Please mnake sure to install the extras to use this task." + ) + raise TierkreisError( + msg, ) from e compilation_pass = QuantinuumBackend.pass_from_info( - backend_info, optimisation_level=optimisation_level, timeout=timeout + backend_info, optimisation_level=optimisation_level, timeout=timeout, ) case _: - raise NotImplementedError() + raise NotImplementedError compilation_pass.apply(circuit) return circuit @@ -146,7 +151,7 @@ def add_measure_all(circuit: Circuit) -> Circuit: @worker.task() def append_pauli_measurement_impl( - circuit: Circuit, pauli_string: QubitPauliString + circuit: Circuit, pauli_string: QubitPauliString, ) -> Circuit: """Appends pauli measurements according to the pauli string to the circuit. @@ -179,7 +184,7 @@ def optimise_phase_gadgets(circuit: Circuit) -> Circuit: @worker.task() def apply_pass(circuit: Circuit, compiler_pass: BasePass) -> Circuit: - """Applies an arbitrary optimization pass to the circuit + """Applies an arbitrary optimization pass to the circuit. :param circuit: The original circuit. :type circuit: Circuit @@ -255,7 +260,7 @@ def compile_generic_with_fixed_pass( gate_set_op = MINIMAL_GATE_SET else: op_types = {op_type.name: op_type for op_type in OpType} - gate_set_op = set(op_types[gate] for gate in gate_set) + gate_set_op = {op_types[gate] for gate in gate_set} return compile_circuit( circuit, @@ -306,10 +311,12 @@ def to_qir_bytes(circuit: Circuit) -> bytes: try: from pytket.qir.conversion.api import pytket_to_qir except ModuleNotFoundError: - raise TierkreisError("Could not resolve pytket.qir") + msg = "Could not resolve pytket.qir" + raise TierkreisError(msg) ret = pytket_to_qir(circuit) if not isinstance(ret, bytes): - raise TierkreisError("Error when converting Circuit to QIR.") + msg = "Error when converting Circuit to QIR." + raise TierkreisError(msg) return ret @@ -325,7 +332,8 @@ def from_qir_bytes(qir: bytes) -> Circuit: try: from pytket_qirpass import qir_to_pytket except ModuleNotFoundError: - raise TierkreisError("Could not resolve pytket_qirpass") + msg = "Could not resolve pytket_qirpass" + raise TierkreisError(msg) return qir_to_pytket(qir) @@ -338,8 +346,7 @@ def expectation(backend_result: BackendResult) -> float: :return: The estimated expectation value. :rtype: float """ - expectation = expectation_from_counts(backend_result.get_counts()) - return expectation + return expectation_from_counts(backend_result.get_counts()) @worker.task() @@ -401,14 +408,14 @@ def backend_result_from_dict(data: dict[str, list[str]]) -> BackendResult: bit_register += [Bit(key, i) for i in range(len(values[0]))] bits.append([[int(b) for b in shot] for shot in values]) bit_strings = [ - [item for sublist in group for item in sublist] for group in zip(*bits) + [item for sublist in group for item in sublist] for group in zip(*bits, strict=False) ] return BackendResult( - shots=OutcomeArray.from_readouts(bit_strings), c_bits=bit_register + shots=OutcomeArray.from_readouts(bit_strings), c_bits=bit_register, ) -def main(): +def main() -> None: worker.app(argv) diff --git a/tierkreis_workers/pytket_worker/test_main.py b/tierkreis_workers/pytket_worker/test_main.py index 158d848c8..18634d24e 100644 --- a/tierkreis_workers/pytket_worker/test_main.py +++ b/tierkreis_workers/pytket_worker/test_main.py @@ -1,8 +1,8 @@ +from pytket._tket.circuit import Circuit from pytket._tket.unit_id import Bit from pytket.backends.backendresult import BackendResult -from pytket.utils.outcomearray import OutcomeArray -from pytket._tket.circuit import Circuit from pytket.extensions.qiskit.backends.aer import AerBackend +from pytket.utils.outcomearray import OutcomeArray from .main import backend_result_from_dict, backend_result_to_dict diff --git a/tierkreis_workers/quantinuum_worker/default_pass_quantinuum.py b/tierkreis_workers/quantinuum_worker/default_pass_quantinuum.py index df0a22e2a..ee6e8da98 100644 --- a/tierkreis_workers/quantinuum_worker/default_pass_quantinuum.py +++ b/tierkreis_workers/quantinuum_worker/default_pass_quantinuum.py @@ -1,19 +1,19 @@ +from pytket.circuit import OpType from pytket.passes import ( - BasePass, AutoRebase, AutoSquash, + BasePass, DecomposeBoxes, DecomposeTK2, FlattenRelabelRegistersPass, + GreedyPauliSimp, NormaliseTK2, + RemoveBarriers, RemovePhaseOps, RemoveRedundancies, SequencePass, - RemoveBarriers, - GreedyPauliSimp, ) from pytket.passes.resizeregpass import scratch_reg_resize_pass -from pytket.circuit import OpType def _gate_set() -> set[OpType]: @@ -66,7 +66,7 @@ def default_compilation_pass() -> BasePass: OpType.XXPhase, OpType.YYPhase, OpType.PhasedX, - } + }, ), GreedyPauliSimp( allow_zzphase=True, @@ -74,7 +74,7 @@ def default_compilation_pass() -> BasePass: thread_timeout=300, trials=10, ), - ] + ], ) passlist.extend(decomposition_passes) rebase_pass = AutoRebase( @@ -87,7 +87,7 @@ def default_compilation_pass() -> BasePass: RemoveRedundancies(), squash, RemoveRedundancies(), - ] + ], ) passlist.append(RemovePhaseOps()) passlist.append(FlattenRelabelRegistersPass("q")) diff --git a/tierkreis_workers/quantinuum_worker/main.py b/tierkreis_workers/quantinuum_worker/main.py index 4a23be227..3aa49848a 100644 --- a/tierkreis_workers/quantinuum_worker/main.py +++ b/tierkreis_workers/quantinuum_worker/main.py @@ -1,20 +1,19 @@ -from sys import argv import time +from sys import argv import qnexus as qnx -from qnexus.models import IssuerEnum -from qnexus.models.references import ExecutionResultRef +from default_pass_quantinuum import default_compilation_pass from pytket._tket.circuit import Circuit from pytket.backends.backendinfo import BackendInfo from pytket.backends.backendresult import BackendResult from pytket.extensions.quantinuum.backends.quantinuum import QuantinuumBackend from pytket.passes import BasePass +from qnexus.models import IssuerEnum +from qnexus.models.references import ExecutionResultRef from tierkreis import Worker from tierkreis.exceptions import TierkreisError -from default_pass_quantinuum import default_compilation_pass - worker = Worker("quantinuum_worker") @@ -31,18 +30,19 @@ def get_backend_info(device_name: str) -> BackendInfo: all_devices = qnx.devices.get_all([IssuerEnum.QUANTINUUM]) info = next(filter(lambda x: x.device_name == device_name, all_devices), None) if info is None: + msg = f"Device {device_name} is not in the list of available Quantinuum devices" raise TierkreisError( - f"Device {device_name} is not in the list of available Quantinuum devices" + msg, ) return info.backend_info @worker.task() def compile_using_info( - circuit: Circuit, backend_info: BackendInfo, optimisation_level: int = 2 + circuit: Circuit, backend_info: BackendInfo, optimisation_level: int = 2, ) -> Circuit: base_pass = QuantinuumBackend.pass_from_info( - backend_info, optimisation_level=optimisation_level + backend_info, optimisation_level=optimisation_level, ) base_pass.apply(circuit) return circuit @@ -50,7 +50,7 @@ def compile_using_info( @worker.task() def backend_pass_from_info( - backend_info: BackendInfo, optimisation_level: int = 2 + backend_info: BackendInfo, optimisation_level: int = 2, ) -> BasePass: """Returns a compilation pass according to the backend info. @@ -62,7 +62,7 @@ def backend_pass_from_info( :rtype: BasePass """ return QuantinuumBackend.pass_from_info( - backend_info, optimisation_level=optimisation_level + backend_info, optimisation_level=optimisation_level, ) @@ -132,14 +132,16 @@ def run_circuit(circuit: Circuit, n_shots: int, device_name: str) -> BackendResu qnx.jobs.wait_for(job_ref) ref_result = qnx.jobs.results(job_ref)[0] if not isinstance(ref_result, ExecutionResultRef): - raise TierkreisError(f"Result incomplete: {ref_result}") + msg = f"Result incomplete: {ref_result}" + raise TierkreisError(msg) backend_result = ref_result.download_result() if not isinstance(backend_result, BackendResult): - raise TierkreisError(f"Result was not a backend result: {backend_result}") + msg = f"Result was not a backend result: {backend_result}" + raise TierkreisError(msg) return backend_result -def main(): +def main() -> None: worker.app(argv) diff --git a/tierkreis_workers/qulacs_worker/main.py b/tierkreis_workers/qulacs_worker/main.py index 9a1905945..6a5fcdd20 100644 --- a/tierkreis_workers/qulacs_worker/main.py +++ b/tierkreis_workers/qulacs_worker/main.py @@ -1,12 +1,13 @@ from sys import argv from typing import Any -from tierkreis import Worker from pytket._tket.circuit import Circuit from pytket.backends.backend import Backend from pytket.backends.backendresult import BackendResult from pytket.extensions.qulacs.backends.qulacs_backend import QulacsBackend +from tierkreis import Worker + worker = Worker("qulacs_worker") @@ -15,8 +16,7 @@ def get_backend(result_type: str = "state_vector", gpu_sim: bool = False) -> Bac from pytket.extensions.qulacs.backends.qulacs_backend import QulacsGPUBackend return QulacsGPUBackend() - else: - return QulacsBackend(result_type) + return QulacsBackend(result_type) def get_config(seed: int | None = None) -> dict[str, Any]: