Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,24 @@ impl Hugr {
}
}

/// Read a HUGR from an Envelope, and return the enclosed extensions.
///
/// To load a HUGR, all the extensions used in its definition must be
/// available. The Envelope may include some of the extensions, but any
/// additional extensions must be provided in the `extensions` parameter. If
/// `extensions` is `None`, the default [`crate::std_extensions::STD_REG`]
/// is used.
pub fn load_with_exts(
reader: impl io::BufRead,
extensions: Option<&ExtensionRegistry>,
) -> Result<(Self, ExtensionRegistry), ReadError> {
let pkg = Package::load(reader, extensions)?;
match pkg.modules.into_iter().exactly_one() {
Ok(hugr) => Ok((hugr, pkg.extensions)),
Err(e) => Err(ReadError::ExpectedSingleHugr { count: e.count() }),
}
}

/// Read a HUGR from an Envelope encoded in a string.
///
/// Note that not all Envelopes are valid strings. In the general case,
Expand Down
3 changes: 3 additions & 0 deletions hugr-py/rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Supporting Rust library for the hugr Python bindings.

mod linking;
mod metadata;
mod model;
mod zstd_util;
Expand All @@ -8,6 +9,8 @@ use pyo3::pymodule;

#[pymodule]
mod _hugr {
#[pymodule_export]
use super::linking::linking;
#[pymodule_export]
use super::metadata::metadata;
#[pymodule_export]
Expand Down
76 changes: 76 additions & 0 deletions hugr-py/rust/linking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//! Bindings for linking utilities defined in the hugr-core crate

use pyo3::exceptions::PyException;
use pyo3::{create_exception, pymodule};

#[pymodule(submodule)]
#[pyo3(module = "hugr._hugr.linking")]
pub mod linking {
use hugr_core::envelope::EnvelopeConfig;
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy};
use hugr_core::{Hugr, HugrView};
use pyo3::exceptions::PyValueError;
use pyo3::types::{PyAnyMethods, PyModule};
use pyo3::{Bound, PyResult, Python, pyfunction};

/// Hack: workaround for <https://github.com/PyO3/pyo3/issues/759>
#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
Python::attach(|py| {
py.import("sys")?
.getattr("modules")?
.set_item("hugr._hugr.linking", m)
})
}

#[pyfunction]
fn link_modules(module_into: &[u8], module_from: &[u8]) -> PyResult<Vec<u8>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice that we're going to have all three hugrs in memory, which seems like it could be a problem.
I don't really have strong enough rust/pyO3 foo to say the solution, but possibly:

  1. take as arguments impl Reader so the files are read here and their contents can be dropped
  2. Take Package as argument so we can consume them when they're linked in and release the memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could instead take an owned Vec<u8> if that helps... See https://pyo3.rs/v0.28.2/conversions/tables.html for allowed types / types that have an implementation of the required traits out of the box.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that'd be good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the param. It seems that the reader trait to read into a hugr is not implemented for a vec though, so I still use a slice. It might also not make much of a difference since Python has no concept of ownership, so I presume the contents are copied anyway. This might actually be more expensive now 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I benchmarked it, and using the QAOA example using Vec<u8 I had about 0.048s per link call (from Python), whereas using &[u8] gave me about 0.031s per link call (probably due to avoiding an unnecessary copy. I reverted the change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair 👍

let (mut hugr_into, mut exts_into) =
Hugr::load_with_exts(module_into, None).map_err(|err| {
PyValueError::new_err(format!("Loading of first envelope failed: {}", err))
})?;
let (hugr_from, exts_from) = Hugr::load_with_exts(module_from, None).map_err(|err| {
PyValueError::new_err(format!("Loading of second envelope failed: {}", err))
})?;
let into_executable = hugr_into.entrypoint() != hugr_into.module_root();
let from_executable = hugr_from.entrypoint() != hugr_from.module_root();
let replacement_entrypoint = if into_executable && from_executable {
return Err(PyValueError::new_err(
"Cannot link two executable modules together.",
));
} else if !into_executable && from_executable {
Some(hugr_from.entrypoint())
} else {
None
};

let forest = hugr_into
.link_module(hugr_from, &NameLinkingPolicy::default())
.map_err(|err| super::HugrLinkingError::new_err(err.to_string()))?;
if let Some(new_entrypoint) = replacement_entrypoint {
let Some(node) = forest.node_map.get(&new_entrypoint) else {
panic!("Entrypoint is to be replaced but was not found after linking");
};
hugr_into.set_entrypoint(*node);
}
exts_into.extend(exts_from);

let mut result = Vec::new();
hugr_into
.store_with_exts(&mut result, EnvelopeConfig::binary(), &exts_into)
.unwrap();

hugr_core::package::Package::load(&result[..], None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like something we don't want in production, I'd say make it a debug assert, but not sure how that works with maturin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oop yeah, that was something that was used for debugging. I will have a look at debug asserts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work as you would expect in regular rust. I pushed it to a debug statement.

.map_err(|err| PyValueError::new_err(format!("Roundtrip failed: {:?}", err)))?;

Ok(result)
}
}

create_exception!(
_hugr.linking,
HugrLinkingError,
PyException,
"Base exception for HUGR linking errors."
);
3 changes: 3 additions & 0 deletions hugr-py/src/hugr/_hugr/linking.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class HugrLinkingError(Exception): ...

def link_modules(module_into: bytes, module_from: bytes) -> bytes: ...
29 changes: 29 additions & 0 deletions hugr-py/src/hugr/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hugr._serialization.extension as ext_s
import hugr.model as model
from hugr import ext
from hugr._hugr.linking import link_modules
from hugr.envelope import (
EnvelopeConfig,
_make_envelope,
Expand Down Expand Up @@ -188,6 +189,34 @@ def used_extensions(

return result

def link(self, *other: Package):
"""Link this package with other packages, returning a new package containing the
extensions of all packages, as well as a single module created from linking the
modules from all packages.

Args:
*other: Other packages to link with.

Returns:
A new package containing the modules and extensions of all packages.
"""
modules = self.modules[:]
extensions = self.extensions[:]
for pkg in other:
modules.extend(pkg.modules)
for new_ext in pkg.extensions:
if new_ext not in extensions:
extensions.append(new_ext)

if len(modules) == 0:
return Package([], extensions)

result_module_bytes = modules[0].to_bytes()
for module in modules[1:]:
result_module_bytes = link_modules(result_module_bytes, module.to_bytes())

return Package([Hugr.from_bytes(result_module_bytes)], extensions)


@dataclass(frozen=True)
class PackagePointer:
Expand Down
103 changes: 103 additions & 0 deletions hugr-py/tests/test_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest

from hugr import Hugr, tys
from hugr._hugr.linking import link_modules
from hugr.build import Module
from hugr.ops import FuncDefn
from hugr.package import Package
from hugr.std import float, int, logic, ptr


def build_module(*, entrypoint: bool) -> Hugr:
builder = Module()
if entrypoint:
main = builder.define_function(
"main", input_types=[tys.Bool], output_types=[tys.Bool], visibility="Public"
)
main.set_outputs(*main.inputs())
builder.hugr.entrypoint = main.parent_node

return builder.hugr


def test_link_modules_no_entrypoints():
hugr1 = build_module(entrypoint=False)
hugr2 = build_module(entrypoint=False)

linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes()))
assert linked.entrypoint == linked.module_root


def test_link_modules_entrypoint_lhs():
hugr1 = build_module(entrypoint=True)
hugr2 = build_module(entrypoint=False)

linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes()))
assert linked.entrypoint != linked.module_root
entrypoint = linked.entrypoint_op()
assert isinstance(entrypoint, FuncDefn)
assert entrypoint.f_name == "main"


def test_link_modules_entrypoint_rhs():
hugr1 = build_module(entrypoint=False)
hugr2 = build_module(entrypoint=True)

linked = Hugr.from_bytes(link_modules(hugr1.to_bytes(), hugr2.to_bytes()))
assert linked.entrypoint != linked.module_root
entrypoint = linked.entrypoint_op()
assert isinstance(entrypoint, FuncDefn)
assert entrypoint.f_name == "main"


def test_link_modules_multiple_entrypoints():
hugr1 = build_module(entrypoint=True)
hugr2 = build_module(entrypoint=True)

with pytest.raises(ValueError, match="Cannot link two executable modules together"):
link_modules(hugr1.to_bytes(), hugr2.to_bytes())


def test_link_packages_no_modules():
pkg1 = Package(modules=[])
pkg2 = Package(modules=[])

result_pkg = pkg1.link(pkg2)

assert result_pkg.modules == []


def test_link_packages_extensions():
pkg1 = Package(
modules=[build_module(entrypoint=False)],
extensions=[
int.CONVERSIONS_EXTENSION,
int.INT_TYPES_EXTENSION,
int.INT_OPS_EXTENSION,
# Shared
logic.EXTENSION,
ptr.EXTENSION,
],
)
pkg2 = Package(
modules=[build_module(entrypoint=False)],
extensions=[
float.FLOAT_OPS_EXTENSION,
float.FLOAT_TYPES_EXTENSION,
# Shared
logic.EXTENSION,
ptr.EXTENSION,
],
)

result_pkg = pkg1.link(pkg2)

assert result_pkg.extensions == [
int.CONVERSIONS_EXTENSION,
int.INT_TYPES_EXTENSION,
int.INT_OPS_EXTENSION,
logic.EXTENSION,
ptr.EXTENSION,
float.FLOAT_OPS_EXTENSION,
float.FLOAT_TYPES_EXTENSION,
]
Loading