Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions hugr-py/rust/linking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ use pyo3::exceptions::PyException;
use pyo3::{create_exception, pymodule};

#[pymodule(submodule)]
#[pyo3(module = "hugr._hugr.linking")]
#[pyo3(module = "hugr._hugr")]
pub mod linking {
#[pymodule_export]
use super::HugrLinkingError;

use hugr_core::envelope::EnvelopeConfig;
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy};
Expand Down Expand Up @@ -47,7 +50,7 @@ pub mod linking {

let forest = hugr_into
.link_module(hugr_from, &NameLinkingPolicy::default())
.map_err(|err| super::HugrLinkingError::new_err(err.to_string()))?;
.map_err(|err| HugrLinkingError::new_err(err.to_string()))?;
if let Some(new_entrypoint) = replacement_entrypoint {
let Some(node) = forest.node_map.get(&new_entrypoint) else {
panic!("Entrypoint is to be replaced but was not found after linking");
Expand All @@ -69,7 +72,7 @@ pub mod linking {
}

create_exception!(
_hugr.linking,
hugr._hugr.linking,
HugrLinkingError,
PyException,
"Base exception for HUGR linking errors."
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/rust/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Bindings for metadata keys defined in the hugr-core crate.

#[pyo3::pymodule(submodule)]
#[pyo3(module = "hugr._hugr.metadata")]
#[pyo3(module = "hugr._hugr")]
pub mod metadata {
use hugr_core::metadata::Metadata;
use pyo3::types::{PyAnyMethods, PyModule};
Expand Down
5 changes: 3 additions & 2 deletions hugr-py/rust/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::exceptions::{PyException, PyValueError};
use pyo3::{PyErr, PyResult, create_exception, pymodule};

#[pymodule(submodule)]
#[pyo3(module = "hugr._hugr")]
pub mod model {
use hugr_cli::CliArgs;
use hugr_model::v0::ast;
Expand Down Expand Up @@ -132,13 +133,13 @@ pub mod model {

// Define custom exceptions
create_exception!(
_hugr,
hugr._hugr.model,
HugrCliError,
PyException,
"Base exception for HUGR CLI errors."
);
create_exception!(
_hugr,
hugr._hugr.model,
HugrCliDescribeError,
HugrCliError,
"Exception for HUGR CLI describe command errors with partial output."
Expand Down
1 change: 1 addition & 0 deletions hugr-py/rust/zstd_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use pyo3::pymodule;

#[pymodule(submodule)]
#[pyo3(module = "hugr._hugr")]
pub mod zstd {
use pyo3::types::{PyAnyMethods, PyModule};
use pyo3::{Bound, Python};
Expand Down
25 changes: 23 additions & 2 deletions hugr-py/tests/test_linking.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
import pytest

from hugr import Hugr, tys
from hugr._hugr.linking import link_modules
from hugr._hugr.linking import HugrLinkingError, link_modules
from hugr.build import Module
from hugr.ops import FuncDefn
from hugr.package import Package
from hugr.std import float, int, logic, ptr


def build_module(*, entrypoint: bool) -> Hugr:
def build_module(*, entrypoint: bool, public_func: bool = False) -> Hugr:
builder = Module()
if entrypoint:
main = builder.define_function(
"main", input_types=[tys.Bool], output_types=[tys.Bool], visibility="Public"
)
main.set_outputs(*main.inputs())
builder.hugr.entrypoint = main.parent_node
elif public_func:
# Entrypoint is already public, so we only need to add one
# if no entrypoint was generated.
func = builder.define_function(
"public_func",
input_types=[tys.Bool],
output_types=[tys.Bool],
visibility="Public",
)
func.set_outputs(*func.inputs())

return builder.hugr

Expand Down Expand Up @@ -58,6 +68,17 @@ def test_link_modules_multiple_entrypoints():
link_modules(hugr1.to_bytes(), hugr2.to_bytes())


def test_link_modules_linking_error():
hugr1 = build_module(entrypoint=False, public_func=True)
hugr2 = build_module(entrypoint=False, public_func=True)

with pytest.raises(
HugrLinkingError,
match=r"Source \(Node\([0-9]+\)\) and target \(Node\([0-9]+\)\) both contained FuncDefn with same public name public_func", # noqa: E501
):
link_modules(hugr1.to_bytes(), hugr2.to_bytes())


def test_link_packages_no_modules():
pkg1 = Package(modules=[])
pkg2 = Package(modules=[])
Expand Down
Loading