diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 01d13aec6..dcbeff4f2 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -159,9 +159,24 @@ impl Hugr { reader: impl io::BufRead, extensions: Option<&ExtensionRegistry>, ) -> Result { + let (hugr, _) = Self::load_with_exts(reader, extensions)?; + Ok(hugr) + } + + /// Read a HUGR from an Envelope, and return the enclosed extensions. + /// + /// To load a HUGR, all the extensions used in its definition must be + /// available. The Envelope may include some of the extensions, but any + /// additional extensions must be provided in the `extensions` parameter. If + /// `extensions` is `None`, the default [`crate::std_extensions::STD_REG`] + /// is used. + pub fn load_with_exts( + reader: impl io::BufRead, + extensions: Option<&ExtensionRegistry>, + ) -> Result<(Self, ExtensionRegistry), ReadError> { let pkg = Package::load(reader, extensions)?; match pkg.modules.into_iter().exactly_one() { - Ok(hugr) => Ok(hugr), + Ok(hugr) => Ok((hugr, pkg.extensions)), Err(e) => Err(ReadError::ExpectedSingleHugr { count: e.count() }), } } @@ -647,11 +662,16 @@ fn make_module_hugr(root_op: OpType, nodes: usize, ports: usize) -> Option #[cfg(test)] pub(crate) mod test { + use crate::Extension; + use crate::extension::prelude::qb_t; + use crate::extension::prelude::usize_t; use std::{fs::File, io::BufReader}; use super::*; + use crate::builder::test::simple_package; use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder}; + use crate::extension::ExtensionId; use crate::extension::prelude::bool_t; use crate::ops::OpaqueOp; use crate::ops::handle::NodeHandle; @@ -839,4 +859,31 @@ pub(crate) mod test { } } } + + #[rstest] + fn load_extensions() { + let my_ext_id = ExtensionId::new("test.ext").unwrap(); + let my_ext = Extension::new_test_arc(my_ext_id, |ext, extension_ref| { + ext.add_op( + "MyOp".into(), + String::new(), + Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]), + extension_ref, + ) + .unwrap(); + }); + + let mut package = simple_package(); + package.extensions.register(my_ext).unwrap(); + let mut hugr_str = Vec::new(); + package + .store(&mut hugr_str, EnvelopeConfig::default()) + .unwrap(); + + let (_, exts) = Hugr::load_with_exts(hugr_str.as_slice(), None).unwrap(); + assert_eq!(exts.len(), 1); + assert_matches!(exts.get("test.ext"), Some(ext) => { + assert!(ext.get_op("MyOp").is_some()); + }); + } } diff --git a/hugr-py/rust/lib.rs b/hugr-py/rust/lib.rs index 3b65ec8dc..3606f10b9 100644 --- a/hugr-py/rust/lib.rs +++ b/hugr-py/rust/lib.rs @@ -1,5 +1,6 @@ //! Supporting Rust library for the hugr Python bindings. +mod linking; mod metadata; mod model; mod zstd_util; @@ -8,6 +9,8 @@ use pyo3::pymodule; #[pymodule] mod _hugr { + #[pymodule_export] + use super::linking::linking; #[pymodule_export] use super::metadata::metadata; #[pymodule_export] diff --git a/hugr-py/rust/linking.rs b/hugr-py/rust/linking.rs new file mode 100644 index 000000000..a07e5f55e --- /dev/null +++ b/hugr-py/rust/linking.rs @@ -0,0 +1,76 @@ +//! Bindings for linking utilities defined in the hugr-core crate + +use pyo3::exceptions::PyException; +use pyo3::{create_exception, pymodule}; + +#[pymodule(submodule)] +#[pyo3(module = "hugr._hugr.linking")] +pub mod linking { + use hugr_core::envelope::EnvelopeConfig; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy}; + use hugr_core::{Hugr, HugrView}; + use pyo3::exceptions::PyValueError; + use pyo3::types::{PyAnyMethods, PyModule}; + use pyo3::{Bound, PyResult, Python, pyfunction}; + + /// Hack: workaround for + #[pymodule_init] + fn init(m: &Bound<'_, PyModule>) -> PyResult<()> { + Python::attach(|py| { + py.import("sys")? + .getattr("modules")? + .set_item("hugr._hugr.linking", m) + }) + } + + #[pyfunction] + fn link_modules(module_into: &[u8], module_from: &[u8]) -> PyResult> { + let (mut hugr_into, mut exts_into) = + Hugr::load_with_exts(module_into, None).map_err(|err| { + PyValueError::new_err(format!("Loading of first envelope failed: {}", err)) + })?; + let (hugr_from, exts_from) = Hugr::load_with_exts(module_from, None).map_err(|err| { + PyValueError::new_err(format!("Loading of second envelope failed: {}", err)) + })?; + let into_executable = hugr_into.entrypoint() != hugr_into.module_root(); + let from_executable = hugr_from.entrypoint() != hugr_from.module_root(); + let replacement_entrypoint = if into_executable && from_executable { + return Err(PyValueError::new_err( + "Cannot link two executable modules together.", + )); + } else if !into_executable && from_executable { + Some(hugr_from.entrypoint()) + } else { + None + }; + + let forest = hugr_into + .link_module(hugr_from, &NameLinkingPolicy::default()) + .map_err(|err| super::HugrLinkingError::new_err(err.to_string()))?; + if let Some(new_entrypoint) = replacement_entrypoint { + let Some(node) = forest.node_map.get(&new_entrypoint) else { + panic!("Entrypoint is to be replaced but was not found after linking"); + }; + hugr_into.set_entrypoint(*node); + } + exts_into.extend(exts_from); + + let mut result = Vec::new(); + hugr_into + .store_with_exts(&mut result, EnvelopeConfig::binary(), &exts_into) + .unwrap(); + + // Sanity check roundtrip + debug_assert!(hugr_core::package::Package::load(&result[..], None).is_ok()); + + Ok(result) + } +} + +create_exception!( + _hugr.linking, + HugrLinkingError, + PyException, + "Base exception for HUGR linking errors." +); diff --git a/hugr-py/src/hugr/_hugr/linking.pyi b/hugr-py/src/hugr/_hugr/linking.pyi new file mode 100644 index 000000000..e56eb69a3 --- /dev/null +++ b/hugr-py/src/hugr/_hugr/linking.pyi @@ -0,0 +1,3 @@ +class HugrLinkingError(Exception): ... + +def link_modules(module_into: bytes, module_from: bytes) -> bytes: ... diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py index 7c392a5f8..c1f31edac 100644 --- a/hugr-py/src/hugr/package.py +++ b/hugr-py/src/hugr/package.py @@ -10,6 +10,7 @@ import hugr._serialization.extension as ext_s import hugr.model as model from hugr import ext +from hugr._hugr.linking import link_modules from hugr.envelope import ( EnvelopeConfig, _make_envelope, @@ -188,6 +189,34 @@ def used_extensions( return result + def link(self, *other: Package): + """Link this package with other packages, returning a new package containing the + extensions of all packages, as well as a single module created from linking the + modules from all packages. + + Args: + *other: Other packages to link with. + + Returns: + A new package containing the modules and extensions of all packages. + """ + modules = self.modules[:] + extensions = self.extensions[:] + for pkg in other: + modules.extend(pkg.modules) + for new_ext in pkg.extensions: + if new_ext not in extensions: + extensions.append(new_ext) + + if len(modules) == 0: + return Package([], extensions) + + result_module_bytes = modules[0].to_bytes() + for module in modules[1:]: + result_module_bytes = link_modules(result_module_bytes, module.to_bytes()) + + return Package([Hugr.from_bytes(result_module_bytes)], extensions) + @dataclass(frozen=True) class PackagePointer: diff --git a/hugr-py/tests/test_linking.py b/hugr-py/tests/test_linking.py new file mode 100644 index 000000000..6c5ea4e0b --- /dev/null +++ b/hugr-py/tests/test_linking.py @@ -0,0 +1,103 @@ +import pytest + +from hugr import Hugr, tys +from hugr._hugr.linking import link_modules +from hugr.build import Module +from hugr.ops import FuncDefn +from hugr.package import Package +from hugr.std import float, int, logic, ptr + + +def build_module(*, entrypoint: bool) -> Hugr: + builder = Module() + if entrypoint: + main = builder.define_function( + "main", input_types=[tys.Bool], output_types=[tys.Bool], visibility="Public" + ) + main.set_outputs(*main.inputs()) + builder.hugr.entrypoint = main.parent_node + + return builder.hugr + + +def test_link_modules_no_entrypoints(): + hugr1 = build_module(entrypoint=False) + hugr2 = build_module(entrypoint=False) + + linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes())) + assert linked.entrypoint == linked.module_root + + +def test_link_modules_entrypoint_lhs(): + hugr1 = build_module(entrypoint=True) + hugr2 = build_module(entrypoint=False) + + linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes())) + assert linked.entrypoint != linked.module_root + entrypoint = linked.entrypoint_op() + assert isinstance(entrypoint, FuncDefn) + assert entrypoint.f_name == "main" + + +def test_link_modules_entrypoint_rhs(): + hugr1 = build_module(entrypoint=False) + hugr2 = build_module(entrypoint=True) + + linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes())) + assert linked.entrypoint != linked.module_root + entrypoint = linked.entrypoint_op() + assert isinstance(entrypoint, FuncDefn) + assert entrypoint.f_name == "main" + + +def test_link_modules_multiple_entrypoints(): + hugr1 = build_module(entrypoint=True) + hugr2 = build_module(entrypoint=True) + + with pytest.raises(ValueError, match="Cannot link two executable modules together"): + link_modules(hugr1.to_bytes(), hugr2.to_bytes()) + + +def test_link_packages_no_modules(): + pkg1 = Package(modules=[]) + pkg2 = Package(modules=[]) + + result_pkg = pkg1.link(pkg2) + + assert result_pkg.modules == [] + + +def test_link_packages_extensions(): + pkg1 = Package( + modules=[build_module(entrypoint=False)], + extensions=[ + int.CONVERSIONS_EXTENSION, + int.INT_TYPES_EXTENSION, + int.INT_OPS_EXTENSION, + # Shared + logic.EXTENSION, + ptr.EXTENSION, + ], + ) + pkg2 = Package( + modules=[build_module(entrypoint=False)], + extensions=[ + float.FLOAT_OPS_EXTENSION, + float.FLOAT_TYPES_EXTENSION, + # Shared + logic.EXTENSION, + ptr.EXTENSION, + ], + ) + + result_pkg = pkg1.link(pkg2) + + assert result_pkg.extensions == [ + int.CONVERSIONS_EXTENSION, + int.INT_TYPES_EXTENSION, + int.INT_OPS_EXTENSION, + logic.EXTENSION, + ptr.EXTENSION, + float.FLOAT_OPS_EXTENSION, + float.FLOAT_TYPES_EXTENSION, + ]