Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,24 @@ impl Hugr {
reader: impl io::BufRead,
extensions: Option<&ExtensionRegistry>,
) -> Result<Self, ReadError> {
let (hugr, _) = Self::load_with_exts(reader, extensions)?;
Ok(hugr)
}

/// Read a HUGR from an Envelope, and return the enclosed extensions.
///
/// To load a HUGR, all the extensions used in its definition must be
/// available. The Envelope may include some of the extensions, but any
/// additional extensions must be provided in the `extensions` parameter. If
/// `extensions` is `None`, the default [`crate::std_extensions::STD_REG`]
/// is used.
pub fn load_with_exts(
reader: impl io::BufRead,
extensions: Option<&ExtensionRegistry>,
) -> Result<(Self, ExtensionRegistry), ReadError> {
let pkg = Package::load(reader, extensions)?;
match pkg.modules.into_iter().exactly_one() {
Ok(hugr) => Ok(hugr),
Ok(hugr) => Ok((hugr, pkg.extensions)),
Err(e) => Err(ReadError::ExpectedSingleHugr { count: e.count() }),
}
}
Expand Down Expand Up @@ -647,11 +662,16 @@ fn make_module_hugr(root_op: OpType, nodes: usize, ports: usize) -> Option<Hugr>

#[cfg(test)]
pub(crate) mod test {
use crate::Extension;
use crate::extension::prelude::qb_t;
use crate::extension::prelude::usize_t;
use std::{fs::File, io::BufReader};

use super::*;

use crate::builder::test::simple_package;
use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
use crate::extension::ExtensionId;
use crate::extension::prelude::bool_t;
use crate::ops::OpaqueOp;
use crate::ops::handle::NodeHandle;
Expand Down Expand Up @@ -839,4 +859,31 @@ pub(crate) mod test {
}
}
}

#[rstest]
fn load_extensions() {
let my_ext_id = ExtensionId::new("test.ext").unwrap();
let my_ext = Extension::new_test_arc(my_ext_id, |ext, extension_ref| {
ext.add_op(
"MyOp".into(),
String::new(),
Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
extension_ref,
)
.unwrap();
});

let mut package = simple_package();
package.extensions.register(my_ext).unwrap();
let mut hugr_str = Vec::new();
package
.store(&mut hugr_str, EnvelopeConfig::default())
.unwrap();

let (_, exts) = Hugr::load_with_exts(hugr_str.as_slice(), None).unwrap();
assert_eq!(exts.len(), 1);
assert_matches!(exts.get("test.ext"), Some(ext) => {
assert!(ext.get_op("MyOp").is_some());
});
}
}
3 changes: 3 additions & 0 deletions hugr-py/rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Supporting Rust library for the hugr Python bindings.

mod linking;
mod metadata;
mod model;
mod zstd_util;
Expand All @@ -8,6 +9,8 @@ use pyo3::pymodule;

#[pymodule]
mod _hugr {
#[pymodule_export]
use super::linking::linking;
#[pymodule_export]
use super::metadata::metadata;
#[pymodule_export]
Expand Down
76 changes: 76 additions & 0 deletions hugr-py/rust/linking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//! Bindings for linking utilities defined in the hugr-core crate

use pyo3::exceptions::PyException;
use pyo3::{create_exception, pymodule};

#[pymodule(submodule)]
#[pyo3(module = "hugr._hugr.linking")]
pub mod linking {
use hugr_core::envelope::EnvelopeConfig;
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy};
use hugr_core::{Hugr, HugrView};
use pyo3::exceptions::PyValueError;
use pyo3::types::{PyAnyMethods, PyModule};
use pyo3::{Bound, PyResult, Python, pyfunction};

/// Hack: workaround for <https://github.com/PyO3/pyo3/issues/759>
#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
Python::attach(|py| {
py.import("sys")?
.getattr("modules")?
.set_item("hugr._hugr.linking", m)
})
}

#[pyfunction]
fn link_modules(module_into: &[u8], module_from: &[u8]) -> PyResult<Vec<u8>> {
let (mut hugr_into, mut exts_into) =
Hugr::load_with_exts(module_into, None).map_err(|err| {
PyValueError::new_err(format!("Loading of first envelope failed: {}", err))
})?;
let (hugr_from, exts_from) = Hugr::load_with_exts(module_from, None).map_err(|err| {
PyValueError::new_err(format!("Loading of second envelope failed: {}", err))
})?;
let into_executable = hugr_into.entrypoint() != hugr_into.module_root();
let from_executable = hugr_from.entrypoint() != hugr_from.module_root();
let replacement_entrypoint = if into_executable && from_executable {
return Err(PyValueError::new_err(
"Cannot link two executable modules together.",
));
} else if !into_executable && from_executable {
Some(hugr_from.entrypoint())
} else {
None
};

let forest = hugr_into
.link_module(hugr_from, &NameLinkingPolicy::default())
.map_err(|err| super::HugrLinkingError::new_err(err.to_string()))?;
if let Some(new_entrypoint) = replacement_entrypoint {
let Some(node) = forest.node_map.get(&new_entrypoint) else {
panic!("Entrypoint is to be replaced but was not found after linking");
};
hugr_into.set_entrypoint(*node);
}
exts_into.extend(exts_from);

let mut result = Vec::new();
hugr_into
.store_with_exts(&mut result, EnvelopeConfig::binary(), &exts_into)
.unwrap();

// Sanity check roundtrip
debug_assert!(hugr_core::package::Package::load(&result[..], None).is_ok());

Ok(result)
}
}

create_exception!(
_hugr.linking,
HugrLinkingError,
PyException,
"Base exception for HUGR linking errors."
);
3 changes: 3 additions & 0 deletions hugr-py/src/hugr/_hugr/linking.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class HugrLinkingError(Exception): ...

def link_modules(module_into: bytes, module_from: bytes) -> bytes: ...
29 changes: 29 additions & 0 deletions hugr-py/src/hugr/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hugr._serialization.extension as ext_s
import hugr.model as model
from hugr import ext
from hugr._hugr.linking import link_modules
from hugr.envelope import (
EnvelopeConfig,
_make_envelope,
Expand Down Expand Up @@ -188,6 +189,34 @@ def used_extensions(

return result

def link(self, *other: Package):
"""Link this package with other packages, returning a new package containing the
extensions of all packages, as well as a single module created from linking the
modules from all packages.

Args:
*other: Other packages to link with.

Returns:
A new package containing the modules and extensions of all packages.
"""
modules = self.modules[:]
extensions = self.extensions[:]
for pkg in other:
modules.extend(pkg.modules)
for new_ext in pkg.extensions:
if new_ext not in extensions:
extensions.append(new_ext)

if len(modules) == 0:
return Package([], extensions)

result_module_bytes = modules[0].to_bytes()
for module in modules[1:]:
result_module_bytes = link_modules(result_module_bytes, module.to_bytes())

return Package([Hugr.from_bytes(result_module_bytes)], extensions)


@dataclass(frozen=True)
class PackagePointer:
Expand Down
103 changes: 103 additions & 0 deletions hugr-py/tests/test_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest

from hugr import Hugr, tys
from hugr._hugr.linking import link_modules
from hugr.build import Module
from hugr.ops import FuncDefn
from hugr.package import Package
from hugr.std import float, int, logic, ptr


def build_module(*, entrypoint: bool) -> Hugr:
builder = Module()
if entrypoint:
main = builder.define_function(
"main", input_types=[tys.Bool], output_types=[tys.Bool], visibility="Public"
)
main.set_outputs(*main.inputs())
builder.hugr.entrypoint = main.parent_node

return builder.hugr


def test_link_modules_no_entrypoints():
hugr1 = build_module(entrypoint=False)
hugr2 = build_module(entrypoint=False)

linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes()))
assert linked.entrypoint == linked.module_root


def test_link_modules_entrypoint_lhs():
hugr1 = build_module(entrypoint=True)
hugr2 = build_module(entrypoint=False)

linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes()))
assert linked.entrypoint != linked.module_root
entrypoint = linked.entrypoint_op()
assert isinstance(entrypoint, FuncDefn)
assert entrypoint.f_name == "main"


def test_link_modules_entrypoint_rhs():
hugr1 = build_module(entrypoint=False)
hugr2 = build_module(entrypoint=True)

linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes()))
assert linked.entrypoint != linked.module_root
entrypoint = linked.entrypoint_op()
assert isinstance(entrypoint, FuncDefn)
assert entrypoint.f_name == "main"


def test_link_modules_multiple_entrypoints():
hugr1 = build_module(entrypoint=True)
hugr2 = build_module(entrypoint=True)

with pytest.raises(ValueError, match="Cannot link two executable modules together"):
link_modules(hugr1.to_bytes(), hugr2.to_bytes())


def test_link_packages_no_modules():
pkg1 = Package(modules=[])
pkg2 = Package(modules=[])

result_pkg = pkg1.link(pkg2)

assert result_pkg.modules == []


def test_link_packages_extensions():
pkg1 = Package(
modules=[build_module(entrypoint=False)],
extensions=[
int.CONVERSIONS_EXTENSION,
int.INT_TYPES_EXTENSION,
int.INT_OPS_EXTENSION,
# Shared
logic.EXTENSION,
ptr.EXTENSION,
],
)
pkg2 = Package(
modules=[build_module(entrypoint=False)],
extensions=[
float.FLOAT_OPS_EXTENSION,
float.FLOAT_TYPES_EXTENSION,
# Shared
logic.EXTENSION,
ptr.EXTENSION,
],
)

result_pkg = pkg1.link(pkg2)

assert result_pkg.extensions == [
int.CONVERSIONS_EXTENSION,
int.INT_TYPES_EXTENSION,
int.INT_OPS_EXTENSION,
logic.EXTENSION,
ptr.EXTENSION,
float.FLOAT_OPS_EXTENSION,
float.FLOAT_TYPES_EXTENSION,
]
Loading