diff --git a/hugr-py/rust/linking.rs b/hugr-py/rust/linking.rs index a07e5f55e..036fab837 100644 --- a/hugr-py/rust/linking.rs +++ b/hugr-py/rust/linking.rs @@ -4,8 +4,11 @@ use pyo3::exceptions::PyException; use pyo3::{create_exception, pymodule}; #[pymodule(submodule)] -#[pyo3(module = "hugr._hugr.linking")] +#[pyo3(module = "hugr._hugr")] pub mod linking { + #[pymodule_export] + use super::HugrLinkingError; + use hugr_core::envelope::EnvelopeConfig; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::linking::{HugrLinking, NameLinkingPolicy}; @@ -47,7 +50,7 @@ pub mod linking { let forest = hugr_into .link_module(hugr_from, &NameLinkingPolicy::default()) - .map_err(|err| super::HugrLinkingError::new_err(err.to_string()))?; + .map_err(|err| HugrLinkingError::new_err(err.to_string()))?; if let Some(new_entrypoint) = replacement_entrypoint { let Some(node) = forest.node_map.get(&new_entrypoint) else { panic!("Entrypoint is to be replaced but was not found after linking"); @@ -69,7 +72,7 @@ pub mod linking { } create_exception!( - _hugr.linking, + hugr._hugr.linking, HugrLinkingError, PyException, "Base exception for HUGR linking errors." diff --git a/hugr-py/rust/metadata.rs b/hugr-py/rust/metadata.rs index 874bf2f42..e5524ea87 100644 --- a/hugr-py/rust/metadata.rs +++ b/hugr-py/rust/metadata.rs @@ -1,7 +1,7 @@ //! Bindings for metadata keys defined in the hugr-core crate. #[pyo3::pymodule(submodule)] -#[pyo3(module = "hugr._hugr.metadata")] +#[pyo3(module = "hugr._hugr")] pub mod metadata { use hugr_core::metadata::Metadata; use pyo3::types::{PyAnyMethods, PyModule}; diff --git a/hugr-py/rust/model.rs b/hugr-py/rust/model.rs index 032de2d92..d71da9665 100644 --- a/hugr-py/rust/model.rs +++ b/hugr-py/rust/model.rs @@ -6,6 +6,7 @@ use pyo3::exceptions::{PyException, PyValueError}; use pyo3::{PyErr, PyResult, create_exception, pymodule}; #[pymodule(submodule)] +#[pyo3(module = "hugr._hugr")] pub mod model { use hugr_cli::CliArgs; use hugr_model::v0::ast; @@ -132,13 +133,13 @@ pub mod model { // Define custom exceptions create_exception!( - _hugr, + hugr._hugr.model, HugrCliError, PyException, "Base exception for HUGR CLI errors." ); create_exception!( - _hugr, + hugr._hugr.model, HugrCliDescribeError, HugrCliError, "Exception for HUGR CLI describe command errors with partial output." diff --git a/hugr-py/rust/zstd_util.rs b/hugr-py/rust/zstd_util.rs index 66aaa9b11..90011e502 100644 --- a/hugr-py/rust/zstd_util.rs +++ b/hugr-py/rust/zstd_util.rs @@ -3,6 +3,7 @@ use pyo3::pymodule; #[pymodule(submodule)] +#[pyo3(module = "hugr._hugr")] pub mod zstd { use pyo3::types::{PyAnyMethods, PyModule}; use pyo3::{Bound, Python}; diff --git a/hugr-py/tests/test_linking.py b/hugr-py/tests/test_linking.py index 6c5ea4e0b..57a5c9ed6 100644 --- a/hugr-py/tests/test_linking.py +++ b/hugr-py/tests/test_linking.py @@ -1,14 +1,14 @@ import pytest from hugr import Hugr, tys -from hugr._hugr.linking import link_modules +from hugr._hugr.linking import HugrLinkingError, link_modules from hugr.build import Module from hugr.ops import FuncDefn from hugr.package import Package from hugr.std import float, int, logic, ptr -def build_module(*, entrypoint: bool) -> Hugr: +def build_module(*, entrypoint: bool, public_func: bool = False) -> Hugr: builder = Module() if entrypoint: main = builder.define_function( @@ -16,6 +16,16 @@ def build_module(*, entrypoint: bool) -> Hugr: ) main.set_outputs(*main.inputs()) builder.hugr.entrypoint = main.parent_node + elif public_func: + # Entrypoint is already public, so we only need to add one + # if no entrypoint was generated. + func = builder.define_function( + "public_func", + input_types=[tys.Bool], + output_types=[tys.Bool], + visibility="Public", + ) + func.set_outputs(*func.inputs()) return builder.hugr @@ -58,6 +68,17 @@ def test_link_modules_multiple_entrypoints(): link_modules(hugr1.to_bytes(), hugr2.to_bytes()) +def test_link_modules_linking_error(): + hugr1 = build_module(entrypoint=False, public_func=True) + hugr2 = build_module(entrypoint=False, public_func=True) + + with pytest.raises( + HugrLinkingError, + match=r"Source \(Node\([0-9]+\)\) and target \(Node\([0-9]+\)\) both contained FuncDefn with same public name public_func", # noqa: E501 + ): + link_modules(hugr1.to_bytes(), hugr2.to_bytes()) + + def test_link_packages_no_modules(): pkg1 = Package(modules=[]) pkg2 = Package(modules=[])