From 937cd143d1a666ef3a2044d4f2cfd90f662b5692 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 11 Mar 2025 19:14:20 +0000 Subject: [PATCH 001/123] Add TypeTransformer and (Sum/Function)Type(Arg/Row)::transform --- hugr-core/src/types.rs | 70 +++++++++++++++++++++++++++++++ hugr-core/src/types/custom.rs | 7 +++- hugr-core/src/types/signature.rs | 10 ++++- hugr-core/src/types/type_param.rs | 24 ++++++++++- hugr-core/src/types/type_row.rs | 16 ++++++- 5 files changed, 122 insertions(+), 5 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index d69c7ef7d6..9b5faefeaf 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -267,6 +267,23 @@ impl SumType { SumType::General { rows } => Either::Right(rows.iter()), } } + + /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) + /// + /// Returns true if any part of the sum type (may have) changed, or false + /// for definitely no change. + pub fn transform(&mut self, tr: &T) -> Result { + Ok(match self { + SumType::Unit { .. } => false, + SumType::General { rows } => { + let mut any_changed = false; + for r in rows.iter_mut() { + any_changed |= r.transform(tr)?; + } + any_changed + } + }) + } } impl From for TypeBase { @@ -528,6 +545,34 @@ impl TypeBase { } } } + + /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) + /// + /// Returns true if the Type (may have) changed, or false if it definitely didn't. + pub fn transform(&mut self, tr: &T) -> Result { + match &mut self.0 { + TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => Ok(false), + TypeEnum::Extension(custom_type) => { + Ok(if let Some(nt) = tr.apply_custom(custom_type)? { + *self = nt.into_(); + true + } else { + let mut args_changed = false; + for a in custom_type.args_mut() { + args_changed |= a.transform(tr)? + } + if args_changed { + *custom_type = custom_type + .get_type_def(&custom_type.get_extension()?)? + .instantiate(custom_type.args())?; + } + args_changed + }) + } + TypeEnum::Function(fty) => fty.transform(tr), + TypeEnum::Sum(sum_type) => sum_type.transform(tr), + } + } } impl Type { @@ -666,6 +711,31 @@ impl<'a> Substitution<'a> { } } +/// A transformation that can be applied to a [Type] or [TypeArg]. +/// More general in some ways than a Substitution: can fail with a +/// [Self::Err], may change [TypeBound::Copyable] to [TypeBound::Any], +/// and applies to arbitrary extension types rather than type variables. +pub trait TypeTransformer { + /// Error returned when a [CustomType] cannot be transformed, or a type + /// containing it (e.g. if changing a [TypeArg::Type] from copyable to + /// linear invalidates a parameterized type). + type Err: std::error::Error + From; + + /// Applies the transformation to an extension type. + /// + /// Note that if the [CustomType] has type arguments, these will *not* + /// have been transformed first (this might not produce a valid type + /// due to changes in [TypeBound]). + /// + /// Returns a type to use instead, or None to indicate no change + /// (in which case, the TypeArgs will be transformed instead. + /// To prevent transforming the arguments, return `t.clone().into()`.) + fn apply_custom(&self, t: &CustomType) -> Result, Self::Err>; + + // Note: in future releases more methods may be added here to transform other types. + // By defaulting such trait methods to Ok(None), backwards compatibility will be preserved. +} + pub(crate) fn check_typevar_decl( decls: &[TypeParam], idx: usize, diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 1de1957169..d456053869 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -86,7 +86,10 @@ impl CustomType { def.check_custom(self) } - fn get_type_def<'a>(&self, ext: &'a Arc) -> Result<&'a TypeDef, SignatureError> { + pub(super) fn get_type_def<'a>( + &self, + ext: &'a Arc, + ) -> Result<&'a TypeDef, SignatureError> { ext.get_type(&self.id) .ok_or(SignatureError::ExtensionTypeNotFound { exn: self.extension.clone(), @@ -94,7 +97,7 @@ impl CustomType { }) } - fn get_extension(&self) -> Result, SignatureError> { + pub(super) fn get_extension(&self) -> Result, SignatureError> { self.extension_ref .upgrade() .ok_or(SignatureError::MissingTypeExtension { diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index c65edafc3e..9e62519d4e 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -7,7 +7,7 @@ use std::fmt::{self, Display}; use super::type_param::TypeParam; use super::type_row::TypeRowBase; -use super::{MaybeRV, NoRV, RowVariable, Substitution, Type, TypeRow}; +use super::{MaybeRV, NoRV, RowVariable, Substitution, Type, TypeRow, TypeTransformer}; use crate::core::PortIndex; use crate::extension::resolution::{ @@ -72,6 +72,14 @@ impl FuncTypeBase { } } + /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) + /// + /// Returns true if the function type (may have) changed, or false if it definitely didn't. + pub fn transform(&mut self, tr: &T) -> Result { + // TODO handle extension sets? + Ok(self.input.transform(tr)? | self.output.transform(tr)?) + } + /// Create a new signature with specified inputs and outputs. pub fn new(input: impl Into>, output: impl Into>) -> Self { Self { diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 4671e5f8c1..4653dc6ddb 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -11,7 +11,9 @@ use std::num::NonZeroU64; use thiserror::Error; use super::row_var::MaybeRV; -use super::{check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound}; +use super::{ + check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound, TypeTransformer, +}; use crate::extension::ExtensionSet; use crate::extension::SignatureError; @@ -367,6 +369,26 @@ impl TypeArg { } => t.apply_var(*idx, cached_decl), } } + + /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) + /// + /// Returns true if the TypeArg (may have) changed, or false if it definitely didn't. + pub fn transform(&mut self, tr: &T) -> Result { + match self { + TypeArg::Type { ty } => ty.transform(tr), + TypeArg::Sequence { elems } => { + let mut any_ch = false; + for e in elems.iter_mut() { + any_ch |= e.transform(tr)?; + } + Ok(any_ch) + } + TypeArg::BoundedNat { .. } + | TypeArg::String { .. } + | TypeArg::Extensions { .. } + | TypeArg::Variable { .. } => Ok(false), + } + } } impl TypeArgVariable { diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 38a4b05206..540f4e8a1b 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -7,7 +7,10 @@ use std::{ ops::{Deref, DerefMut}, }; -use super::{type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Type, TypeBase}; +use super::{ + type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Type, TypeBase, + TypeTransformer, +}; use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; use itertools::Itertools; @@ -75,6 +78,17 @@ impl TypeRowBase { .into() } + /// Applies a [TypeTransformer] to all the types in the row. (Mutates in-place.) + /// + /// Returns true if any type (may have) changed, or false if all were definitely unchanged. + pub fn transform(&mut self, tr: &T) -> Result { + let mut any_ch = false; + for t in self.iter_mut() { + any_ch |= t.transform(tr)?; + } + Ok(any_ch) + } + delegate! { to self.types { /// Iterator over the types in the row. From 815536b82bd7680d2a575026907a9dcf2aac8eec Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Mar 2025 09:10:32 +0000 Subject: [PATCH 002/123] trait Transformable to common up all the 'let mut any_change = false' loops --- hugr-core/src/types.rs | 56 ++++++++++++++++++------------- hugr-core/src/types/signature.rs | 19 ++++++----- hugr-core/src/types/type_param.rs | 18 ++++------ hugr-core/src/types/type_row.rs | 19 ++++------- 4 files changed, 56 insertions(+), 56 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 9b5faefeaf..c196b89ad4 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -267,22 +267,14 @@ impl SumType { SumType::General { rows } => Either::Right(rows.iter()), } } +} - /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) - /// - /// Returns true if any part of the sum type (may have) changed, or false - /// for definitely no change. - pub fn transform(&mut self, tr: &T) -> Result { - Ok(match self { - SumType::Unit { .. } => false, - SumType::General { rows } => { - let mut any_changed = false; - for r in rows.iter_mut() { - any_changed |= r.transform(tr)?; - } - any_changed - } - }) +impl Transformable for SumType { + fn transform(&mut self, tr: &T) -> Result { + match self { + SumType::Unit { .. } => Ok(false), + SumType::General { rows } => rows.iter_mut().transform(tr), + } } } @@ -545,11 +537,10 @@ impl TypeBase { } } } +} - /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) - /// - /// Returns true if the Type (may have) changed, or false if it definitely didn't. - pub fn transform(&mut self, tr: &T) -> Result { +impl Transformable for TypeBase { + fn transform(&mut self, tr: &T) -> Result { match &mut self.0 { TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => Ok(false), TypeEnum::Extension(custom_type) => { @@ -557,10 +548,7 @@ impl TypeBase { *self = nt.into_(); true } else { - let mut args_changed = false; - for a in custom_type.args_mut() { - args_changed |= a.transform(tr)? - } + let args_changed = custom_type.args_mut().into_iter().transform(tr)?; if args_changed { *custom_type = custom_type .get_type_def(&custom_type.get_extension()?)? @@ -736,6 +724,28 @@ pub trait TypeTransformer { // By defaulting such trait methods to Ok(None), backwards compatibility will be preserved. } +/// Trait for things that can be transformed by applying a [TypeTransformer]. +/// (A destructive / in-place mutation.) +pub trait Transformable { + /// Applies a [TypeTransformer] to this instance. + /// + /// Returns true if any part may have changed, or false for definitely no change. + /// + /// If an Err occurs, `self` may be left in an inconsistent state (e.g. partially + /// transformed). + fn transform(&mut self, t: &T) -> Result; +} + +impl<'a, E: Transformable + 'a, I: Iterator> Transformable for I { + fn transform(&mut self, tr: &T) -> Result { + let mut any_change = false; + for item in self { + any_change |= item.transform(tr)?; + } + Ok(any_change) + } +} + pub(crate) fn check_typevar_decl( decls: &[TypeParam], idx: usize, diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 9e62519d4e..cd6b06bb3a 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -7,7 +7,9 @@ use std::fmt::{self, Display}; use super::type_param::TypeParam; use super::type_row::TypeRowBase; -use super::{MaybeRV, NoRV, RowVariable, Substitution, Type, TypeRow, TypeTransformer}; +use super::{ + MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeRow, TypeTransformer, +}; use crate::core::PortIndex; use crate::extension::resolution::{ @@ -72,14 +74,6 @@ impl FuncTypeBase { } } - /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) - /// - /// Returns true if the function type (may have) changed, or false if it definitely didn't. - pub fn transform(&mut self, tr: &T) -> Result { - // TODO handle extension sets? - Ok(self.input.transform(tr)? | self.output.transform(tr)?) - } - /// Create a new signature with specified inputs and outputs. pub fn new(input: impl Into>, output: impl Into>) -> Self { Self { @@ -150,6 +144,13 @@ impl FuncTypeBase { } } +impl Transformable for FuncTypeBase { + fn transform(&mut self, tr: &T) -> Result { + // TODO handle extension sets? + Ok(self.input.transform(tr)? | self.output.transform(tr)?) + } +} + impl FuncValueType { /// If this FuncValueType contains any row variables, return one. pub fn find_rowvar(&self) -> Option { diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 4653dc6ddb..4b64661e9b 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -12,7 +12,8 @@ use thiserror::Error; use super::row_var::MaybeRV; use super::{ - check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound, TypeTransformer, + check_typevar_decl, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, + TypeTransformer, }; use crate::extension::ExtensionSet; use crate::extension::SignatureError; @@ -369,20 +370,13 @@ impl TypeArg { } => t.apply_var(*idx, cached_decl), } } +} - /// Applies a [TypeTransformer] to this instance. (Mutates in-place.) - /// - /// Returns true if the TypeArg (may have) changed, or false if it definitely didn't. - pub fn transform(&mut self, tr: &T) -> Result { +impl Transformable for TypeArg { + fn transform(&mut self, tr: &T) -> Result { match self { TypeArg::Type { ty } => ty.transform(tr), - TypeArg::Sequence { elems } => { - let mut any_ch = false; - for e in elems.iter_mut() { - any_ch |= e.transform(tr)?; - } - Ok(any_ch) - } + TypeArg::Sequence { elems } => elems.iter_mut().transform(tr), TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Extensions { .. } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 540f4e8a1b..2c15c85df2 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -8,7 +8,7 @@ use std::{ }; use super::{ - type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Type, TypeBase, + type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeTransformer, }; use crate::{extension::SignatureError, utils::display_list}; @@ -78,17 +78,6 @@ impl TypeRowBase { .into() } - /// Applies a [TypeTransformer] to all the types in the row. (Mutates in-place.) - /// - /// Returns true if any type (may have) changed, or false if all were definitely unchanged. - pub fn transform(&mut self, tr: &T) -> Result { - let mut any_ch = false; - for t in self.iter_mut() { - any_ch |= t.transform(tr)?; - } - Ok(any_ch) - } - delegate! { to self.types { /// Iterator over the types in the row. @@ -110,6 +99,12 @@ impl TypeRowBase { } } +impl Transformable for TypeRowBase { + fn transform(&mut self, tr: &T) -> Result { + self.iter_mut().transform(tr) + } +} + impl TypeRow { delegate! { to self.types { From cc1e8f2a00fbf08e2fbe43266a1bd15b1a3b0060 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Mar 2025 09:18:39 +0000 Subject: [PATCH 003/123] Transformable v2: implement for [E], works without iter_mut/etc. in almost all cases --- hugr-core/src/types.rs | 8 ++++---- hugr-core/src/types/type_param.rs | 2 +- hugr-core/src/types/type_row.rs | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index c196b89ad4..4cb7a40d72 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -273,7 +273,7 @@ impl Transformable for SumType { fn transform(&mut self, tr: &T) -> Result { match self { SumType::Unit { .. } => Ok(false), - SumType::General { rows } => rows.iter_mut().transform(tr), + SumType::General { rows } => rows.transform(tr), } } } @@ -548,7 +548,7 @@ impl Transformable for TypeBase { *self = nt.into_(); true } else { - let args_changed = custom_type.args_mut().into_iter().transform(tr)?; + let args_changed = custom_type.args_mut().transform(tr)?; if args_changed { *custom_type = custom_type .get_type_def(&custom_type.get_extension()?)? @@ -736,10 +736,10 @@ pub trait Transformable { fn transform(&mut self, t: &T) -> Result; } -impl<'a, E: Transformable + 'a, I: Iterator> Transformable for I { +impl Transformable for [E] { fn transform(&mut self, tr: &T) -> Result { let mut any_change = false; - for item in self { + for item in self.into_iter() { any_change |= item.transform(tr)?; } Ok(any_change) diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 4b64661e9b..db2efecc65 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -376,7 +376,7 @@ impl Transformable for TypeArg { fn transform(&mut self, tr: &T) -> Result { match self { TypeArg::Type { ty } => ty.transform(tr), - TypeArg::Sequence { elems } => elems.iter_mut().transform(tr), + TypeArg::Sequence { elems } => elems.transform(tr), TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Extensions { .. } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index 2c15c85df2..24616b7200 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -101,7 +101,7 @@ impl TypeRowBase { impl Transformable for TypeRowBase { fn transform(&mut self, tr: &T) -> Result { - self.iter_mut().transform(tr) + self.to_mut().transform(tr) } } From 48881842497b670044a36c93c8d5ca2072ed0225 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 18:42:49 +0000 Subject: [PATCH 004/123] first test, fix CustomType bound caching --- hugr-core/src/types.rs | 53 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 4cb7a40d72..52761a99ce 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -550,9 +550,11 @@ impl Transformable for TypeBase { } else { let args_changed = custom_type.args_mut().transform(tr)?; if args_changed { - *custom_type = custom_type - .get_type_def(&custom_type.get_extension()?)? - .instantiate(custom_type.args())?; + *self = Self::new_extension( + custom_type + .get_type_def(&custom_type.get_extension()?)? + .instantiate(custom_type.args())?, + ); } args_changed }) @@ -773,12 +775,11 @@ pub(crate) fn check_typevar_decl( #[cfg(test)] pub(crate) mod test { - use std::sync::Weak; use super::*; - use crate::extension::prelude::usize_t; - use crate::type_row; + use crate::extension::{prelude::usize_t, TypeDefBound}; + use crate::{hugr::IdentList, type_row, Extension}; #[test] fn construct() { @@ -836,6 +837,46 @@ pub(crate) mod test { } } + struct FnTransformer(T); + impl Option> TypeTransformer for FnTransformer { + type Err = SignatureError; + + fn apply_custom(&self, t: &CustomType) -> Result, Self::Err> { + Ok((self.0)(t)) + } + } + #[test] + fn transform() { + const LIN: SmolStr = SmolStr::new_inline("MyLinear"); + const COLN: SmolStr = SmolStr::new_inline("ColnOfAny"); + + let e = Extension::new_test_arc(IdentList::new("TestExt").unwrap(), |e, w| { + e.add_type(LIN, vec![], String::new(), TypeDefBound::any(), w) + .unwrap(); + e.add_type( + COLN, + vec![TypeBound::Any.into()], + String::new(), + TypeDefBound::from_params(vec![0]), + w, + ) + .unwrap(); + }); + let lin = e.get_type(&LIN).unwrap().instantiate([]).unwrap(); + let coln = e.get_type(&COLN).unwrap(); + + let lin_to_usize = FnTransformer(|ct: &CustomType| (*ct == lin).then_some(usize_t())); + let mut t = Type::new_extension(lin.clone()); + assert_eq!(t.transform(&lin_to_usize), Ok(true)); + assert_eq!(t, usize_t()); + let mut t = + Type::new_extension(coln.instantiate([Type::from(lin.clone()).into()]).unwrap()); + assert_eq!(t.transform(&lin_to_usize), Ok(true)); + let expected = Type::new_extension(coln.instantiate([usize_t().into()]).unwrap()); + assert_eq!(t, expected); + assert_eq!(t.transform(&lin_to_usize), Ok(false)); + } + mod proptest { use crate::proptest::RecursionDepth; From 7002e4d98861da0b1b49266c8e01741649012da2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 19:11:29 +0000 Subject: [PATCH 005/123] Second test, fix SumType caching --- hugr-core/src/types.rs | 92 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 52761a99ce..4f0ae2232f 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -560,7 +560,11 @@ impl Transformable for TypeBase { }) } TypeEnum::Function(fty) => fty.transform(tr), - TypeEnum::Sum(sum_type) => sum_type.transform(tr), + TypeEnum::Sum(sum_type) => { + let ch = sum_type.transform(tr)?; + self.1 = self.0.least_upper_bound(); + Ok(ch) + } } } } @@ -778,7 +782,9 @@ pub(crate) mod test { use std::sync::Weak; use super::*; - use crate::extension::{prelude::usize_t, TypeDefBound}; + use crate::extension::prelude::{qb_t, usize_t}; + use crate::extension::TypeDefBound; + use crate::types::type_param::TypeArgError; use crate::{hugr::IdentList, type_row, Extension}; #[test] @@ -877,6 +883,88 @@ pub(crate) mod test { assert_eq!(t.transform(&lin_to_usize), Ok(false)); } + #[test] + fn transform_copyable_to_linear() { + const CPY: SmolStr = SmolStr::new_inline("MyCopyable"); + const COLN: SmolStr = SmolStr::new_inline("ColnOfCopyableElems"); + let e = Extension::new_test_arc(IdentList::new("TestExt").unwrap(), |e, w| { + e.add_type(CPY, vec![], String::new(), TypeDefBound::copyable(), w) + .unwrap(); + e.add_type( + COLN, + vec![TypeParam::new_list(TypeBound::Copyable)], + String::new(), + TypeDefBound::copyable(), + w, + ) + .unwrap(); + }); + + let cpy = e.get_type(&CPY).unwrap().instantiate([]).unwrap(); + let mk_opt = |t: Type| Type::new_sum([type_row![], TypeRow::from(t)]); + + let cpy_to_qb = FnTransformer(|ct: &CustomType| (ct == &cpy).then_some(qb_t())); + + let mut t = mk_opt(cpy.clone().into()); + assert_eq!(t.transform(&cpy_to_qb), Ok(true)); + assert_eq!(t, mk_opt(qb_t())); + + let coln = e.get_type(&COLN).unwrap(); + let c_of_cpy = coln + .instantiate([TypeArg::Sequence { + elems: vec![Type::from(cpy.clone()).into()], + }]) + .unwrap(); + + let mut t = Type::new_extension(c_of_cpy.clone()); + assert_eq!( + t.transform(&cpy_to_qb), + Err(SignatureError::from(TypeArgError::TypeMismatch { + param: TypeBound::Copyable.into(), + arg: qb_t().into() + })) + ); + + let mut t = Type::new_extension( + coln.instantiate([TypeArg::Sequence { + elems: vec![mk_opt(Type::from(cpy.clone())).into()], + }]) + .unwrap(), + ); + assert_eq!( + t.transform(&cpy_to_qb), + Err(SignatureError::from(TypeArgError::TypeMismatch { + param: TypeBound::Copyable.into(), + arg: mk_opt(qb_t()).into() + })) + ); + + // Finally, check handling Coln overrides handling of Cpy + let cpy_to_qb2 = FnTransformer(|ct: &CustomType| { + if ct == &cpy { + Some(qb_t()) + } else { + (ct == &c_of_cpy).then_some(usize_t()) + } + }); + let mut t = Type::new_extension( + coln.instantiate([TypeArg::Sequence { + elems: vec![Type::from(c_of_cpy.clone()).into(); 2], + }]) + .unwrap(), + ); + assert_eq!(t.transform(&cpy_to_qb2), Ok(true)); + assert_eq!( + t, + Type::new_extension( + coln.instantiate([TypeArg::Sequence { + elems: vec![usize_t().into(); 2] + }]) + .unwrap() + ) + ); + } + mod proptest { use crate::proptest::RecursionDepth; From 6cc7ceac35928183cb938bf19dde0762bdf04fd1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 19:25:34 +0000 Subject: [PATCH 006/123] Make clippy happy --- hugr-core/src/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 4f0ae2232f..7a1d106ece 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -745,7 +745,7 @@ pub trait Transformable { impl Transformable for [E] { fn transform(&mut self, tr: &T) -> Result { let mut any_change = false; - for item in self.into_iter() { + for item in self { any_change |= item.transform(tr)?; } Ok(any_change) From d43784bfbfda391ac5bdf0c23fc183aadc10fa72 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 11 Mar 2025 15:48:30 +0000 Subject: [PATCH 007/123] Add HugrMut::optype_mut (v2, allow mutating root if RootHandle == Node) --- hugr-core/src/hugr/internal.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index fb801a3252..77e46a4a35 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -10,7 +10,7 @@ use itertools::Itertools; use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; use crate::ops::handle::NodeHandle; -use crate::ops::OpTrait; +use crate::ops::{OpTag, OpTrait}; use crate::{Direction, Hugr, Node}; use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; @@ -309,6 +309,22 @@ pub trait HugrMutInternals: RootTagged { } self.hugr_mut().replace_op(node, op) } + + /// Gets a mutable reference to the optype. + /// + /// Changing this may invalidate the ports, which may need to be resized to + /// match the OpType signature. + /// + /// Will panic for the root node unless [Self::RootHandle] is [OpTag::Any], + /// as mutation could invalidate the bound. + fn optype_mut(&mut self, node: Node) -> &mut OpType { + if Self::RootHandle::TAG.is_superset(OpTag::Any) { + panic_invalid_node(self, node); + } else { + panic_invalid_non_root(self, node); + } + self.hugr_mut().op_types.get_mut(node.pg_index()) + } } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -406,8 +422,7 @@ impl + AsMut> HugrMutInterna fn replace_op(&mut self, node: Node, op: impl Into) -> Result { // We know RootHandle=Node here so no need to check - let cur = self.hugr_mut().op_types.get_mut(node.pg_index()); - Ok(std::mem::replace(cur, op.into())) + Ok(std::mem::replace(self.optype_mut(node), op.into())) } } From 422c496d6946d382408f6bd147e8c5605d4a85ac Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 5 Mar 2025 18:28:01 +0000 Subject: [PATCH 008/123] WIP add hugr-passes/src/lower_types.rs (w/ change_node, subst_ty) --- hugr-core/src/types/poly_func.rs | 2 +- hugr-passes/src/lib.rs | 1 + hugr-passes/src/lower_types.rs | 120 +++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 hugr-passes/src/lower_types.rs diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 4b73c71af4..67bc7fbf57 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -119,7 +119,7 @@ impl PolyFuncTypeBase { /// # Errors /// If there is not exactly one [TypeArg] for each binder ([Self::params]), /// or an arg does not fit into its corresponding [TypeParam] - pub(crate) fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { + pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. check_type_args(args, &self.params)?; diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 38ef01505b..fd09cda3e4 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -30,6 +30,7 @@ pub use monomorphize::monomorphize; pub use monomorphize::{MonomorphizeError, MonomorphizePass}; pub mod nest_cfgs; pub mod non_local; +pub mod lower_types; pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs new file mode 100644 index 0000000000..c5a065ea66 --- /dev/null +++ b/hugr-passes/src/lower_types.rs @@ -0,0 +1,120 @@ +use std::collections::HashMap; + +use hugr_core::{hugr::hugrmut::HugrMut, ops::{ExtensionOp, OpType}, types::{CustomType, FuncValueType, Type, TypeArg, TypeBound, TypeEnum, TypeRV, TypeRowRV}}; + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct LowerTypes { + // TODO allow Fn() to cope with parametrized CustomTypes? What about Aliases? + type_fn: Arc Option>, + type_map: HashMap, + copy_dup: HashMap, // TODO what about e.g. arrays that have gone from copyable to linear because their elements have?! + //op_map: HashMap + // 1. is input op always a single OpType, or a schema/predicate? + // 2. output might not be an op - might be a node with children + // 3. do we need checking BEFORE reparametrization as well as after? (after only if not reparametrized?) +} + +impl LowerTypes { + pub fn lower_type(&mut self, src: CustomType, dest: Type) { + if src.bound() == TypeBound::Copyable && !dest.copyable() { + // Of course we could try, and fail only if we encounter outports that are not singly-used! + panic!("Cannot lower copyable type to linear without copy/dup - use lower_type_linearize instead"); + } + self.type_map.insert(src, dest); + } + + pub fn lower_type_linearize(&mut self, src: CustomType, dest: Type, copy: OpType, dup: OpType) { + self.type_map.insert(src.clone(), dest); + self.copy_dup.insert(src, (copy, dup)); + } + + pub fn run_no_validate(&self, h: &mut impl HugrMut) { + for n in h.nodes().collect::>() { + let n_op = match h.get_optype(n) { + OpType::ExtensionOp(eop) => { + // YEUCH eop.def() is &OpDef but we need the Arc + let ext = eop.def().extension().upgrade().unwrap(); + let def = ext.get_op(eop.def().name()).unwrap(); + ExtensionOp::new(def.clone(), self.subst_tas(eop.args())).unwrap() // TODO return error + }, + OpType::Module(_) | OpType::AliasDecl(_) => todo!(), + OpType::FuncDefn(func_defn) => todo!(), + OpType::FuncDecl(func_decl) => todo!(), + OpType::AliasDecl(alias_decl) => todo!(), + OpType::AliasDefn(alias_defn) => todo!(), + OpType::Const(_) => todo!(), + OpType::Input(input) => todo!(), + OpType::Output(output) => todo!(), + OpType::Call(call) => todo!(), + OpType::CallIndirect(call_indirect) => todo!(), + OpType::LoadConstant(load_constant) => todo!(), + OpType::LoadFunction(load_function) => todo!(), + OpType::DFG(dfg) => todo!(), + OpType::OpaqueOp(opaque_op) => todo!(), + OpType::Tag(tag) => todo!(), + OpType::DataflowBlock(dataflow_block) => todo!(), + OpType::ExitBlock(exit_block) => todo!(), + OpType::TailLoop(tail_loop) => todo!(), + OpType::CFG(cfg) => todo!(), + OpType::Conditional(conditional) => todo!(), + OpType::Case(case) => todo!(), + _ => todo!(), + }; + h.replace_op(n, n_op).unwrap(); + // TODO now sort out outputs - insert copy/dup + } + pub fn change_type_hugr(mut hugr: impl HugrMut, change: &mut impl Changer, reg: &ExtensionRegistry) -> Result<()> { + // let ext_op_params: HashMap> = hugr.nodes().filter(|&x| hugr.get_optype(x).is_extension_op()).map(|x| op_params(hugr, x)).collect(); + + Ok(()) + } + } + + + fn subst_tas(&self, args: &[TypeArg]) -> Vec { + args.iter().map(|ta| self.subst_ta(ta)).collect() + } + + fn subst_ta(&self, arg: &TypeArg) -> TypeArg { + match arg { + TypeArg::Type { ty } => TypeArg::Type {ty: self.subst_ty(ty)}, + TypeArg::BoundedNat { .. } | + TypeArg::String { .. } | + TypeArg::Extensions { .. } | + TypeArg::Variable { .. } => arg.clone(), // Or panic on Variable? + TypeArg::Sequence { elems } => TypeArg::Sequence { elems: self.subst_tas(elems) }, + _ => todo!(), + } + } + + fn subst_ty(&self, ty: &Type) -> Type { + match ty.as_type_enum() { + TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => ty.clone(), + TypeEnum::Extension(ct) => { + if let Some(r) = self.type_map.get(ct) { + return r.clone() + } + let ext = ct.extension_ref().upgrade().unwrap(); + let def = ext.get_type(ct.name()).unwrap(); + def.instantiate( self.subst_tas(ct.args())).unwrap().into() // TODO return error + } + TypeEnum::Function(fty) => Type::new_function(FuncValueType::new( + self.subst_tys(&fty.input), self.subst_tys(&fty.output))), + TypeEnum::Sum(s) => Type::new_sum(s.variants().map(|v| self.subst_tys(v))) + } + } + + fn subst_tys(&self, r: &TypeRowRV) -> TypeRowRV { + r.iter().map(|t| { + match t.clone().try_into_type() { + Ok(t) => self.subst_ty(&t).into(), + Err(rv) => { + // YEUCH Type::new(TypeEnum) is crate-private so: + let mut t=t.clone(); + *(t.as_type_enum_mut()) = TypeEnum::RowVar(rv); + t + } + } + }).collect::>().into() + } +} From 76ed3910f99ffd863216619e113b8005e7a04617 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 11 Mar 2025 15:49:23 +0000 Subject: [PATCH 009/123] Add def_arc --- hugr-core/src/ops/custom.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 11ec390a65..96e884f6d2 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -85,6 +85,12 @@ impl ExtensionOp { self.def.as_ref() } + /// Gets an Arc to the [`OpDef`] of this instance, i.e. usable to create + /// new instances. + pub fn def_arc(&self) -> &Arc { + &self.def + } + /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult { self.def().constant_fold(self.args(), consts) From ed5b5dd442668413803e667b3469763c858cb64c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 10 Mar 2025 15:18:46 +0000 Subject: [PATCH 010/123] change_node, change_type, and the rest --- hugr-passes/src/lib.rs | 2 +- hugr-passes/src/lower_types.rs | 327 +++++++++++++++++++++++++-------- 2 files changed, 255 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index fd09cda3e4..d5d765e367 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -28,9 +28,9 @@ pub use monomorphize::remove_polyfuncs; #[allow(deprecated)] pub use monomorphize::monomorphize; pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub mod lower_types; pub mod nest_cfgs; pub mod non_local; -pub mod lower_types; pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index c5a065ea66..fec8c99f13 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,17 +1,40 @@ +use hugr_core::ops::constant::{CustomConst, Sum}; +use hugr_core::types::{Signature, SumType, TypeRow}; +use thiserror::Error; + use std::collections::HashMap; +use std::sync::Arc; -use hugr_core::{hugr::hugrmut::HugrMut, ops::{ExtensionOp, OpType}, types::{CustomType, FuncValueType, Type, TypeArg, TypeBound, TypeEnum, TypeRV, TypeRowRV}}; +use hugr_core::{ + extension::SignatureError, + hugr::hugrmut::HugrMut, + ops::{ + AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, + ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, + TailLoop, Value, CFG, DFG, + }, + types::{CustomType, FuncValueType, Type, TypeArg, TypeBound, TypeEnum, TypeRV, TypeRowRV}, + Node, +}; -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone)] pub struct LowerTypes { - // TODO allow Fn() to cope with parametrized CustomTypes? What about Aliases? - type_fn: Arc Option>, + type_fn: Arc Option>, type_map: HashMap, copy_dup: HashMap, // TODO what about e.g. arrays that have gone from copyable to linear because their elements have?! //op_map: HashMap // 1. is input op always a single OpType, or a schema/predicate? // 2. output might not be an op - might be a node with children // 3. do we need checking BEFORE reparametrization as well as after? (after only if not reparametrized?) + #[allow(unused)] + const_fn: Arc Option>, +} + +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum ChangeTypeError { + #[error(transparent)] + SignatureError(#[from] SignatureError), } impl LowerTypes { @@ -28,93 +51,251 @@ impl LowerTypes { self.copy_dup.insert(src, (copy, dup)); } - pub fn run_no_validate(&self, h: &mut impl HugrMut) { - for n in h.nodes().collect::>() { - let n_op = match h.get_optype(n) { - OpType::ExtensionOp(eop) => { - // YEUCH eop.def() is &OpDef but we need the Arc - let ext = eop.def().extension().upgrade().unwrap(); - let def = ext.get_op(eop.def().name()).unwrap(); - ExtensionOp::new(def.clone(), self.subst_tas(eop.args())).unwrap() // TODO return error - }, - OpType::Module(_) | OpType::AliasDecl(_) => todo!(), - OpType::FuncDefn(func_defn) => todo!(), - OpType::FuncDecl(func_decl) => todo!(), - OpType::AliasDecl(alias_decl) => todo!(), - OpType::AliasDefn(alias_defn) => todo!(), - OpType::Const(_) => todo!(), - OpType::Input(input) => todo!(), - OpType::Output(output) => todo!(), - OpType::Call(call) => todo!(), - OpType::CallIndirect(call_indirect) => todo!(), - OpType::LoadConstant(load_constant) => todo!(), - OpType::LoadFunction(load_function) => todo!(), - OpType::DFG(dfg) => todo!(), - OpType::OpaqueOp(opaque_op) => todo!(), - OpType::Tag(tag) => todo!(), - OpType::DataflowBlock(dataflow_block) => todo!(), - OpType::ExitBlock(exit_block) => todo!(), - OpType::TailLoop(tail_loop) => todo!(), - OpType::CFG(cfg) => todo!(), - OpType::Conditional(conditional) => todo!(), - OpType::Case(case) => todo!(), - _ => todo!(), - }; - h.replace_op(n, n_op).unwrap(); - // TODO now sort out outputs - insert copy/dup + pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; } - pub fn change_type_hugr(mut hugr: impl HugrMut, change: &mut impl Changer, reg: &ExtensionRegistry) -> Result<()> { - // let ext_op_params: HashMap> = hugr.nodes().filter(|&x| hugr.get_optype(x).is_extension_op()).map(|x| op_params(hugr, x)).collect(); - - Ok(()) + Ok(changed) + } + + fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { + match hugr.optype_mut(n) { + OpType::FuncDefn(FuncDefn { signature, .. }) + | OpType::FuncDecl(FuncDecl { signature, .. }) => self.change_sig(signature.body_mut()), + OpType::LoadConstant(LoadConstant { datatype: ty }) + | OpType::AliasDefn(AliasDefn { definition: ty, .. }) => self.change_type(ty), + + OpType::ExitBlock(ExitBlock { cfg_outputs: types }) + | OpType::Input(Input { types }) + | OpType::Output(Output { types }) => self.change_type_row(types), + OpType::LoadFunction(LoadFunction { + func_sig, + type_args, + instantiation, + }) + | OpType::Call(Call { + func_sig, + type_args, + instantiation, + }) => { + let change = + self.change_sig(func_sig.body_mut())? | self.change_type_args(type_args)?; + let new_inst = func_sig + .instantiate(&type_args) + .map_err(ChangeTypeError::SignatureError)?; + *instantiation = new_inst; + Ok(change) + } + OpType::Case(Case { signature }) + | OpType::CFG(CFG { signature }) + | OpType::DFG(DFG { signature }) + | OpType::CallIndirect(CallIndirect { signature }) => self.change_sig(signature), + OpType::Tag(Tag { variants, .. }) => { + let mut ch = false; + for v in variants.iter_mut() { + ch |= self.change_type_row(v)?; + } + Ok(ch) + } + OpType::Conditional(Conditional { + other_inputs: row1, + outputs: row2, + sum_rows, + .. + }) + | OpType::DataflowBlock(DataflowBlock { + inputs: row1, + other_outputs: row2, + sum_rows, + .. + }) => { + let mut ch = self.change_type_row(row1)? | self.change_type_row(row2)?; + for r in sum_rows.iter_mut() { + ch |= self.change_type_row(r)?; + } + Ok(ch) + } + OpType::TailLoop(TailLoop { + just_inputs, + just_outputs, + rest, + .. + }) => Ok(self.change_type_row(just_inputs)? + | self.change_type_row(just_outputs)? + | self.change_type_row(rest)?), + + OpType::Const(Const { value, .. }) => self.change_value(value), + OpType::ExtensionOp(ext_op) => { + let def = ext_op.def_arc(); + let mut args = ext_op.args().to_vec(); + let change = self.change_type_args(args.as_mut_slice())?; + if change { + *ext_op = ExtensionOp::new(def.clone(), args)?; + } + // let params = ext_op_params[node].to_owned(); + todo!("Also check whether we should lower op") + } + OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), + + OpType::AliasDecl(_) | OpType::Module(_) => Ok(false), + _ => todo!(), } } + fn change_value(&self, value: &mut Value) -> Result { + match value { + Value::Sum(Sum { + values, sum_type, .. + }) => { + let mut any_change = false; + for value in values { + any_change |= self.change_value(value)?; + } + any_change |= self.change_sumtype(sum_type)?; + Ok(any_change) + } + Value::Extension { e } => { + if let Some(new_const) = self.subst_custom_const(e.value())? { + *value = new_const; + Ok(true) + } else { + Ok(false) + } + } + Value::Function { hugr } => self.run_no_validate(&mut **hugr), + } + } - fn subst_tas(&self, args: &[TypeArg]) -> Vec { - args.iter().map(|ta| self.subst_ta(ta)).collect() + fn subst_custom_const(&self, _cst: &dyn CustomConst) -> Result, ChangeTypeError> { + todo!() } - fn subst_ta(&self, arg: &TypeArg) -> TypeArg { + fn change_type_arg(&self, arg: &mut TypeArg) -> Result { match arg { - TypeArg::Type { ty } => TypeArg::Type {ty: self.subst_ty(ty)}, - TypeArg::BoundedNat { .. } | - TypeArg::String { .. } | - TypeArg::Extensions { .. } | - TypeArg::Variable { .. } => arg.clone(), // Or panic on Variable? - TypeArg::Sequence { elems } => TypeArg::Sequence { elems: self.subst_tas(elems) }, + TypeArg::Type { ty } => self.change_type(ty), + TypeArg::BoundedNat { .. } + | TypeArg::String { .. } + | TypeArg::Extensions { .. } + | TypeArg::Variable { .. } => Ok(false), + TypeArg::Sequence { elems } => self.change_type_args(elems), _ => todo!(), } } - fn subst_ty(&self, ty: &Type) -> Type { - match ty.as_type_enum() { - TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => ty.clone(), + fn change_type(&self, ty: &mut Type) -> Result { + // There is no as_type_enum_mut because mutation could invalidate the cache of TypeBound + let new_ty = match ty.as_type_enum() { + TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => return Ok(false), TypeEnum::Extension(ct) => { - if let Some(r) = self.type_map.get(ct) { - return r.clone() + if let Some(t) = self.subst_custom_type(ct)? { + t + } else { + return Ok(false); } - let ext = ct.extension_ref().upgrade().unwrap(); - let def = ext.get_type(ct.name()).unwrap(); - def.instantiate( self.subst_tas(ct.args())).unwrap().into() // TODO return error } - TypeEnum::Function(fty) => Type::new_function(FuncValueType::new( - self.subst_tys(&fty.input), self.subst_tys(&fty.output))), - TypeEnum::Sum(s) => Type::new_sum(s.variants().map(|v| self.subst_tys(v))) + TypeEnum::Function(fty) => { + if let Some(fty) = self.subst_fty(&**fty)? { + Type::new_function(fty) + } else { + return Ok(false); + } + } + TypeEnum::Sum(s) => { + let mut st = s.clone(); + if !self.change_sumtype(&mut st)? { + return Ok(false); + }; + st.into() + } + }; + *ty = new_ty; + Ok(true) + } + + fn change_tyrv(&self, ty: &mut TypeRV) -> Result { + // There is no as_type_enum_mut because mutation could invalidate the cache of TypeBound + let new_ty = match ty.as_type_enum() { + TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => return Ok(false), + TypeEnum::Extension(ct) => self.subst_custom_type(ct)?.map(TypeRV::from), + TypeEnum::Function(fty) => self.subst_fty(&**fty)?.map(TypeRV::new_function), + TypeEnum::Sum(s) => { + let mut st = s.clone(); + self.change_sumtype(&mut st)?.then(|| TypeRV::from(st)) + } + }; + if let Some(new_ty) = new_ty { + *ty = new_ty; + return Ok(true); + }; + return Ok(false); + } + + fn subst_custom_type(&self, ct: &CustomType) -> Result, ChangeTypeError> { + let mut nargs = ct.args().to_vec(); + let ch = self.change_type_args(&mut nargs)?; + let ext = ct.extension_ref().upgrade().unwrap(); + let ct = ext.get_type(ct.name()).unwrap().instantiate(nargs)?; + Ok(if let Some(r) = (self.type_fn)(&ct) { + Some(r) + } else if let Some(r) = self.type_map.get(&ct) { + Some(r.clone()) + } else if ch { + Some(ct.into()) + } else { + None + }) + } + + fn change_sig(&self, ft: &mut Signature) -> Result { + // TODO runtime_reqs? + Ok(self.change_type_row(&mut ft.input)? | self.change_type_row(&mut ft.output)?) + } + + fn subst_fty(&self, fty: &FuncValueType) -> Result, ChangeTypeError> { + let mut fty = fty.clone(); + if !self.change_type_row_rv(&mut fty.input)? & !self.change_type_row_rv(&mut fty.output)? { + return Ok(None); } + // TODO what about runtime_req if we are changing ops?? + Ok(Some(fty)) } - - fn subst_tys(&self, r: &TypeRowRV) -> TypeRowRV { - r.iter().map(|t| { - match t.clone().try_into_type() { - Ok(t) => self.subst_ty(&t).into(), - Err(rv) => { - // YEUCH Type::new(TypeEnum) is crate-private so: - let mut t=t.clone(); - *(t.as_type_enum_mut()) = TypeEnum::RowVar(rv); - t + + fn change_sumtype(&self, st: &mut SumType) -> Result { + Ok(match st { + SumType::Unit { .. } => false, + SumType::General { rows } => { + let mut ch = false; + for row in rows.iter_mut() { + ch |= self.change_type_row_rv(row)? } + ch } - }).collect::>().into() + _ => todo!("Unexpected SumType {st:?}"), + }) + } + + fn change_type_args(&self, tas: &mut [TypeArg]) -> Result { + let mut ch = false; + for ta in tas.iter_mut() { + ch |= self.change_type_arg(ta)?; + } + Ok(ch) + } + + fn change_type_row(&self, row: &mut TypeRow) -> Result { + let mut ch = false; + for t in row.iter_mut() { + ch |= self.change_type(t)?; + } + Ok(ch) + } + + fn change_type_row_rv(&self, row: &mut TypeRowRV) -> Result { + let mut ch = false; + for t in row.iter_mut() { + ch |= self.change_tyrv(t)?; + } + Ok(ch) } } From adcbbf69f5365d838c1ece1e941e50bee9a36607 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 11 Mar 2025 19:28:56 +0000 Subject: [PATCH 011/123] Use TypeTransformer framework, removing most type_stuff from lower_types.rs --- hugr-passes/src/lower_types.rs | 194 ++++++--------------------------- 1 file changed, 34 insertions(+), 160 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index fec8c99f13..8b1817cd1b 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,5 +1,5 @@ use hugr_core::ops::constant::{CustomConst, Sum}; -use hugr_core::types::{Signature, SumType, TypeRow}; +use hugr_core::types::{Transformable, TypeTransformer}; use thiserror::Error; use std::collections::HashMap; @@ -13,7 +13,7 @@ use hugr_core::{ ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }, - types::{CustomType, FuncValueType, Type, TypeArg, TypeBound, TypeEnum, TypeRV, TypeRowRV}, + types::{CustomType, Type, TypeBound}, Node, }; @@ -30,6 +30,20 @@ pub struct LowerTypes { const_fn: Arc Option>, } +impl TypeTransformer for LowerTypes { + type Err = ChangeTypeError; + + fn apply_custom(&self, ct: &CustomType) -> Result, Self::Err> { + Ok(if let Some(r) = (self.type_fn)(ct) { + Some(r) + } else if let Some(r) = self.type_map.get(ct) { + Some(r.clone()) + } else { + None + }) + } +} + #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum ChangeTypeError { @@ -62,13 +76,13 @@ impl LowerTypes { fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) - | OpType::FuncDecl(FuncDecl { signature, .. }) => self.change_sig(signature.body_mut()), + | OpType::FuncDecl(FuncDecl { signature, .. }) => signature.body_mut().transform(self), OpType::LoadConstant(LoadConstant { datatype: ty }) - | OpType::AliasDefn(AliasDefn { definition: ty, .. }) => self.change_type(ty), + | OpType::AliasDefn(AliasDefn { definition: ty, .. }) => ty.transform(self), OpType::ExitBlock(ExitBlock { cfg_outputs: types }) | OpType::Input(Input { types }) - | OpType::Output(Output { types }) => self.change_type_row(types), + | OpType::Output(Output { types }) => types.transform(self), OpType::LoadFunction(LoadFunction { func_sig, type_args, @@ -79,25 +93,20 @@ impl LowerTypes { type_args, instantiation, }) => { - let change = - self.change_sig(func_sig.body_mut())? | self.change_type_args(type_args)?; - let new_inst = func_sig - .instantiate(&type_args) - .map_err(ChangeTypeError::SignatureError)?; - *instantiation = new_inst; + let change = func_sig.body_mut().transform(self)? | type_args.transform(self)?; + if change { + let new_inst = func_sig + .instantiate(&type_args) + .map_err(ChangeTypeError::SignatureError)?; + *instantiation = new_inst; + } Ok(change) } OpType::Case(Case { signature }) | OpType::CFG(CFG { signature }) | OpType::DFG(DFG { signature }) - | OpType::CallIndirect(CallIndirect { signature }) => self.change_sig(signature), - OpType::Tag(Tag { variants, .. }) => { - let mut ch = false; - for v in variants.iter_mut() { - ch |= self.change_type_row(v)?; - } - Ok(ch) - } + | OpType::CallIndirect(CallIndirect { signature }) => signature.transform(self), + OpType::Tag(Tag { variants, .. }) => variants.transform(self), OpType::Conditional(Conditional { other_inputs: row1, outputs: row2, @@ -109,28 +118,21 @@ impl LowerTypes { other_outputs: row2, sum_rows, .. - }) => { - let mut ch = self.change_type_row(row1)? | self.change_type_row(row2)?; - for r in sum_rows.iter_mut() { - ch |= self.change_type_row(r)?; - } - Ok(ch) - } + }) => Ok(row1.transform(self)? | row2.transform(self)? | sum_rows.transform(self)?), OpType::TailLoop(TailLoop { just_inputs, just_outputs, rest, .. - }) => Ok(self.change_type_row(just_inputs)? - | self.change_type_row(just_outputs)? - | self.change_type_row(rest)?), + }) => Ok(just_inputs.transform(self)? + | just_outputs.transform(self)? + | rest.transform(self)?), OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => { let def = ext_op.def_arc(); let mut args = ext_op.args().to_vec(); - let change = self.change_type_args(args.as_mut_slice())?; - if change { + if args.transform(self)? { *ext_op = ExtensionOp::new(def.clone(), args)?; } // let params = ext_op_params[node].to_owned(); @@ -152,7 +154,7 @@ impl LowerTypes { for value in values { any_change |= self.change_value(value)?; } - any_change |= self.change_sumtype(sum_type)?; + any_change |= sum_type.transform(self)?; Ok(any_change) } Value::Extension { e } => { @@ -170,132 +172,4 @@ impl LowerTypes { fn subst_custom_const(&self, _cst: &dyn CustomConst) -> Result, ChangeTypeError> { todo!() } - - fn change_type_arg(&self, arg: &mut TypeArg) -> Result { - match arg { - TypeArg::Type { ty } => self.change_type(ty), - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Extensions { .. } - | TypeArg::Variable { .. } => Ok(false), - TypeArg::Sequence { elems } => self.change_type_args(elems), - _ => todo!(), - } - } - - fn change_type(&self, ty: &mut Type) -> Result { - // There is no as_type_enum_mut because mutation could invalidate the cache of TypeBound - let new_ty = match ty.as_type_enum() { - TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => return Ok(false), - TypeEnum::Extension(ct) => { - if let Some(t) = self.subst_custom_type(ct)? { - t - } else { - return Ok(false); - } - } - TypeEnum::Function(fty) => { - if let Some(fty) = self.subst_fty(&**fty)? { - Type::new_function(fty) - } else { - return Ok(false); - } - } - TypeEnum::Sum(s) => { - let mut st = s.clone(); - if !self.change_sumtype(&mut st)? { - return Ok(false); - }; - st.into() - } - }; - *ty = new_ty; - Ok(true) - } - - fn change_tyrv(&self, ty: &mut TypeRV) -> Result { - // There is no as_type_enum_mut because mutation could invalidate the cache of TypeBound - let new_ty = match ty.as_type_enum() { - TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => return Ok(false), - TypeEnum::Extension(ct) => self.subst_custom_type(ct)?.map(TypeRV::from), - TypeEnum::Function(fty) => self.subst_fty(&**fty)?.map(TypeRV::new_function), - TypeEnum::Sum(s) => { - let mut st = s.clone(); - self.change_sumtype(&mut st)?.then(|| TypeRV::from(st)) - } - }; - if let Some(new_ty) = new_ty { - *ty = new_ty; - return Ok(true); - }; - return Ok(false); - } - - fn subst_custom_type(&self, ct: &CustomType) -> Result, ChangeTypeError> { - let mut nargs = ct.args().to_vec(); - let ch = self.change_type_args(&mut nargs)?; - let ext = ct.extension_ref().upgrade().unwrap(); - let ct = ext.get_type(ct.name()).unwrap().instantiate(nargs)?; - Ok(if let Some(r) = (self.type_fn)(&ct) { - Some(r) - } else if let Some(r) = self.type_map.get(&ct) { - Some(r.clone()) - } else if ch { - Some(ct.into()) - } else { - None - }) - } - - fn change_sig(&self, ft: &mut Signature) -> Result { - // TODO runtime_reqs? - Ok(self.change_type_row(&mut ft.input)? | self.change_type_row(&mut ft.output)?) - } - - fn subst_fty(&self, fty: &FuncValueType) -> Result, ChangeTypeError> { - let mut fty = fty.clone(); - if !self.change_type_row_rv(&mut fty.input)? & !self.change_type_row_rv(&mut fty.output)? { - return Ok(None); - } - // TODO what about runtime_req if we are changing ops?? - Ok(Some(fty)) - } - - fn change_sumtype(&self, st: &mut SumType) -> Result { - Ok(match st { - SumType::Unit { .. } => false, - SumType::General { rows } => { - let mut ch = false; - for row in rows.iter_mut() { - ch |= self.change_type_row_rv(row)? - } - ch - } - _ => todo!("Unexpected SumType {st:?}"), - }) - } - - fn change_type_args(&self, tas: &mut [TypeArg]) -> Result { - let mut ch = false; - for ta in tas.iter_mut() { - ch |= self.change_type_arg(ta)?; - } - Ok(ch) - } - - fn change_type_row(&self, row: &mut TypeRow) -> Result { - let mut ch = false; - for t in row.iter_mut() { - ch |= self.change_type(t)?; - } - Ok(ch) - } - - fn change_type_row_rv(&self, row: &mut TypeRowRV) -> Result { - let mut ch = false; - for t in row.iter_mut() { - ch |= self.change_tyrv(t)?; - } - Ok(ch) - } } From d84ae4ef36fec263c20a88fea989ec5e6dfefd46 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Mar 2025 16:07:36 +0000 Subject: [PATCH 012/123] Add a load of copy/discard lowering stuff, OpReplacement --- hugr-passes/src/lower_types.rs | 164 +++++++++++++++++++++++++++++---- 1 file changed, 144 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 8b1817cd1b..93c9ad56b4 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,8 +1,12 @@ use hugr_core::ops::constant::{CustomConst, Sum}; -use hugr_core::types::{Transformable, TypeTransformer}; +use hugr_core::ops::OpTrait; +use hugr_core::types::{Transformable, TypeEnum, TypeTransformer}; +use hugr_core::{Hugr, IncomingPort, OutgoingPort}; use thiserror::Error; +use std::borrow::Cow; use std::collections::HashMap; +use std::ops::Deref; use std::sync::Arc; use hugr_core::{ @@ -17,30 +21,59 @@ use hugr_core::{ Node, }; +#[derive(Clone, Debug, PartialEq)] +enum OpReplacement { + SingleOp(OpType), + CompoundOp(Box), +} + +impl OpReplacement { + fn add(&self, hugr: &mut impl HugrMut, parent: Node) -> Node { + match self.clone() { + OpReplacement::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), + OpReplacement::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + } + } + + // n must be non-root. I mean, it's an ExtensionOp... + fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + let new_optype = match self.clone() { + OpReplacement::SingleOp(op_type) => op_type, + OpReplacement::CompoundOp(new_h) => { + let new_root = hugr + .insert_hugr(hugr.get_parent(n).unwrap(), *new_h) + .new_root; + for ch in hugr.children(new_root).collect::>() { + hugr.set_parent(ch, n); + } + hugr.remove_node(new_root) + } + }; + *hugr.optype_mut(n) = new_optype; + } +} + #[derive(Clone)] pub struct LowerTypes { - type_fn: Arc Option>, + /// Handles simple cases like T1 -> T2. + /// If T1 is Copyable and T2 Linear, then error will be raised if we find e.g. + /// ArrayOfCopyables(T1). This would require an additional entry for that. + /// No support yet for mapping parametrically. type_map: HashMap, - copy_dup: HashMap, // TODO what about e.g. arrays that have gone from copyable to linear because their elements have?! - //op_map: HashMap + copy_discard: HashMap, // TODO what about e.g. arrays that have gone from copyable to linear because their elements have?! + op_map: HashMap, // 1. is input op always a single OpType, or a schema/predicate? // 2. output might not be an op - might be a node with children // 3. do we need checking BEFORE reparametrization as well as after? (after only if not reparametrized?) - #[allow(unused)] const_fn: Arc Option>, + check_sig: bool, } impl TypeTransformer for LowerTypes { type Err = ChangeTypeError; fn apply_custom(&self, ct: &CustomType) -> Result, Self::Err> { - Ok(if let Some(r) = (self.type_fn)(ct) { - Some(r) - } else if let Some(r) = self.type_map.get(ct) { - Some(r.clone()) - } else { - None - }) + Ok(self.type_map.get(ct).cloned()) } } @@ -60,19 +93,106 @@ impl LowerTypes { self.type_map.insert(src, dest); } - pub fn lower_type_linearize(&mut self, src: CustomType, dest: Type, copy: OpType, dup: OpType) { + pub fn lower_type_linearize( + &mut self, + src: CustomType, + dest: Type, + copy: OpReplacement, + discard: OpReplacement, + ) { self.type_map.insert(src.clone(), dest); - self.copy_dup.insert(src, (copy, dup)); + self.copy_discard.insert(src, (copy, discard)); } pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { + let expected_sig = self.check_sig.then(|| { + let mut dfsig = hugr.get_optype(n).dataflow_signature().map(Cow::into_owned); + if let Some(sig) = dfsig.as_mut() { + sig.transform(self); + } + dfsig + }); changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + // (If check_sig) then verify that the Signature still has the same arity/wires, + // with only the expected changes to types within. + if let Some(dfsig) = expected_sig { + assert_eq!(new_dfsig.as_ref().map(Cow::deref), dfsig.as_ref()); + } + let Some(new_sig) = new_dfsig.filter(|_| changed) else { + continue; + }; + let new_sig = new_sig.into_owned(); + for outp in new_sig.output_ports() { + if new_sig.out_port_type(outp).unwrap().copyable() { + continue; + }; + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() == 1 { + continue; + }; + hugr.disconnect(n, outp); + if targets.len() == 0 { + self.do_discard(hugr, n, outp) + } else { + self.do_copy_chain(hugr, n, outp, &targets) + } + } } Ok(changed) } + fn do_copy_chain( + &self, + hugr: &mut impl HugrMut, + mut src_node: Node, + mut src_port: OutgoingPort, + inps: &[(Node, IncomingPort)], + ) { + assert!(inps.len() > 1); + for (tgt_node, tgt_port) in &inps[..inps.len() - 1] { + (src_node, src_port) = self.do_copy(hugr, src_node, src_port, *tgt_node, *tgt_port); + } + let (tgt_node, tgt_port) = inps.last().unwrap(); + hugr.connect(src_node, src_port, *tgt_node, *tgt_port) + } + + fn do_copy( + &self, + hugr: &mut impl HugrMut, + src_node: Node, + src_port: OutgoingPort, + tgt_node: Node, + tgt_port: IncomingPort, + ) -> (Node, OutgoingPort) { + let sig = hugr.get_optype(src_node).dataflow_signature().unwrap(); + let typ = sig.out_port_type(src_port).unwrap(); + if let TypeEnum::Extension(exty) = typ.as_type_enum() { + if let Some((copy, _)) = self.copy_discard.get(exty) { + let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); + hugr.connect(src_node, src_port, n, 0); + hugr.connect(n, 0, tgt_node, tgt_port); + return (n, 1.into()); + } + } + todo!("Containers/arrays/etc.") + } + + fn do_discard(&self, hugr: &mut impl HugrMut, src_node: Node, src_port: OutgoingPort) { + let sig = hugr.get_optype(src_node).dataflow_signature().unwrap(); + let typ = sig.out_port_type(src_port).unwrap(); + if let TypeEnum::Extension(exty) = typ.as_type_enum() { + if let Some((_, discard)) = self.copy_discard.get(exty) { + let n = discard.add(hugr, hugr.get_parent(src_node).unwrap()); + hugr.connect(src_node, src_port, n, 0); + return; + } + } + todo!("Containers/arrays/etc.") + } + fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) @@ -130,13 +250,17 @@ impl LowerTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => { - let def = ext_op.def_arc(); - let mut args = ext_op.args().to_vec(); - if args.transform(self)? { - *ext_op = ExtensionOp::new(def.clone(), args)?; + if let Some(replacement) = self.op_map.get(ext_op) { + replacement.replace(hugr, n); // Copy/discard insertion done by caller + Ok(true) + } else { + let def = ext_op.def_arc(); + let mut args = ext_op.args().to_vec(); + Ok(args.transform(self)? && { + *ext_op = ExtensionOp::new(def.clone(), args)?; + true + }) } - // let params = ext_op_params[node].to_owned(); - todo!("Also check whether we should lower op") } OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), From bfa52cf1428a8dff5905a88a457d87081c03ee89 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Mar 2025 16:13:36 +0000 Subject: [PATCH 013/123] OpHashWrapper --- hugr-passes/src/lower_types.rs | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 93c9ad56b4..1ed602b1ea 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,6 +1,7 @@ +use hugr_core::extension::ExtensionId; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::OpTrait; -use hugr_core::types::{Transformable, TypeEnum, TypeTransformer}; +use hugr_core::types::{Transformable, TypeArg, TypeEnum, TypeTransformer}; use hugr_core::{Hugr, IncomingPort, OutgoingPort}; use thiserror::Error; @@ -21,6 +22,23 @@ use hugr_core::{ Node, }; +#[derive(Clone, Hash, PartialEq, Eq)] +struct OpHashWrapper { + ext_name: ExtensionId, + op_name: String, // Only because SmolStr not in hugr-passes yet + args: Vec, +} + +impl From<&ExtensionOp> for OpHashWrapper { + fn from(op: &ExtensionOp) -> Self { + Self { + ext_name: op.def().extension_id().clone(), + op_name: op.def().name().to_string(), + args: op.args().to_vec(), + } + } +} + #[derive(Clone, Debug, PartialEq)] enum OpReplacement { SingleOp(OpType), @@ -61,7 +79,7 @@ pub struct LowerTypes { /// No support yet for mapping parametrically. type_map: HashMap, copy_discard: HashMap, // TODO what about e.g. arrays that have gone from copyable to linear because their elements have?! - op_map: HashMap, + op_map: HashMap, // 1. is input op always a single OpType, or a schema/predicate? // 2. output might not be an op - might be a node with children // 3. do we need checking BEFORE reparametrization as well as after? (after only if not reparametrized?) @@ -104,6 +122,10 @@ impl LowerTypes { self.copy_discard.insert(src, (copy, discard)); } + pub fn lower_op(&mut self, src: &ExtensionOp, tgt: OpReplacement) { + self.op_map.insert(OpHashWrapper::from(src), tgt); + } + pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { @@ -250,7 +272,7 @@ impl LowerTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => { - if let Some(replacement) = self.op_map.get(ext_op) { + if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { replacement.replace(hugr, n); // Copy/discard insertion done by caller Ok(true) } else { From b3735713e01efde6474d22a1c7121fd2e3b81675 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 13:00:00 +0000 Subject: [PATCH 014/123] Parametrized type support --- hugr-passes/src/lower_types.rs | 182 +++++++++++++++++++++------------ 1 file changed, 116 insertions(+), 66 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 1ed602b1ea..a1f48bc98c 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,26 +1,20 @@ -use hugr_core::extension::ExtensionId; -use hugr_core::ops::constant::{CustomConst, Sum}; -use hugr_core::ops::OpTrait; -use hugr_core::types::{Transformable, TypeArg, TypeEnum, TypeTransformer}; -use hugr_core::{Hugr, IncomingPort, OutgoingPort}; -use thiserror::Error; - use std::borrow::Cow; use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use hugr_core::{ - extension::SignatureError, - hugr::hugrmut::HugrMut, - ops::{ - AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, - ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, - TailLoop, Value, CFG, DFG, - }, - types::{CustomType, Type, TypeBound}, - Node, +use thiserror::Error; + +use hugr_core::extension::{ExtensionId, SignatureError, TypeDef}; +use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::ops::constant::{CustomConst, Sum}; +use hugr_core::ops::{ + AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, + FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, + Value, CFG, DFG, }; +use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; +use hugr_core::{Hugr, IncomingPort, Node, OutgoingPort}; #[derive(Clone, Hash, PartialEq, Eq)] struct OpHashWrapper { @@ -40,7 +34,7 @@ impl From<&ExtensionOp> for OpHashWrapper { } #[derive(Clone, Debug, PartialEq)] -enum OpReplacement { +pub enum OpReplacement { SingleOp(OpType), CompoundOp(Box), } @@ -76,9 +70,26 @@ pub struct LowerTypes { /// Handles simple cases like T1 -> T2. /// If T1 is Copyable and T2 Linear, then error will be raised if we find e.g. /// ArrayOfCopyables(T1). This would require an additional entry for that. - /// No support yet for mapping parametrically. type_map: HashMap, - copy_discard: HashMap, // TODO what about e.g. arrays that have gone from copyable to linear because their elements have?! + /// Parametric types are handled by a function which receives the lowered typeargs. + param_types: HashMap<(ExtensionId, String), Arc Type>>, + // Keyed by lowered type, as only needed when there is an op outputting such + copy_discard: HashMap, + // Copy/discard of parametric types handled by a function that receives the new/lowered type. + // We do not allow linearization to "parametrized" non-extension types, at least not + // in one step. We could do that using a trait, but it seems enough of a corner case. + // Instead that can be achieved by *firstly* lowering to a custom linear type, with copy/dup + // inserted; *secondly* by lowering that to the desired non-extension linear type, + // including lowering of the copy/dup operations to...whatever. + copy_discard_parametric: HashMap< + (ExtensionId, String), + // TODO should pass &LowerTypes, or at least some way to call copy_op / discard_op, to these + ( + Arc OpReplacement>, + Arc OpReplacement>, + ), + >, + // Handles simple cases Op1 -> Op2. TODO handle parametric ops op_map: HashMap, // 1. is input op always a single OpType, or a schema/predicate? // 2. output might not be an op - might be a node with children @@ -91,7 +102,22 @@ impl TypeTransformer for LowerTypes { type Err = ChangeTypeError; fn apply_custom(&self, ct: &CustomType) -> Result, Self::Err> { - Ok(self.type_map.get(ct).cloned()) + Ok(if let Some(res) = self.type_map.get(ct) { + Some(res.clone()) + } else if let Some(dest_fn) = self + .param_types + .get(&(ct.extension().clone(), ct.name().to_string())) + { + // `ct` has not had args transformed + let mut nargs = ct.args().to_vec(); + // We don't care if `nargs` are changed, we're just calling `dest_fn` + nargs + .iter_mut() + .try_for_each(|ta| ta.transform(self).map(|_ch| ()))?; + Some(dest_fn(&nargs)) + } else { + None + }) } } @@ -104,24 +130,42 @@ pub enum ChangeTypeError { impl LowerTypes { pub fn lower_type(&mut self, src: CustomType, dest: Type) { - if src.bound() == TypeBound::Copyable && !dest.copyable() { - // Of course we could try, and fail only if we encounter outports that are not singly-used! - panic!("Cannot lower copyable type to linear without copy/dup - use lower_type_linearize instead"); - } self.type_map.insert(src, dest); } - pub fn lower_type_linearize( + pub fn lower_parametric_type( &mut self, - src: CustomType, - dest: Type, - copy: OpReplacement, - discard: OpReplacement, + src: TypeDef, + dest_fn: Box Type>, ) { - self.type_map.insert(src.clone(), dest); + // No way to check that dest_fn never produces a linear type. + // We could require copy/discard-generators if src is Copyable, or *might be* + // (depending on arguments - i.e. if src's TypeDefBound is anything other than + // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying + // overapproximation. We could just require copy/discard-generators in *all cases* + // (e.g. funcs that just panic!)... + self.param_types.insert( + (src.extension_id().clone(), src.name().to_string()), + Arc::from(dest_fn), + ); + } + + pub fn linearize(&mut self, src: Type, copy: OpReplacement, discard: OpReplacement) { self.copy_discard.insert(src, (copy, discard)); } + pub fn linearize_parametric( + &mut self, + src: TypeDef, + copy_fn: Box OpReplacement>, + discard_fn: Box OpReplacement>, + ) { + self.copy_discard_parametric.insert( + (src.extension_id().clone(), src.name().to_string()), + (Arc::from(copy_fn), Arc::from(discard_fn)), + ); + } + pub fn lower_op(&mut self, src: &ExtensionOp, tgt: OpReplacement) { self.op_map.insert(OpHashWrapper::from(src), tgt); } @@ -129,13 +173,15 @@ impl LowerTypes { pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { - let expected_sig = self.check_sig.then(|| { + let expected_sig = if self.check_sig { let mut dfsig = hugr.get_optype(n).dataflow_signature().map(Cow::into_owned); if let Some(sig) = dfsig.as_mut() { - sig.transform(self); + sig.transform(self)?; } - dfsig - }); + Some(dfsig) + } else { + None + }; changed |= self.change_node(hugr, n)?; let new_dfsig = hugr.get_optype(n).dataflow_signature(); // (If check_sig) then verify that the Signature still has the same arity/wires, @@ -156,10 +202,19 @@ impl LowerTypes { continue; }; hugr.disconnect(n, outp); + let sig = hugr.get_optype(n).dataflow_signature().unwrap(); + let typ = sig.out_port_type(outp).unwrap(); if targets.len() == 0 { - self.do_discard(hugr, n, outp) + let discard = self + .discard_op(typ) + .expect("Don't know how to discard {typ:?}"); // TODO return error + + let disc = discard.add(hugr, hugr.get_parent(n).unwrap()); + hugr.connect(n, outp, disc, 0); } else { - self.do_copy_chain(hugr, n, outp, &targets) + // TODO return error + let copy = self.copy_op(typ).expect("Don't know how to copy {typ:?}"); + self.do_copy_chain(hugr, n, outp, copy, &targets) } } } @@ -171,48 +226,43 @@ impl LowerTypes { hugr: &mut impl HugrMut, mut src_node: Node, mut src_port: OutgoingPort, + copy: OpReplacement, inps: &[(Node, IncomingPort)], ) { assert!(inps.len() > 1); + // Could sanity-check signature here? for (tgt_node, tgt_port) in &inps[..inps.len() - 1] { - (src_node, src_port) = self.do_copy(hugr, src_node, src_port, *tgt_node, *tgt_port); + let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); + hugr.connect(src_node, src_port, n, 0); + hugr.connect(n, 0, *tgt_node, *tgt_port); + (src_node, src_port) = (n, 1.into()); } let (tgt_node, tgt_port) = inps.last().unwrap(); hugr.connect(src_node, src_port, *tgt_node, *tgt_port) } - fn do_copy( - &self, - hugr: &mut impl HugrMut, - src_node: Node, - src_port: OutgoingPort, - tgt_node: Node, - tgt_port: IncomingPort, - ) -> (Node, OutgoingPort) { - let sig = hugr.get_optype(src_node).dataflow_signature().unwrap(); - let typ = sig.out_port_type(src_port).unwrap(); - if let TypeEnum::Extension(exty) = typ.as_type_enum() { - if let Some((copy, _)) = self.copy_discard.get(exty) { - let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); - hugr.connect(src_node, src_port, n, 0); - hugr.connect(n, 0, tgt_node, tgt_port); - return (n, 1.into()); - } + pub fn copy_op(&self, typ: &Type) -> Option { + if let Some((copy, _)) = self.copy_discard.get(typ) { + return Some(copy.clone()); } - todo!("Containers/arrays/etc.") + let TypeEnum::Extension(exty) = typ.as_type_enum() else { + return None; + }; + self.copy_discard_parametric + .get(&(exty.extension().clone(), exty.name().to_string())) + .map(|(copy_fn, _)| copy_fn(exty.args())) } - fn do_discard(&self, hugr: &mut impl HugrMut, src_node: Node, src_port: OutgoingPort) { - let sig = hugr.get_optype(src_node).dataflow_signature().unwrap(); - let typ = sig.out_port_type(src_port).unwrap(); - if let TypeEnum::Extension(exty) = typ.as_type_enum() { - if let Some((_, discard)) = self.copy_discard.get(exty) { - let n = discard.add(hugr, hugr.get_parent(src_node).unwrap()); - hugr.connect(src_node, src_port, n, 0); - return; - } + pub fn discard_op(&self, typ: &Type) -> Option { + if let Some((_, discard)) = self.copy_discard.get(typ) { + return Some(discard.clone()); } - todo!("Containers/arrays/etc.") + let TypeEnum::Extension(exty) = typ.as_type_enum() else { + return None; + }; + self.copy_discard_parametric + .get(&(exty.extension().clone(), exty.name().to_string())) + .map(|(_, discard_fn)| discard_fn(exty.args())) } fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { From a8e613a1bc28dd2c1d02c7872767266a74508124 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 13:33:42 +0000 Subject: [PATCH 015/123] remove copy_discard stuff --- hugr-passes/src/lower_types.rs | 118 +-------------------------------- 1 file changed, 3 insertions(+), 115 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index a1f48bc98c..b94f6f0d15 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -13,8 +13,8 @@ use hugr_core::ops::{ FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; -use hugr_core::{Hugr, IncomingPort, Node, OutgoingPort}; +use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeTransformer}; +use hugr_core::{Hugr, Node}; #[derive(Clone, Hash, PartialEq, Eq)] struct OpHashWrapper { @@ -40,13 +40,6 @@ pub enum OpReplacement { } impl OpReplacement { - fn add(&self, hugr: &mut impl HugrMut, parent: Node) -> Node { - match self.clone() { - OpReplacement::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - OpReplacement::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, - } - } - // n must be non-root. I mean, it's an ExtensionOp... fn replace(&self, hugr: &mut impl HugrMut, n: Node) { let new_optype = match self.clone() { @@ -73,22 +66,6 @@ pub struct LowerTypes { type_map: HashMap, /// Parametric types are handled by a function which receives the lowered typeargs. param_types: HashMap<(ExtensionId, String), Arc Type>>, - // Keyed by lowered type, as only needed when there is an op outputting such - copy_discard: HashMap, - // Copy/discard of parametric types handled by a function that receives the new/lowered type. - // We do not allow linearization to "parametrized" non-extension types, at least not - // in one step. We could do that using a trait, but it seems enough of a corner case. - // Instead that can be achieved by *firstly* lowering to a custom linear type, with copy/dup - // inserted; *secondly* by lowering that to the desired non-extension linear type, - // including lowering of the copy/dup operations to...whatever. - copy_discard_parametric: HashMap< - (ExtensionId, String), - // TODO should pass &LowerTypes, or at least some way to call copy_op / discard_op, to these - ( - Arc OpReplacement>, - Arc OpReplacement>, - ), - >, // Handles simple cases Op1 -> Op2. TODO handle parametric ops op_map: HashMap, // 1. is input op always a single OpType, or a schema/predicate? @@ -142,30 +119,13 @@ impl LowerTypes { // We could require copy/discard-generators if src is Copyable, or *might be* // (depending on arguments - i.e. if src's TypeDefBound is anything other than // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying - // overapproximation. We could just require copy/discard-generators in *all cases* - // (e.g. funcs that just panic!)... + // overapproximation. self.param_types.insert( (src.extension_id().clone(), src.name().to_string()), Arc::from(dest_fn), ); } - pub fn linearize(&mut self, src: Type, copy: OpReplacement, discard: OpReplacement) { - self.copy_discard.insert(src, (copy, discard)); - } - - pub fn linearize_parametric( - &mut self, - src: TypeDef, - copy_fn: Box OpReplacement>, - discard_fn: Box OpReplacement>, - ) { - self.copy_discard_parametric.insert( - (src.extension_id().clone(), src.name().to_string()), - (Arc::from(copy_fn), Arc::from(discard_fn)), - ); - } - pub fn lower_op(&mut self, src: &ExtensionOp, tgt: OpReplacement) { self.op_map.insert(OpHashWrapper::from(src), tgt); } @@ -189,82 +149,10 @@ impl LowerTypes { if let Some(dfsig) = expected_sig { assert_eq!(new_dfsig.as_ref().map(Cow::deref), dfsig.as_ref()); } - let Some(new_sig) = new_dfsig.filter(|_| changed) else { - continue; - }; - let new_sig = new_sig.into_owned(); - for outp in new_sig.output_ports() { - if new_sig.out_port_type(outp).unwrap().copyable() { - continue; - }; - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() == 1 { - continue; - }; - hugr.disconnect(n, outp); - let sig = hugr.get_optype(n).dataflow_signature().unwrap(); - let typ = sig.out_port_type(outp).unwrap(); - if targets.len() == 0 { - let discard = self - .discard_op(typ) - .expect("Don't know how to discard {typ:?}"); // TODO return error - - let disc = discard.add(hugr, hugr.get_parent(n).unwrap()); - hugr.connect(n, outp, disc, 0); - } else { - // TODO return error - let copy = self.copy_op(typ).expect("Don't know how to copy {typ:?}"); - self.do_copy_chain(hugr, n, outp, copy, &targets) - } - } } Ok(changed) } - fn do_copy_chain( - &self, - hugr: &mut impl HugrMut, - mut src_node: Node, - mut src_port: OutgoingPort, - copy: OpReplacement, - inps: &[(Node, IncomingPort)], - ) { - assert!(inps.len() > 1); - // Could sanity-check signature here? - for (tgt_node, tgt_port) in &inps[..inps.len() - 1] { - let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); - hugr.connect(src_node, src_port, n, 0); - hugr.connect(n, 0, *tgt_node, *tgt_port); - (src_node, src_port) = (n, 1.into()); - } - let (tgt_node, tgt_port) = inps.last().unwrap(); - hugr.connect(src_node, src_port, *tgt_node, *tgt_port) - } - - pub fn copy_op(&self, typ: &Type) -> Option { - if let Some((copy, _)) = self.copy_discard.get(typ) { - return Some(copy.clone()); - } - let TypeEnum::Extension(exty) = typ.as_type_enum() else { - return None; - }; - self.copy_discard_parametric - .get(&(exty.extension().clone(), exty.name().to_string())) - .map(|(copy_fn, _)| copy_fn(exty.args())) - } - - pub fn discard_op(&self, typ: &Type) -> Option { - if let Some((_, discard)) = self.copy_discard.get(typ) { - return Some(discard.clone()); - } - let TypeEnum::Extension(exty) = typ.as_type_enum() else { - return None; - }; - self.copy_discard_parametric - .get(&(exty.extension().clone(), exty.name().to_string())) - .map(|(_, discard_fn)| discard_fn(exty.args())) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) From 24ed15c22bb383f19e010a676deb33094393e0ac Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 13:44:17 +0000 Subject: [PATCH 016/123] Assume less in OpReplacement::replace --- hugr-passes/src/lower_types.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index b94f6f0d15..5f569e4cb4 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -40,18 +40,18 @@ pub enum OpReplacement { } impl OpReplacement { - // n must be non-root. I mean, it's an ExtensionOp... fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { OpReplacement::SingleOp(op_type) => op_type, OpReplacement::CompoundOp(new_h) => { - let new_root = hugr - .insert_hugr(hugr.get_parent(n).unwrap(), *new_h) - .new_root; - for ch in hugr.children(new_root).collect::>() { + let new_root = hugr.insert_hugr(n, *new_h).new_root; + let children = hugr.children(new_root).collect::>(); + let root_opty = hugr.remove_node(new_root); + for ch in children { hugr.set_parent(ch, n); } - hugr.remove_node(new_root) + root_opty } }; *hugr.optype_mut(n) = new_optype; From 9153fcfaf54022c471206ed3df3200841e1179c0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 14:07:46 +0000 Subject: [PATCH 017/123] parametrized ops --- hugr-passes/src/lower_types.rs | 81 +++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 5f569e4cb4..fb5eddbe79 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use thiserror::Error; -use hugr_core::extension::{ExtensionId, SignatureError, TypeDef}; +use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::{ @@ -33,6 +33,31 @@ impl From<&ExtensionOp> for OpHashWrapper { } } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct ParametricType(ExtensionId, String); // SmolStr not in hugr-passes + +impl From<&TypeDef> for ParametricType { + fn from(value: &TypeDef) -> Self { + Self(value.extension_id().clone(), value.name().to_string()) + } +} + +impl From<&CustomType> for ParametricType { + fn from(value: &CustomType) -> Self { + Self(value.extension().clone(), value.name().to_string()) + } +} + +// Separate from above for clarity +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct ParametricOp(ExtensionId, String); + +impl From<&OpDef> for ParametricOp { + fn from(value: &OpDef) -> Self { + Self(value.extension_id().clone(), value.name().to_string()) + } +} + #[derive(Clone, Debug, PartialEq)] pub enum OpReplacement { SingleOp(OpType), @@ -65,12 +90,11 @@ pub struct LowerTypes { /// ArrayOfCopyables(T1). This would require an additional entry for that. type_map: HashMap, /// Parametric types are handled by a function which receives the lowered typeargs. - param_types: HashMap<(ExtensionId, String), Arc Type>>, + param_types: HashMap Type>>, // Handles simple cases Op1 -> Op2. TODO handle parametric ops op_map: HashMap, - // 1. is input op always a single OpType, or a schema/predicate? - // 2. output might not be an op - might be a node with children - // 3. do we need checking BEFORE reparametrization as well as after? (after only if not reparametrized?) + // Called after lowering typeargs; None means to use original OpDef + param_ops: HashMap Option>>, const_fn: Arc Option>, check_sig: bool, } @@ -81,10 +105,7 @@ impl TypeTransformer for LowerTypes { fn apply_custom(&self, ct: &CustomType) -> Result, Self::Err> { Ok(if let Some(res) = self.type_map.get(ct) { Some(res.clone()) - } else if let Some(dest_fn) = self - .param_types - .get(&(ct.extension().clone(), ct.name().to_string())) - { + } else if let Some(dest_fn) = self.param_types.get(&ct.into()) { // `ct` has not had args transformed let mut nargs = ct.args().to_vec(); // We don't care if `nargs` are changed, we're just calling `dest_fn` @@ -107,12 +128,14 @@ pub enum ChangeTypeError { impl LowerTypes { pub fn lower_type(&mut self, src: CustomType, dest: Type) { + // We could check that 'dest' is copyable or 'src' is linear, but since we can't + // check that for parametric types, we'll be consistent and not check here either. self.type_map.insert(src, dest); } pub fn lower_parametric_type( &mut self, - src: TypeDef, + src: &TypeDef, dest_fn: Box Type>, ) { // No way to check that dest_fn never produces a linear type. @@ -120,16 +143,21 @@ impl LowerTypes { // (depending on arguments - i.e. if src's TypeDefBound is anything other than // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying // overapproximation. - self.param_types.insert( - (src.extension_id().clone(), src.name().to_string()), - Arc::from(dest_fn), - ); + self.param_types.insert(src.into(), Arc::from(dest_fn)); } pub fn lower_op(&mut self, src: &ExtensionOp, tgt: OpReplacement) { self.op_map.insert(OpHashWrapper::from(src), tgt); } + pub fn lower_parametric_op( + &mut self, + src: &OpDef, + dest_fn: Box Option>, + ) { + self.param_ops.insert(src.into(), Arc::from(dest_fn)); + } + pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { @@ -209,19 +237,30 @@ impl LowerTypes { | rest.transform(self)?), OpType::Const(Const { value, .. }) => self.change_value(value), - OpType::ExtensionOp(ext_op) => { + OpType::ExtensionOp(ext_op) => Ok( if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { replacement.replace(hugr, n); // Copy/discard insertion done by caller - Ok(true) + true } else { let def = ext_op.def_arc(); let mut args = ext_op.args().to_vec(); - Ok(args.transform(self)? && { - *ext_op = ExtensionOp::new(def.clone(), args)?; + let ch = args.transform(self)?; + if let Some(replacement) = self + .param_ops + .get(&def.as_ref().into()) + .and_then(|rep_fn| rep_fn(&args)) + { + replacement.replace(hugr, n); true - }) - } - } + } else { + if ch { + *ext_op = ExtensionOp::new(def.clone(), args)?; + } + ch + } + }, + ), + OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), OpType::AliasDecl(_) | OpType::Module(_) => Ok(false), From a89879caf01450e394c83bdf8609af036041ebfc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 14:22:10 +0000 Subject: [PATCH 018/123] Comments, renaming, use const_fn --- hugr-passes/src/lower_types.rs | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index fb5eddbe79..192e32bc95 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -34,7 +34,7 @@ impl From<&ExtensionOp> for OpHashWrapper { } #[derive(Clone, Debug, PartialEq, Eq, Hash)] -struct ParametricType(ExtensionId, String); // SmolStr not in hugr-passes +struct ParametricType(ExtensionId, String); impl From<&TypeDef> for ParametricType { fn from(value: &TypeDef) -> Self { @@ -91,9 +91,9 @@ pub struct LowerTypes { type_map: HashMap, /// Parametric types are handled by a function which receives the lowered typeargs. param_types: HashMap Type>>, - // Handles simple cases Op1 -> Op2. TODO handle parametric ops + // Handles simple cases Op1 -> Op2. op_map: HashMap, - // Called after lowering typeargs; None means to use original OpDef + // Called after lowering typeargs; return None to use original OpDef param_ops: HashMap Option>>, const_fn: Arc Option>, check_sig: bool, @@ -161,7 +161,7 @@ impl LowerTypes { pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { - let expected_sig = if self.check_sig { + let expected_dfsig = if self.check_sig { let mut dfsig = hugr.get_optype(n).dataflow_signature().map(Cow::into_owned); if let Some(sig) = dfsig.as_mut() { sig.transform(self)?; @@ -174,8 +174,8 @@ impl LowerTypes { let new_dfsig = hugr.get_optype(n).dataflow_signature(); // (If check_sig) then verify that the Signature still has the same arity/wires, // with only the expected changes to types within. - if let Some(dfsig) = expected_sig { - assert_eq!(new_dfsig.as_ref().map(Cow::deref), dfsig.as_ref()); + if let Some(expected_sig) = expected_dfsig { + assert_eq!(new_dfsig.as_ref().map(Cow::deref), expected_sig.as_ref()); } } Ok(changed) @@ -281,7 +281,7 @@ impl LowerTypes { Ok(any_change) } Value::Extension { e } => { - if let Some(new_const) = self.subst_custom_const(e.value())? { + if let Some(new_const) = (self.const_fn)(e.value()) { *value = new_const; Ok(true) } else { @@ -291,8 +291,4 @@ impl LowerTypes { Value::Function { hugr } => self.run_no_validate(&mut **hugr), } } - - fn subst_custom_const(&self, _cst: &dyn CustomConst) -> Result, ChangeTypeError> { - todo!() - } } From c78d88aec440fe6e33fc0509f66c220c6378f22a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 17 Mar 2025 14:25:10 +0000 Subject: [PATCH 019/123] Comment const_fn TODO --- hugr-passes/src/lower_types.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 192e32bc95..ef681ca0f2 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -95,6 +95,9 @@ pub struct LowerTypes { op_map: HashMap, // Called after lowering typeargs; return None to use original OpDef param_ops: HashMap Option>>, + // TODO should probably have a map, or two, here - from CustomType and from ParametricType. + // Whereupon the closure should be given a callback to self.change_value, too, in case of nested + // values for collections. const_fn: Arc Option>, check_sig: bool, } From bf6a9a450f1cff94021a440e7de1847fffab1774 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 09:30:18 +0000 Subject: [PATCH 020/123] Test panics on unexpected argument - simpler, better --- hugr-core/src/types.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 7a1d106ece..9ee6893ce4 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -941,11 +941,8 @@ pub(crate) mod test { // Finally, check handling Coln overrides handling of Cpy let cpy_to_qb2 = FnTransformer(|ct: &CustomType| { - if ct == &cpy { - Some(qb_t()) - } else { - (ct == &c_of_cpy).then_some(usize_t()) - } + assert_ne!(ct, &cpy); + (ct == &c_of_cpy).then_some(usize_t()) }); let mut t = Type::new_extension( coln.instantiate([TypeArg::Sequence { From 5ecd9e67a47dd7c0e3a867a90592f30821ef4d74 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 09:50:27 +0000 Subject: [PATCH 021/123] test functiontype --- hugr-core/src/types.rs | 2 +- hugr-core/src/types/signature.rs | 31 ++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 9ee6893ce4..e46f485c40 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -843,7 +843,7 @@ pub(crate) mod test { } } - struct FnTransformer(T); + pub(super) struct FnTransformer(pub(super) T); impl Option> TypeTransformer for FnTransformer { type Err = SignatureError; diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index cd6b06bb3a..636d05d68d 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -339,7 +339,9 @@ impl PartialEq> for Cow<'_, FuncTy #[cfg(test)] mod test { - use crate::{extension::prelude::usize_t, type_row}; + use crate::extension::prelude::{bool_t, qb_t, usize_t}; + use crate::type_row; + use crate::types::{test::FnTransformer, CustomType, TypeEnum}; use super::*; #[test] @@ -367,4 +369,31 @@ mod test { (&type_row![Type::UNIT], &vec![usize_t()].into()) ); } + + #[test] + fn test_transform() { + let TypeEnum::Extension(usz_t) = usize_t().as_type_enum().clone() else { + panic!() + }; + let tr = FnTransformer(|ct: &CustomType| (ct == &usz_t).then_some(bool_t())); + let row_with = || TypeRow::from(vec![usize_t(), qb_t(), bool_t()]); + let row_after = || TypeRow::from(vec![bool_t(), qb_t(), bool_t()]); + let mut sig = Signature::new(row_with(), row_after()); + let mut exp = Signature::new(row_after(), row_after()); + assert_eq!(sig.transform(&tr), Ok(true)); + assert_eq!(sig, exp); + assert_eq!(sig.transform(&tr), Ok(false)); + assert_eq!(sig, exp); + let exp = Type::new_function(exp); + for fty in [ + FuncValueType::new(row_after(), row_with()), + FuncValueType::new(row_with(), row_with()), + ] { + let mut t = Type::new_function(fty); + assert_eq!(t.transform(&tr), Ok(true)); + assert_eq!(t, exp); + assert_eq!(t.transform(&tr), Ok(false)); + assert_eq!(t, exp); + } + } } From d9a6d2909b6072405250d284f3c1196a22031748 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 10:37:36 +0000 Subject: [PATCH 022/123] clippy that new test --- hugr-core/src/types/signature.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 636d05d68d..28c39fa088 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -379,7 +379,7 @@ mod test { let row_with = || TypeRow::from(vec![usize_t(), qb_t(), bool_t()]); let row_after = || TypeRow::from(vec![bool_t(), qb_t(), bool_t()]); let mut sig = Signature::new(row_with(), row_after()); - let mut exp = Signature::new(row_after(), row_after()); + let exp = Signature::new(row_after(), row_after()); assert_eq!(sig.transform(&tr), Ok(true)); assert_eq!(sig, exp); assert_eq!(sig.transform(&tr), Ok(false)); From 84fe82d509ff34ea9c8d0f02242af23e99b15645 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 14:16:49 +0000 Subject: [PATCH 023/123] WIP setup for test --- hugr-passes/src/lower_types.rs | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index ef681ca0f2..f44c0814f6 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -102,6 +102,19 @@ pub struct LowerTypes { check_sig: bool, } +impl Default for LowerTypes { + fn default() -> Self { + Self { + type_map: Default::default(), + param_types: Default::default(), + op_map: Default::default(), + param_ops: Default::default(), + const_fn: Arc::new(|_| None), + check_sig: false, + } + } +} + impl TypeTransformer for LowerTypes { type Err = ChangeTypeError; @@ -295,3 +308,115 @@ impl LowerTypes { } } } + +#[cfg(test)] +mod test { + use hugr_core::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ + prelude::{bool_t, option_type, UnwrapBuilder}, + TypeDefBound, Version, + }, + hugr::IdentList, + ops::ExtensionOp, + std_extensions::{ + arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, + collections::array::{array_type, ArrayOpDef}, + }, + types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}, + Extension, + }; + + use super::{LowerTypes, OpReplacement}; + + #[test] + fn lower() { + let ext = Extension::new_arc( + IdentList::new("TestExt").unwrap(), + Version::new(0, 0, 1), + |ext, w| { + let pv_of_var = ext + .add_type( + "PackedVec".into(), + vec![TypeBound::Any.into()], + String::new(), + TypeDefBound::from_params(vec![0]), + w, + ) + .unwrap() + .instantiate(vec![Type::new_var_use(0, TypeBound::Any).into()]) + .unwrap(); + ext.add_op( + "read".into(), + "".into(), + PolyFuncType::new( + vec![TypeBound::Any.into()], + Signature::new( + vec![pv_of_var.into(), INT_TYPES[6].to_owned()], + Type::new_var_use(0, TypeBound::Any), + ), + ), + w, + ) + .unwrap(); + ext.add_op( + "lowered_read_bool".into(), + "".into(), + Signature::new(vec![INT_TYPES[6].to_owned(); 2], bool_t()), + w, + ) + .unwrap(); + }, + ); + fn lowered_read(args: &[TypeArg]) -> Option { + let [TypeArg::Type { ty }] = args else { + panic!("Illegal TypeArgs") + }; + let mut dfb = DFGBuilder::new(Signature::new( + vec![array_type(64, ty.clone()), INT_TYPES[6].to_owned()], + ty.clone(), + )) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) + .unwrap(); + Some(OpReplacement::CompoundOp(Box::new( + dfb.finish_hugr_with_outputs([res]).unwrap(), + ))) + } + let pv = ext.get_type("PackedVec").unwrap(); + let read = ext.get_op("read").unwrap(); + let mut lw = LowerTypes::default(); + lw.lower_type( + pv.instantiate([bool_t().into()]).unwrap(), + INT_TYPES[6].to_owned(), + ); + lw.lower_parametric_type( + pv, + Box::new(|args: &[TypeArg]| { + let [TypeArg::Type { ty }] = args else { + panic!("Illegal TypeArgs") + }; + array_type(64, ty.clone()) + }), + ); + lw.lower_op( + &ExtensionOp::new(read.clone(), [bool_t().into()]).unwrap(), + OpReplacement::SingleOp( + ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), []) + .unwrap() + .into(), + ), + ); + lw.lower_parametric_op(read.as_ref(), Box::new(lowered_read)); + } +} From 616835319ca14ccb3f19269f95627bd433c30be4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 15:36:06 +0000 Subject: [PATCH 024/123] First test --- hugr-passes/src/lower_types.rs | 113 +++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 19 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index f44c0814f6..1ea221ac5a 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -135,7 +135,7 @@ impl TypeTransformer for LowerTypes { } } -#[derive(Clone, Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq, Eq)] #[non_exhaustive] pub enum ChangeTypeError { #[error(transparent)] @@ -311,27 +311,26 @@ impl LowerTypes { #[cfg(test)] mod test { - use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{ - prelude::{bool_t, option_type, UnwrapBuilder}, - TypeDefBound, Version, - }, - hugr::IdentList, - ops::ExtensionOp, - std_extensions::{ - arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, - collections::array::{array_type, ArrayOpDef}, - }, - types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}, - Extension, + use std::{collections::HashMap, sync::Arc}; + + use hugr_core::builder::{ + Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, + ModuleBuilder, SubContainer, + }; + use hugr_core::extension::prelude::{bool_t, option_type, UnwrapBuilder}; + use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::ops::{ExtensionOp, OpType, Tag}; + use hugr_core::std_extensions::{ + arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, + collections::array::{array_type, ArrayOpDef}, }; + use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; + use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use super::{LowerTypes, OpReplacement}; - #[test] - fn lower() { - let ext = Extension::new_arc( + fn ext() -> Arc { + Extension::new_arc( IdentList::new("TestExt").unwrap(), Version::new(0, 0, 1), |ext, w| { @@ -367,7 +366,10 @@ mod test { ) .unwrap(); }, - ); + ) + } + + fn lower_types(ext: &Extension) -> LowerTypes { fn lowered_read(args: &[TypeArg]) -> Option { let [TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs") @@ -418,5 +420,78 @@ mod test { ), ); lw.lower_parametric_op(read.as_ref(), Box::new(lowered_read)); + lw + } + + #[test] + fn module_func_dfg_cfg() { + let ext = ext(); + let coln = ext.get_type("PackedVec").unwrap(); + let read = ext.get_op("read").unwrap(); + let i64 = || INT_TYPES[6].to_owned(); + let c_int = Type::from(coln.instantiate([INT_TYPES[6].to_owned().into()]).unwrap()); + let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); + let mut mb = ModuleBuilder::new(); + let mut fb = mb + .define_function( + "foo", + Signature::new(vec![i64(), c_int.clone(), c_bool.clone()], bool_t()), + ) + .unwrap(); + let [idx, indices, bools] = fb.input_wires_arr(); + let mut dfb = fb + .dfg_builder(Signature::new(vec![i64(), c_int], i64()), [idx, indices]) + .unwrap(); + let [idx, indices] = dfb.input_wires_arr(); + let int_read_op = dfb + .add_dataflow_op( + ExtensionOp::new(read.clone(), [i64().into()]).unwrap(), + [indices, idx], + ) + .unwrap(); + let [idx2] = dfb + .finish_with_outputs(int_read_op.outputs()) + .unwrap() + .outputs_arr(); + let mut cfg = fb + .cfg_builder([(i64(), idx2), (c_bool, bools)], bool_t().into()) + .unwrap(); + let mut entry = cfg.entry_builder([bool_t().into()], type_row![]).unwrap(); + let [idx2, bools] = entry.input_wires_arr(); + let bool_read_op = entry + .add_dataflow_op( + ExtensionOp::new(read.clone(), [bool_t().into()]).unwrap(), + [bools, idx2], + ) + .unwrap(); + let [tagged] = entry + .add_dataflow_op( + OpType::Tag(Tag::new(0, vec![bool_t().into()])), + bool_read_op.outputs(), + ) + .unwrap() + .outputs_arr(); + let entry = entry.finish_with_outputs(tagged, []).unwrap(); + cfg.branch(&entry, 0, &cfg.exit_block()).unwrap(); + let cfg = cfg.finish_sub_container().unwrap(); + fb.finish_with_outputs(cfg.outputs()).unwrap(); + let mut h = mb.finish_hugr().unwrap(); + + assert!(lower_types(&ext).run_no_validate(&mut h).unwrap()); + + h.validate().unwrap(); + let mut counts: HashMap<_, usize> = HashMap::new(); + for ext_op in h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()) { + *(counts.entry(ext_op.def().name().as_str()).or_default()) += 1; + } + assert_eq!( + counts, + HashMap::from([ + ("lowered_read_bool", 1), + ("itousize", 1), + ("get", 1), + ("panic", 1) + ]) + ); } } From fcb85a26f6974268ea943ec50ceb4c1a05f13816 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 17:53:36 +0000 Subject: [PATCH 025/123] read only makes sense for Copyables --- hugr-passes/src/lower_types.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 1ea221ac5a..d2b3aed516 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -343,13 +343,13 @@ mod test { w, ) .unwrap() - .instantiate(vec![Type::new_var_use(0, TypeBound::Any).into()]) + .instantiate(vec![Type::new_var_use(0, TypeBound::Copyable).into()]) .unwrap(); ext.add_op( "read".into(), "".into(), PolyFuncType::new( - vec![TypeBound::Any.into()], + vec![TypeBound::Copyable.into()], Signature::new( vec![pv_of_var.into(), INT_TYPES[6].to_owned()], Type::new_var_use(0, TypeBound::Any), From 0f9aa1709631e7de5d5e00e74bb1abf74abe7996 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 18:02:59 +0000 Subject: [PATCH 026/123] Extend test to Calls of polyfunc; comments, monomorphize first --- hugr-passes/src/lower_types.rs | 68 ++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index d2b3aed516..6abe706a2e 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -61,6 +61,11 @@ impl From<&OpDef> for ParametricOp { #[derive(Clone, Debug, PartialEq)] pub enum OpReplacement { SingleOp(OpType), + /// Defines a sub-Hugr to splice in place of the op. + /// Note this will be of limited use before [monomorphization](super::monomorphize) because + /// the sub-Hugrwill not be able to use type variables present in the op. + // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s + // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), } @@ -143,6 +148,10 @@ pub enum ChangeTypeError { } impl LowerTypes { + /// Configures this instance to change occurrences of `src` to `dest`. + /// Note that if `src` is an instance of a *parametrized* Type, this should only + /// be used on *[monomorphize](super::monomorphize)d* Hugrs, because substitution + /// (parametric polymorphism) happening later will not respect the lowering(s). pub fn lower_type(&mut self, src: CustomType, dest: Type) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't // check that for parametric types, we'll be consistent and not check here either. @@ -327,6 +336,8 @@ mod test { use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use crate::{MonomorphizePass, RemoveDeadFuncsPass}; + use super::{LowerTypes, OpReplacement}; fn ext() -> Arc { @@ -424,17 +435,40 @@ mod test { } #[test] - fn module_func_dfg_cfg() { + fn module_func_dfg_cfg_call() { let ext = ext(); let coln = ext.get_type("PackedVec").unwrap(); - let read = ext.get_op("read").unwrap(); + let i64 = || INT_TYPES[6].to_owned(); let c_int = Type::from(coln.instantiate([INT_TYPES[6].to_owned().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); + let read = { + let read_op = ext.get_op("read").unwrap(); + let tv = Type::new_var_use(0, TypeBound::Copyable); + let mut read_fn = mb + .define_function( + "reader", + PolyFuncType::new( + [TypeBound::Copyable.into()], + Signature::new( + vec![coln.instantiate([tv.clone().into()]).unwrap().into(), i64()], + tv.clone(), + ), + ), + ) + .unwrap(); + let res = read_fn + .add_dataflow_op( + ExtensionOp::new(read_op.clone(), [tv.into()]).unwrap(), + read_fn.input_wires(), + ) + .unwrap(); + read_fn.finish_with_outputs(res.outputs()).unwrap() + }; let mut fb = mb .define_function( - "foo", + "main", Signature::new(vec![i64(), c_int.clone(), c_bool.clone()], bool_t()), ) .unwrap(); @@ -444,10 +478,7 @@ mod test { .unwrap(); let [idx, indices] = dfb.input_wires_arr(); let int_read_op = dfb - .add_dataflow_op( - ExtensionOp::new(read.clone(), [i64().into()]).unwrap(), - [indices, idx], - ) + .call(read.handle(), &[i64().into()], [indices, idx]) .unwrap(); let [idx2] = dfb .finish_with_outputs(int_read_op.outputs()) @@ -459,10 +490,7 @@ mod test { let mut entry = cfg.entry_builder([bool_t().into()], type_row![]).unwrap(); let [idx2, bools] = entry.input_wires_arr(); let bool_read_op = entry - .add_dataflow_op( - ExtensionOp::new(read.clone(), [bool_t().into()]).unwrap(), - [bools, idx2], - ) + .call(read.handle(), &[bool_t().into()], [bools, idx2]) .unwrap(); let [tagged] = entry .add_dataflow_op( @@ -476,10 +504,26 @@ mod test { let cfg = cfg.finish_sub_container().unwrap(); fb.finish_with_outputs(cfg.outputs()).unwrap(); let mut h = mb.finish_hugr().unwrap(); + // Since we treat collection differently, we must monomorphize to catch all instantiations + MonomorphizePass::default().run(&mut h).unwrap(); + RemoveDeadFuncsPass::default() + .with_module_entry_points(h.children(h.root()).filter(|n| { + h.get_optype(*n) + .as_func_defn() + .is_some_and(|fd| fd.name == "main") + })) + .run(&mut h) + .unwrap(); + assert_eq!( + h.nodes() + .filter(|n| h.get_optype(*n).is_func_defn()) + .count(), + 3 + ); assert!(lower_types(&ext).run_no_validate(&mut h).unwrap()); - h.validate().unwrap(); + let mut counts: HashMap<_, usize> = HashMap::new(); for ext_op in h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()) { *(counts.entry(ext_op.def().name().as_str()).or_default()) += 1; From 74c67754eae0d92701e9dd5234daba6c895d8910 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 18:55:33 +0000 Subject: [PATCH 027/123] no, instantiate the calls with types being lowered --- hugr-passes/src/lower_types.rs | 79 ++++++++++++++-------------------- 1 file changed, 33 insertions(+), 46 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 6abe706a2e..fe621a3592 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -336,8 +336,6 @@ mod test { use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; - use crate::{MonomorphizePass, RemoveDeadFuncsPass}; - use super::{LowerTypes, OpReplacement}; fn ext() -> Arc { @@ -438,34 +436,22 @@ mod test { fn module_func_dfg_cfg_call() { let ext = ext(); let coln = ext.get_type("PackedVec").unwrap(); - + let read = ext.get_op("read").unwrap(); let i64 = || INT_TYPES[6].to_owned(); let c_int = Type::from(coln.instantiate([INT_TYPES[6].to_owned().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); - let read = { - let read_op = ext.get_op("read").unwrap(); - let tv = Type::new_var_use(0, TypeBound::Copyable); - let mut read_fn = mb - .define_function( - "reader", - PolyFuncType::new( - [TypeBound::Copyable.into()], - Signature::new( - vec![coln.instantiate([tv.clone().into()]).unwrap().into(), i64()], - tv.clone(), - ), - ), - ) - .unwrap(); - let res = read_fn - .add_dataflow_op( - ExtensionOp::new(read_op.clone(), [tv.into()]).unwrap(), - read_fn.input_wires(), - ) - .unwrap(); - read_fn.finish_with_outputs(res.outputs()).unwrap() - }; + let fb = mb + .define_function( + "id", + PolyFuncType::new( + [TypeBound::Any.into()], + Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), + ), + ) + .unwrap(); + let inps = fb.input_wires(); + let id = fb.finish_with_outputs(inps).unwrap(); let mut fb = mb .define_function( "main", @@ -474,23 +460,40 @@ mod test { .unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let mut dfb = fb - .dfg_builder(Signature::new(vec![i64(), c_int], i64()), [idx, indices]) + .dfg_builder( + Signature::new(vec![i64(), c_int.clone()], i64()), + [idx, indices], + ) .unwrap(); let [idx, indices] = dfb.input_wires_arr(); + let [indices] = dfb + .call(id.handle(), &[c_int.into()], [indices]) + .unwrap() + .outputs_arr(); let int_read_op = dfb - .call(read.handle(), &[i64().into()], [indices, idx]) + .add_dataflow_op( + ExtensionOp::new(read.clone(), [i64().into()]).unwrap(), + [indices, idx], + ) .unwrap(); let [idx2] = dfb .finish_with_outputs(int_read_op.outputs()) .unwrap() .outputs_arr(); let mut cfg = fb - .cfg_builder([(i64(), idx2), (c_bool, bools)], bool_t().into()) + .cfg_builder([(i64(), idx2), (c_bool.clone(), bools)], bool_t().into()) .unwrap(); let mut entry = cfg.entry_builder([bool_t().into()], type_row![]).unwrap(); let [idx2, bools] = entry.input_wires_arr(); + let [bools] = entry + .call(id.handle(), &[c_bool.into()], [bools]) + .unwrap() + .outputs_arr(); let bool_read_op = entry - .call(read.handle(), &[bool_t().into()], [bools, idx2]) + .add_dataflow_op( + ExtensionOp::new(read.clone(), [bool_t().into()]).unwrap(), + [bools, idx2], + ) .unwrap(); let [tagged] = entry .add_dataflow_op( @@ -504,22 +507,6 @@ mod test { let cfg = cfg.finish_sub_container().unwrap(); fb.finish_with_outputs(cfg.outputs()).unwrap(); let mut h = mb.finish_hugr().unwrap(); - // Since we treat collection differently, we must monomorphize to catch all instantiations - MonomorphizePass::default().run(&mut h).unwrap(); - RemoveDeadFuncsPass::default() - .with_module_entry_points(h.children(h.root()).filter(|n| { - h.get_optype(*n) - .as_func_defn() - .is_some_and(|fd| fd.name == "main") - })) - .run(&mut h) - .unwrap(); - assert_eq!( - h.nodes() - .filter(|n| h.get_optype(*n).is_func_defn()) - .count(), - 3 - ); assert!(lower_types(&ext).run_no_validate(&mut h).unwrap()); h.validate().unwrap(); From 00fd2841850ed871beadcbe825063a46dd98e6e2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 21:40:19 +0000 Subject: [PATCH 028/123] Consts: HashMap keyed by either, add lower_ methods. Test TailLoop and Const. --- hugr-passes/src/lower_types.rs | 140 ++++++++++++++++++++++++--------- 1 file changed, 104 insertions(+), 36 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index fe621a3592..4f6bd8cc7d 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -3,17 +3,18 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; +use itertools::Either; use thiserror::Error; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; -use hugr_core::ops::constant::{CustomConst, Sum}; +use hugr_core::ops::constant::{OpaqueValue, Sum}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeTransformer}; +use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; use hugr_core::{Hugr, Node}; #[derive(Clone, Hash, PartialEq, Eq)] @@ -88,7 +89,7 @@ impl OpReplacement { } } -#[derive(Clone)] +#[derive(Clone, Default)] pub struct LowerTypes { /// Handles simple cases like T1 -> T2. /// If T1 is Copyable and T2 Linear, then error will be raised if we find e.g. @@ -100,26 +101,10 @@ pub struct LowerTypes { op_map: HashMap, // Called after lowering typeargs; return None to use original OpDef param_ops: HashMap Option>>, - // TODO should probably have a map, or two, here - from CustomType and from ParametricType. - // Whereupon the closure should be given a callback to self.change_value, too, in case of nested - // values for collections. - const_fn: Arc Option>, + consts: HashMap, Arc Option>>, check_sig: bool, } -impl Default for LowerTypes { - fn default() -> Self { - Self { - type_map: Default::default(), - param_types: Default::default(), - op_map: Default::default(), - param_ops: Default::default(), - const_fn: Arc::new(|_| None), - check_sig: false, - } - } -} - impl TypeTransformer for LowerTypes { type Err = ChangeTypeError; @@ -183,6 +168,24 @@ impl LowerTypes { self.param_ops.insert(src.into(), Arc::from(dest_fn)); } + pub fn lower_consts( + &mut self, + src_ty: &CustomType, + const_fn: Box Option>, + ) { + self.consts + .insert(Either::Left(src_ty.clone()), Arc::from(const_fn)); + } + + pub fn lower_consts_parametric( + &mut self, + src_ty: &TypeDef, + const_fn: Box Option>, + ) { + self.consts + .insert(Either::Right(src_ty.into()), Arc::from(const_fn)); + } + pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { @@ -305,14 +308,21 @@ impl LowerTypes { any_change |= sum_type.transform(self)?; Ok(any_change) } - Value::Extension { e } => { - if let Some(new_const) = (self.const_fn)(e.value()) { - *value = new_const; - Ok(true) - } else { - Ok(false) + Value::Extension { e } => Ok('changed: { + if let TypeEnum::Extension(exty) = e.get_type().as_type_enum() { + if let Some(const_fn) = self + .consts + .get(&Either::Left(exty.clone())) + .or(self.consts.get(&Either::Right(exty.into()))) + { + if let Some(new_const) = const_fn(e) { + *value = new_const; + break 'changed true; + } + } } - } + false + }), Value::Function { hugr } => self.run_no_validate(&mut **hugr), } } @@ -324,16 +334,15 @@ mod test { use hugr_core::builder::{ Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - ModuleBuilder, SubContainer, + ModuleBuilder, SubContainer, TailLoopBuilder, }; - use hugr_core::extension::prelude::{bool_t, option_type, UnwrapBuilder}; + use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; use hugr_core::extension::{TypeDefBound, Version}; - use hugr_core::ops::{ExtensionOp, OpType, Tag}; - use hugr_core::std_extensions::{ - arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, - collections::array::{array_type, ArrayOpDef}, - }; - use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; + use hugr_core::ops::{ExtensionOp, OpType, Tag, Value}; + use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef, ArrayValue}; + use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use super::{LowerTypes, OpReplacement}; @@ -438,7 +447,7 @@ mod test { let coln = ext.get_type("PackedVec").unwrap(); let read = ext.get_op("read").unwrap(); let i64 = || INT_TYPES[6].to_owned(); - let c_int = Type::from(coln.instantiate([INT_TYPES[6].to_owned().into()]).unwrap()); + let c_int = Type::from(coln.instantiate([i64().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); let fb = mb @@ -525,4 +534,63 @@ mod test { ]) ); } + + #[test] + fn loop_const() { + let cu = |u| ConstUsize::new(u).into(); + let mut tl = TailLoopBuilder::new( + list_type(usize_t()), + list_type(bool_t()), + list_type(usize_t()), + ) + .unwrap(); + let [_, bools] = tl.input_wires_arr(); + let st = SumType::new(vec![list_type(usize_t()); 2]); + let pred = tl.add_load_value( + Value::sum( + 0, + [ListValue::new(usize_t(), [cu(1), cu(3), cu(3), cu(7)]).into()], + st, + ) + .unwrap(), + ); + tl.set_outputs(pred, [bools]).unwrap(); + let mut h = tl.finish_hugr().unwrap(); + + // Lower List to Array<4,T> so we can use List's handy CustomConst + let mut lowerer = LowerTypes::default(); + lowerer.lower_parametric_type( + list_type_def(), + Box::new(|args: &[TypeArg]| { + let [TypeArg::Type { ty }] = args else { + panic!("Expected elem type") + }; + array_type(4, ty.clone()) + }), + ); + lowerer.lower_consts_parametric( + list_type_def(), + Box::new(|opaq| { + let lv = opaq + .value() + .downcast_ref::() + .expect("Only one constant in test"); + Some( + ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()) + .into(), + ) + }), + ); + lowerer.run_no_validate(&mut h).unwrap(); + h.validate().unwrap(); + assert_eq!( + h.get_optype(pred.node()) + .as_load_constant() + .map(|lc| lc.constant_type()), + Some(&Type::new_sum(vec![ + Type::from(array_type(4, usize_t())); + 2 + ])) + ); + } } From 9f02acf9e7164cc5730e8eefd7b35fdb13099fba Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 18 Mar 2025 21:44:09 +0000 Subject: [PATCH 029/123] clippy, turn off type-complexity --- hugr-passes/src/lower_types.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 4f6bd8cc7d..24055ec304 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,6 +1,6 @@ +#![allow(clippy::type_complexity)] use std::borrow::Cow; use std::collections::HashMap; -use std::ops::Deref; use std::sync::Arc; use itertools::Either; @@ -203,7 +203,7 @@ impl LowerTypes { // (If check_sig) then verify that the Signature still has the same arity/wires, // with only the expected changes to types within. if let Some(expected_sig) = expected_dfsig { - assert_eq!(new_dfsig.as_ref().map(Cow::deref), expected_sig.as_ref()); + assert_eq!(new_dfsig.as_deref(), expected_sig.as_ref()); } } Ok(changed) @@ -232,7 +232,7 @@ impl LowerTypes { let change = func_sig.body_mut().transform(self)? | type_args.transform(self)?; if change { let new_inst = func_sig - .instantiate(&type_args) + .instantiate(type_args) .map_err(ChangeTypeError::SignatureError)?; *instantiation = new_inst; } From a0ac6d6f7c03f79ac1c44c6346fb8d3117053de7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 10:02:27 +0000 Subject: [PATCH 030/123] Actual Error for check_sig, add setter method --- hugr-passes/src/lower_types.rs | 58 ++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 24055ec304..a7ab2ef501 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -14,7 +14,9 @@ use hugr_core::ops::{ FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; +use hugr_core::types::{ + CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, +}; use hugr_core::{Hugr, Node}; #[derive(Clone, Hash, PartialEq, Eq)] @@ -125,11 +127,18 @@ impl TypeTransformer for LowerTypes { } } -#[derive(Clone, Debug, Error, PartialEq, Eq)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] pub enum ChangeTypeError { #[error(transparent)] SignatureError(#[from] SignatureError), + #[error("Lowering op {op} with original signature {old:?}\nExpected signature: {expected:?}\nBut got: {actual:?}")] + SignatureMismatch { + op: OpType, + old: Option, + expected: Option, + actual: Option, + }, } impl LowerTypes { @@ -186,15 +195,29 @@ impl LowerTypes { .insert(Either::Right(src_ty.into()), Arc::from(const_fn)); } + /// Configures this instance to check signatures of ops lowered following [Self::lower_op] + /// and [Self::lower_parametric_op] are as expected, i.e. match the signatures of the + /// original op modulo the required type substitutions. (If signatures are incorrect, + /// it is likely that the wires in the Hugr will be invalid, so this gives an early warning + /// by instead raising [ChangeTypeError::SignatureMismatch].) + pub fn check_signatures(&mut self, check_sig: bool) { + self.check_sig = check_sig; + } + pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { - let expected_dfsig = if self.check_sig { - let mut dfsig = hugr.get_optype(n).dataflow_signature().map(Cow::into_owned); - if let Some(sig) = dfsig.as_mut() { - sig.transform(self)?; - } - Some(dfsig) + let maybe_check_sig = if self.check_sig { + Some( + if let Some(old_sig) = hugr.get_optype(n).dataflow_signature() { + let old_sig = old_sig.into_owned(); + let mut expected_sig = old_sig.clone(); + expected_sig.transform(self)?; + (Some(old_sig), Some(expected_sig)) + } else { + (None, None) + }, + ) } else { None }; @@ -202,8 +225,23 @@ impl LowerTypes { let new_dfsig = hugr.get_optype(n).dataflow_signature(); // (If check_sig) then verify that the Signature still has the same arity/wires, // with only the expected changes to types within. - if let Some(expected_sig) = expected_dfsig { - assert_eq!(new_dfsig.as_deref(), expected_sig.as_ref()); + if let Some((old, expected)) = maybe_check_sig { + match (&expected, &new_dfsig) { + (None, None) => (), + (Some(exp), Some(act)) + if exp.input == act.input && exp.output == act.output => + { + () + } + _ => { + return Err(ChangeTypeError::SignatureMismatch { + op: hugr.get_optype(n).clone(), + old, + expected, + actual: new_dfsig.map(Cow::into_owned), + }) + } + }; } } Ok(changed) From 6ac9efca059a9b1798a865c2d4c6fe19217e054a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 10:05:40 +0000 Subject: [PATCH 031/123] docs --- hugr-passes/src/lower_types.rs | 40 ++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index a7ab2ef501..8f432cb910 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -66,10 +66,11 @@ pub enum OpReplacement { SingleOp(OpType), /// Defines a sub-Hugr to splice in place of the op. /// Note this will be of limited use before [monomorphization](super::monomorphize) because - /// the sub-Hugrwill not be able to use type variables present in the op. + /// the sub-Hugr will not be able to use type variables present in the op. // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), + // TODO add Call to...a Node in the existing Hugr (?) } impl OpReplacement { @@ -144,18 +145,24 @@ pub enum ChangeTypeError { impl LowerTypes { /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* Type, this should only - /// be used on *[monomorphize](super::monomorphize)d* Hugrs, because substitution + /// be used on already-*[monomorphize](super::monomorphize)d* Hugrs, as substitution /// (parametric polymorphism) happening later will not respect the lowering(s). + /// + /// This takes precedence over [Self::lower_parametric_type] where the `src`s overlap. pub fn lower_type(&mut self, src: CustomType, dest: Type) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't // check that for parametric types, we'll be consistent and not check here either. self.type_map.insert(src, dest); } + /// Configures this instance to change occurrences of a parametrized type `src` + /// via a callback that builds the replacement type given the [TypeArg]s. + /// Note that the TypeArgs will already have been lowered (e.g. they may not + /// fit the bounds of the original type). pub fn lower_parametric_type( &mut self, src: &TypeDef, - dest_fn: Box Type>, + dest_fn: Box Type>, // TODO should we return Option ? ) { // No way to check that dest_fn never produces a linear type. // We could require copy/discard-generators if src is Copyable, or *might be* @@ -165,27 +172,46 @@ impl LowerTypes { self.param_types.insert(src.into(), Arc::from(dest_fn)); } - pub fn lower_op(&mut self, src: &ExtensionOp, tgt: OpReplacement) { - self.op_map.insert(OpHashWrapper::from(src), tgt); + /// Configures this instance to change occurrences of `src` to `dest`. + /// Note that if `src` is an instance of a *parametrized* [OpDef], this should only + /// be used on already-*[monomorphize](super::monomorphize)d* Hugrs, as substitution + /// (parametric polymorphism) happening later will not respect the lowering(s). + /// + /// This takes precedence over [Self::lower_parametric_op] where the `src`s overlap. + pub fn lower_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { + self.op_map.insert(OpHashWrapper::from(src), dest); } + /// Configures this instance to change occurrences of a parametrized op `src` + /// via a callback that builds the replacement type given the [TypeArg]s. + /// Note that the TypeArgs will already have been lowered (e.g. they may not + /// fit the bounds of the original op). + /// + /// If the Callback returns None, the new typeargs will be applied to the original op. pub fn lower_parametric_op( &mut self, src: &OpDef, - dest_fn: Box Option>, + dest_fn: Box Option>, // TODO or just OpReplacement? ) { self.param_ops.insert(src.into(), Arc::from(dest_fn)); } + /// Configures this instance to change occurrences consts of type `src_ty`, using + /// a callback given the value of the constant (of that type). + /// Note that if `src_ty` is an instance of a *parametrized* [TypeDef], this + /// takes precedence over [Self::lower_consts_parametric] where the `src_ty`s overlap. pub fn lower_consts( &mut self, src_ty: &CustomType, - const_fn: Box Option>, + const_fn: Box Option>, // TODO should we return Value?? ) { self.consts .insert(Either::Left(src_ty.clone()), Arc::from(const_fn)); } + /// Configures this instance to change occurrences consts of all types that + /// are instances of a parametric typedef `src_ty`, using a callback given + /// the value of the constant (the [OpaqueValue] contains the [TypeArg]s). pub fn lower_consts_parametric( &mut self, src_ty: &TypeDef, From 044ff32fa8d0cbc4296058aab81f0bce2744c56f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 11:53:36 +0000 Subject: [PATCH 032/123] Test variable, boundednat; use list_type --- hugr-core/src/types.rs | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index e46f485c40..bcf2cda07e 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -784,6 +784,8 @@ pub(crate) mod test { use super::*; use crate::extension::prelude::{qb_t, usize_t}; use crate::extension::TypeDefBound; + use crate::std_extensions::collections::array::{array_type, array_type_parametric}; + use crate::std_extensions::collections::list::list_type; use crate::types::type_param::TypeArgError; use crate::{hugr::IdentList, type_row, Extension}; @@ -854,33 +856,35 @@ pub(crate) mod test { #[test] fn transform() { const LIN: SmolStr = SmolStr::new_inline("MyLinear"); - const COLN: SmolStr = SmolStr::new_inline("ColnOfAny"); - let e = Extension::new_test_arc(IdentList::new("TestExt").unwrap(), |e, w| { e.add_type(LIN, vec![], String::new(), TypeDefBound::any(), w) .unwrap(); - e.add_type( - COLN, - vec![TypeBound::Any.into()], - String::new(), - TypeDefBound::from_params(vec![0]), - w, - ) - .unwrap(); }); let lin = e.get_type(&LIN).unwrap().instantiate([]).unwrap(); - let coln = e.get_type(&COLN).unwrap(); let lin_to_usize = FnTransformer(|ct: &CustomType| (*ct == lin).then_some(usize_t())); let mut t = Type::new_extension(lin.clone()); assert_eq!(t.transform(&lin_to_usize), Ok(true)); assert_eq!(t, usize_t()); - let mut t = - Type::new_extension(coln.instantiate([Type::from(lin.clone()).into()]).unwrap()); - assert_eq!(t.transform(&lin_to_usize), Ok(true)); - let expected = Type::new_extension(coln.instantiate([usize_t().into()]).unwrap()); - assert_eq!(t, expected); - assert_eq!(t.transform(&lin_to_usize), Ok(false)); + + for coln in [ + list_type, + |t| array_type(10, t), + |t| { + array_type_parametric( + TypeArg::new_var_use(0, TypeParam::bounded_nat(3.try_into().unwrap())), + t, + ) + .unwrap() + }, + ] { + let mut t = coln(lin.clone().into()); + assert_eq!(t.transform(&lin_to_usize), Ok(true)); + let expected = coln(usize_t()); + assert_eq!(t, expected); + assert_eq!(t.transform(&lin_to_usize), Ok(false)); + assert_eq!(t, expected); + } } #[test] From b539e2f37291b9c699501f3ef77afa73b2a6196d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 12:12:15 +0000 Subject: [PATCH 033/123] Yet Another ValidationLevel interface --- hugr-passes/src/lower_types.rs | 28 ++++++++++++++++++++++++---- hugr-passes/src/validation.rs | 2 +- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 8f432cb910..d74aa900de 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -19,6 +19,8 @@ use hugr_core::types::{ }; use hugr_core::{Hugr, Node}; +use crate::validation::{ValidatePassError, ValidationLevel}; + #[derive(Clone, Hash, PartialEq, Eq)] struct OpHashWrapper { ext_name: ExtensionId, @@ -106,6 +108,7 @@ pub struct LowerTypes { param_ops: HashMap Option>>, consts: HashMap, Arc Option>>, check_sig: bool, + validation: ValidationLevel, } impl TypeTransformer for LowerTypes { @@ -128,7 +131,7 @@ impl TypeTransformer for LowerTypes { } } -#[derive(Clone, Debug, Error, PartialEq)] +#[derive(Debug, Error, PartialEq)] #[non_exhaustive] pub enum ChangeTypeError { #[error(transparent)] @@ -140,9 +143,20 @@ pub enum ChangeTypeError { expected: Option, actual: Option, }, + #[error(transparent)] + #[allow(missing_docs)] + ValidationError(#[from] ValidatePassError), } impl LowerTypes { + /// Sets the validation level used before and after the pass is run. + // Note the self -> Self style is consistent with other passes, but not the other methods here. + // TODO change the others? But we are planning to drop validation_level in https://github.com/CQCL/hugr/pull/1895 + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* Type, this should only /// be used on already-*[monomorphize](super::monomorphize)d* Hugrs, as substitution @@ -230,7 +244,13 @@ impl LowerTypes { self.check_sig = check_sig; } - pub fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { + /// Run the pass using specified configuration. + pub fn run(&self, hugr: &mut H) -> Result { + self.validation + .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) + } + + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { let maybe_check_sig = if self.check_sig { @@ -581,7 +601,7 @@ mod test { fb.finish_with_outputs(cfg.outputs()).unwrap(); let mut h = mb.finish_hugr().unwrap(); - assert!(lower_types(&ext).run_no_validate(&mut h).unwrap()); + assert!(lower_types(&ext).run(&mut h).unwrap()); h.validate().unwrap(); let mut counts: HashMap<_, usize> = HashMap::new(); @@ -645,7 +665,7 @@ mod test { ) }), ); - lowerer.run_no_validate(&mut h).unwrap(); + lowerer.run(&mut h).unwrap(); h.validate().unwrap(); assert_eq!( h.get_optype(pred.node()) diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index baf3b86d83..5f53f403c7 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -23,7 +23,7 @@ pub enum ValidationLevel { WithExtensions, } -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] From bd24d1a631d82f7aacfa9ec990ca51e155c95fb7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 12:22:48 +0000 Subject: [PATCH 034/123] clippy --- hugr-passes/src/lower_types.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index d74aa900de..2189fe130a 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -275,10 +275,7 @@ impl LowerTypes { match (&expected, &new_dfsig) { (None, None) => (), (Some(exp), Some(act)) - if exp.input == act.input && exp.output == act.output => - { - () - } + if exp.input == act.input && exp.output == act.output => {} _ => { return Err(ChangeTypeError::SignatureMismatch { op: hugr.get_optype(n).clone(), From a5d8b65ff43a8d4e294b5d329abd582aac149d40 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 12:53:50 +0000 Subject: [PATCH 035/123] doclinxs --- hugr-core/src/hugr/internal.rs | 4 ++-- hugr-passes/src/lower_types.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 77e46a4a35..db7095c2c7 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -315,8 +315,8 @@ pub trait HugrMutInternals: RootTagged { /// Changing this may invalidate the ports, which may need to be resized to /// match the OpType signature. /// - /// Will panic for the root node unless [Self::RootHandle] is [OpTag::Any], - /// as mutation could invalidate the bound. + /// Will panic for the root node unless [`Self::RootHandle`](RootTagged::RootHandle) + /// is [OpTag::Any], as mutation could invalidate the bound. fn optype_mut(&mut self, node: Node) -> &mut OpType { if Self::RootHandle::TAG.is_superset(OpTag::Any) { panic_invalid_node(self, node); diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 2189fe130a..85231795a4 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -67,7 +67,7 @@ impl From<&OpDef> for ParametricOp { pub enum OpReplacement { SingleOp(OpType), /// Defines a sub-Hugr to splice in place of the op. - /// Note this will be of limited use before [monomorphization](super::monomorphize) because + /// Note this will be of limited use before [monomorphization](super::monomorphize()) because /// the sub-Hugr will not be able to use type variables present in the op. // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 @@ -159,7 +159,7 @@ impl LowerTypes { /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* Type, this should only - /// be used on already-*[monomorphize](super::monomorphize)d* Hugrs, as substitution + /// be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as substitution /// (parametric polymorphism) happening later will not respect the lowering(s). /// /// This takes precedence over [Self::lower_parametric_type] where the `src`s overlap. @@ -188,7 +188,7 @@ impl LowerTypes { /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [OpDef], this should only - /// be used on already-*[monomorphize](super::monomorphize)d* Hugrs, as substitution + /// be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as substitution /// (parametric polymorphism) happening later will not respect the lowering(s). /// /// This takes precedence over [Self::lower_parametric_op] where the `src`s overlap. From 6b1438c4e82ece95bc219a189df780dae27df7ff Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 13:17:29 +0000 Subject: [PATCH 036/123] pub re-export --- hugr-passes/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index d5d765e367..e7b643bcc9 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -29,6 +29,7 @@ pub use monomorphize::remove_polyfuncs; pub use monomorphize::monomorphize; pub use monomorphize::{MonomorphizeError, MonomorphizePass}; pub mod lower_types; +pub use lower_types::LowerTypes; pub mod nest_cfgs; pub mod non_local; pub mod validation; From d1036bc2f3b6d5f16448c43a0305e3818bd33d8c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 13:26:41 +0000 Subject: [PATCH 037/123] fix const_loop for extensions --- hugr-passes/src/lower_types.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 85231795a4..14924e47da 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -419,9 +419,10 @@ mod test { }; use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; use hugr_core::extension::{TypeDefBound, Version}; - use hugr_core::ops::{ExtensionOp, OpType, Tag, Value}; + use hugr_core::hugr::internal::HugrMutInternals; + use hugr_core::ops::{ExtensionOp, OpType, Tag, TailLoop, Value}; use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; - use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef, ArrayValue}; + use hugr_core::std_extensions::collections::array::{self, array_type, ArrayOpDef, ArrayValue}; use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; @@ -663,6 +664,14 @@ mod test { }), ); lowerer.run(&mut h).unwrap(); + if cfg!(feature = "extension_inference") { + match h.optype_mut(h.root()) { + OpType::TailLoop(TailLoop { + extension_delta, .. + }) => extension_delta.insert(array::EXTENSION_ID), + _ => panic!(), + } + } h.validate().unwrap(); assert_eq!( h.get_optype(pred.node()) From d19dc5a77047fc66a0511a6a68c257d86dc96378 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 19 Mar 2025 14:26:06 +0000 Subject: [PATCH 038/123] fix other test for extensions, but turn off extension validation after lowering --- hugr-passes/src/lower_types.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 14924e47da..9044aa03a6 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -414,13 +414,14 @@ mod test { use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - ModuleBuilder, SubContainer, TailLoopBuilder, + inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; use hugr_core::extension::{TypeDefBound, Version}; use hugr_core::hugr::internal::HugrMutInternals; use hugr_core::ops::{ExtensionOp, OpType, Tag, TailLoop, Value}; + use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayOpDef, ArrayValue}; use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; @@ -474,7 +475,7 @@ mod test { let [TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs") }; - let mut dfb = DFGBuilder::new(Signature::new( + let mut dfb = DFGBuilder::new(inout_sig( vec![array_type(64, ty.clone()), INT_TYPES[6].to_owned()], ty.clone(), )) @@ -546,15 +547,13 @@ mod test { let mut fb = mb .define_function( "main", - Signature::new(vec![i64(), c_int.clone(), c_bool.clone()], bool_t()), + Signature::new(vec![i64(), c_int.clone(), c_bool.clone()], bool_t()) + .with_extension_delta(ext.name.clone()), ) .unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let mut dfb = fb - .dfg_builder( - Signature::new(vec![i64(), c_int.clone()], i64()), - [idx, indices], - ) + .dfg_builder(inout_sig(vec![i64(), c_int.clone()], i64()), [idx, indices]) .unwrap(); let [idx, indices] = dfb.input_wires_arr(); let [indices] = dfb @@ -600,7 +599,8 @@ mod test { let mut h = mb.finish_hugr().unwrap(); assert!(lower_types(&ext).run(&mut h).unwrap()); - h.validate().unwrap(); + // We do not update the extension delta + h.validate_no_extensions().unwrap(); let mut counts: HashMap<_, usize> = HashMap::new(); for ext_op in h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()) { From d0fddde4bb8dd2f1202b0d6b7061fe1c8b704531 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 12:00:55 +0000 Subject: [PATCH 039/123] Add another test of Conditional + Case --- hugr-passes/src/lower_types.rs | 112 +++++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 9044aa03a6..b8c2b9a8a2 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -418,18 +418,27 @@ mod test { HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; + use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{TypeDefBound, Version}; use hugr_core::hugr::internal::HugrMutInternals; use hugr_core::ops::{ExtensionOp, OpType, Tag, TailLoop, Value}; use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; - use hugr_core::std_extensions::collections::array::{self, array_type, ArrayOpDef, ArrayValue}; + use hugr_core::std_extensions::collections::array::{ + self, array_type, ArrayOp, ArrayOpDef, ArrayValue, + }; use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use itertools::Itertools; use super::{LowerTypes, OpReplacement}; + const PACKED_VEC: &str = "PackedVec"; + fn i64_t() -> Type { + INT_TYPES[6].clone() + } + fn ext() -> Arc { Extension::new_arc( IdentList::new("TestExt").unwrap(), @@ -437,7 +446,7 @@ mod test { |ext, w| { let pv_of_var = ext .add_type( - "PackedVec".into(), + PACKED_VEC.into(), vec![TypeBound::Any.into()], String::new(), TypeDefBound::from_params(vec![0]), @@ -452,7 +461,7 @@ mod test { PolyFuncType::new( vec![TypeBound::Copyable.into()], Signature::new( - vec![pv_of_var.into(), INT_TYPES[6].to_owned()], + vec![pv_of_var.into(), i64_t()], Type::new_var_use(0, TypeBound::Any), ), ), @@ -462,7 +471,7 @@ mod test { ext.add_op( "lowered_read_bool".into(), "".into(), - Signature::new(vec![INT_TYPES[6].to_owned(); 2], bool_t()), + Signature::new(vec![i64_t(); 2], bool_t()), w, ) .unwrap(); @@ -476,7 +485,7 @@ mod test { panic!("Illegal TypeArgs") }; let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), INT_TYPES[6].to_owned()], + vec![array_type(64, ty.clone()), i64_t()], ty.clone(), )) .unwrap(); @@ -496,13 +505,10 @@ mod test { dfb.finish_hugr_with_outputs([res]).unwrap(), ))) } - let pv = ext.get_type("PackedVec").unwrap(); + let pv = ext.get_type(PACKED_VEC).unwrap(); let read = ext.get_op("read").unwrap(); let mut lw = LowerTypes::default(); - lw.lower_type( - pv.instantiate([bool_t().into()]).unwrap(), - INT_TYPES[6].to_owned(), - ); + lw.lower_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); lw.lower_parametric_type( pv, Box::new(|args: &[TypeArg]| { @@ -525,12 +531,11 @@ mod test { } #[test] - fn module_func_dfg_cfg_call() { + fn module_func_cfg_call() { let ext = ext(); - let coln = ext.get_type("PackedVec").unwrap(); + let coln = ext.get_type(PACKED_VEC).unwrap(); let read = ext.get_op("read").unwrap(); - let i64 = || INT_TYPES[6].to_owned(); - let c_int = Type::from(coln.instantiate([i64().into()]).unwrap()); + let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); let fb = mb @@ -547,31 +552,24 @@ mod test { let mut fb = mb .define_function( "main", - Signature::new(vec![i64(), c_int.clone(), c_bool.clone()], bool_t()) + Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()) .with_extension_delta(ext.name.clone()), ) .unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); - let mut dfb = fb - .dfg_builder(inout_sig(vec![i64(), c_int.clone()], i64()), [idx, indices]) - .unwrap(); - let [idx, indices] = dfb.input_wires_arr(); - let [indices] = dfb + let [indices] = fb .call(id.handle(), &[c_int.into()], [indices]) .unwrap() .outputs_arr(); - let int_read_op = dfb + let int_read_op = fb .add_dataflow_op( - ExtensionOp::new(read.clone(), [i64().into()]).unwrap(), + ExtensionOp::new(read.clone(), [i64_t().into()]).unwrap(), [indices, idx], ) .unwrap(); - let [idx2] = dfb - .finish_with_outputs(int_read_op.outputs()) - .unwrap() - .outputs_arr(); + let [idx2] = int_read_op.outputs_arr(); let mut cfg = fb - .cfg_builder([(i64(), idx2), (c_bool.clone(), bools)], bool_t().into()) + .cfg_builder([(i64_t(), idx2), (c_bool.clone(), bools)], bool_t().into()) .unwrap(); let mut entry = cfg.entry_builder([bool_t().into()], type_row![]).unwrap(); let [idx2, bools] = entry.input_wires_arr(); @@ -617,6 +615,66 @@ mod test { ); } + #[test] + fn dfg_conditional_case() { + let ext = ext(); + let coln = ext.get_type(PACKED_VEC).unwrap(); + let read = ext.get_op("read").unwrap(); + let pv = |t: Type| Type::new_extension(coln.instantiate([t.into()]).unwrap()); + let sum_rows = [vec![pv(pv(bool_t())), i64_t()].into(), pv(i64_t()).into()]; + let mut dfb = DFGBuilder::new(inout_sig( + vec![Type::new_sum(sum_rows.clone()), pv(bool_t()), pv(i64_t())], + vec![pv(bool_t()), pv(i64_t())], + )) + .unwrap(); + let [sum, vb, vi] = dfb.input_wires_arr(); + let mut cb = dfb + .conditional_builder( + (sum_rows, sum), + [(pv(bool_t()), vb), (pv(i64_t()), vi)], + vec![pv(bool_t()), pv(i64_t())].into(), + ) + .unwrap(); + let mut case0 = cb.case_builder(0).unwrap(); + let [vvb, i, _, vi0] = case0.input_wires_arr(); + let [vb0] = case0 + .add_dataflow_op( + ExtensionOp::new(read.clone(), [pv(bool_t()).into()]).unwrap(), + [vvb, i], + ) + .unwrap() + .outputs_arr(); + case0.finish_with_outputs([vb0, vi0]).unwrap(); + + let case1 = cb.case_builder(1).unwrap(); + let [vi, vb1, _vi1] = case1.input_wires_arr(); + case1.finish_with_outputs([vb1, vi]).unwrap(); + let cond = cb.finish_sub_container().unwrap(); + let mut h = dfb.finish_hugr_with_outputs(cond.outputs()).unwrap(); + + lower_types(&ext).run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); + let ext_ops = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .collect_vec(); + assert_eq!( + ext_ops + .iter() + .map(|e| e.def().name()) + .sorted() + .collect_vec(), + ["get", "itousize", "panic"] + ); + // The PackedVec> becomes an array + let [array_get] = ext_ops + .into_iter() + .filter_map(|e| ArrayOp::from_extension_op(e).ok()) + .collect_array() + .unwrap(); + assert_eq!(array_get, ArrayOpDef::get.to_concrete(i64_t(), 64)); + } + #[test] fn loop_const() { let cu = |u| ConstUsize::new(u).into(); From 3494887e2e14e381ea46b31f6b1b03be9216447c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 12:10:03 +0000 Subject: [PATCH 040/123] common up read_op --- hugr-passes/src/lower_types.rs | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index b8c2b9a8a2..7928d08f2f 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -439,6 +439,10 @@ mod test { INT_TYPES[6].clone() } + fn read_op(ext: &Arc, t: Type) -> ExtensionOp { + ExtensionOp::new(ext.get_op("read").unwrap().clone(), [t.into()]).unwrap() + } + fn ext() -> Arc { Extension::new_arc( IdentList::new("TestExt").unwrap(), @@ -479,7 +483,7 @@ mod test { ) } - fn lower_types(ext: &Extension) -> LowerTypes { + fn lower_types(ext: &Arc) -> LowerTypes { fn lowered_read(args: &[TypeArg]) -> Option { let [TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs") @@ -506,7 +510,6 @@ mod test { ))) } let pv = ext.get_type(PACKED_VEC).unwrap(); - let read = ext.get_op("read").unwrap(); let mut lw = LowerTypes::default(); lw.lower_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); lw.lower_parametric_type( @@ -519,14 +522,14 @@ mod test { }), ); lw.lower_op( - &ExtensionOp::new(read.clone(), [bool_t().into()]).unwrap(), + &read_op(ext, bool_t()), OpReplacement::SingleOp( ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), []) .unwrap() .into(), ), ); - lw.lower_parametric_op(read.as_ref(), Box::new(lowered_read)); + lw.lower_parametric_op(ext.get_op("read").unwrap().as_ref(), Box::new(lowered_read)); lw } @@ -534,7 +537,6 @@ mod test { fn module_func_cfg_call() { let ext = ext(); let coln = ext.get_type(PACKED_VEC).unwrap(); - let read = ext.get_op("read").unwrap(); let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); @@ -562,10 +564,7 @@ mod test { .unwrap() .outputs_arr(); let int_read_op = fb - .add_dataflow_op( - ExtensionOp::new(read.clone(), [i64_t().into()]).unwrap(), - [indices, idx], - ) + .add_dataflow_op(read_op(&ext, i64_t()), [indices, idx]) .unwrap(); let [idx2] = int_read_op.outputs_arr(); let mut cfg = fb @@ -578,10 +577,7 @@ mod test { .unwrap() .outputs_arr(); let bool_read_op = entry - .add_dataflow_op( - ExtensionOp::new(read.clone(), [bool_t().into()]).unwrap(), - [bools, idx2], - ) + .add_dataflow_op(read_op(&ext, bool_t()), [bools, idx2]) .unwrap(); let [tagged] = entry .add_dataflow_op( @@ -619,7 +615,6 @@ mod test { fn dfg_conditional_case() { let ext = ext(); let coln = ext.get_type(PACKED_VEC).unwrap(); - let read = ext.get_op("read").unwrap(); let pv = |t: Type| Type::new_extension(coln.instantiate([t.into()]).unwrap()); let sum_rows = [vec![pv(pv(bool_t())), i64_t()].into(), pv(i64_t()).into()]; let mut dfb = DFGBuilder::new(inout_sig( @@ -638,10 +633,7 @@ mod test { let mut case0 = cb.case_builder(0).unwrap(); let [vvb, i, _, vi0] = case0.input_wires_arr(); let [vb0] = case0 - .add_dataflow_op( - ExtensionOp::new(read.clone(), [pv(bool_t()).into()]).unwrap(), - [vvb, i], - ) + .add_dataflow_op(read_op(&ext, pv(bool_t())), [vvb, i]) .unwrap() .outputs_arr(); case0.finish_with_outputs([vb0, vi0]).unwrap(); From ffdcaf28e060fed135e6a8096672f5eb7b3f206b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 12:22:09 +0000 Subject: [PATCH 041/123] test tidies --- hugr-passes/src/lower_types.rs | 42 +++++++++++----------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 7928d08f2f..3f74755b90 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -411,7 +411,7 @@ impl LowerTypes { #[cfg(test)] mod test { - use std::{collections::HashMap, sync::Arc}; + use std::sync::Arc; use hugr_core::builder::{ inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, @@ -540,33 +540,25 @@ mod test { let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap()); let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); let mut mb = ModuleBuilder::new(); + let sig = Signature::new_endo(Type::new_var_use(0, TypeBound::Any)); let fb = mb - .define_function( - "id", - PolyFuncType::new( - [TypeBound::Any.into()], - Signature::new_endo(Type::new_var_use(0, TypeBound::Any)), - ), - ) + .define_function("id", PolyFuncType::new([TypeBound::Any.into()], sig)) .unwrap(); let inps = fb.input_wires(); let id = fb.finish_with_outputs(inps).unwrap(); - let mut fb = mb - .define_function( - "main", - Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()) - .with_extension_delta(ext.name.clone()), - ) - .unwrap(); + + let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()) + .with_extension_delta(ext.name.clone()); + let mut fb = mb.define_function("main", sig).unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let [indices] = fb .call(id.handle(), &[c_int.into()], [indices]) .unwrap() .outputs_arr(); - let int_read_op = fb + let [idx2] = fb .add_dataflow_op(read_op(&ext, i64_t()), [indices, idx]) - .unwrap(); - let [idx2] = int_read_op.outputs_arr(); + .unwrap() + .outputs_arr(); let mut cfg = fb .cfg_builder([(i64_t(), idx2), (c_bool.clone(), bools)], bool_t().into()) .unwrap(); @@ -596,18 +588,10 @@ mod test { // We do not update the extension delta h.validate_no_extensions().unwrap(); - let mut counts: HashMap<_, usize> = HashMap::new(); - for ext_op in h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()) { - *(counts.entry(ext_op.def().name().as_str()).or_default()) += 1; - } + let ext_ops = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); assert_eq!( - counts, - HashMap::from([ - ("lowered_read_bool", 1), - ("itousize", 1), - ("get", 1), - ("panic", 1) - ]) + ext_ops.map(|e| e.def().name()).sorted().collect_vec(), + ["get", "itousize", "lowered_read_bool", "panic",] ); } From 6f8f43c31646670e86231434336fb82c034650ab Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 15:57:00 +0000 Subject: [PATCH 042/123] No need to validate, run() does it for us --- hugr-passes/src/lower_types.rs | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 3f74755b90..406f34632c 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -420,12 +420,12 @@ mod test { use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{TypeDefBound, Version}; - use hugr_core::hugr::internal::HugrMutInternals; - use hugr_core::ops::{ExtensionOp, OpType, Tag, TailLoop, Value}; + + use hugr_core::ops::{ExtensionOp, OpType, Tag, Value}; use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; use hugr_core::std_extensions::collections::array::{ - self, array_type, ArrayOp, ArrayOpDef, ArrayValue, + array_type, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound}; @@ -585,8 +585,6 @@ mod test { let mut h = mb.finish_hugr().unwrap(); assert!(lower_types(&ext).run(&mut h).unwrap()); - // We do not update the extension delta - h.validate_no_extensions().unwrap(); let ext_ops = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); assert_eq!( @@ -629,7 +627,7 @@ mod test { let mut h = dfb.finish_hugr_with_outputs(cond.outputs()).unwrap(); lower_types(&ext).run(&mut h).unwrap(); - h.validate_no_extensions().unwrap(); + let ext_ops = h .nodes() .filter_map(|n| h.get_optype(n).as_extension_op()) @@ -698,15 +696,7 @@ mod test { }), ); lowerer.run(&mut h).unwrap(); - if cfg!(feature = "extension_inference") { - match h.optype_mut(h.root()) { - OpType::TailLoop(TailLoop { - extension_delta, .. - }) => extension_delta.insert(array::EXTENSION_ID), - _ => panic!(), - } - } - h.validate().unwrap(); + assert_eq!( h.get_optype(pred.node()) .as_load_constant() From e3da25998e941f09ada05300877e309d2e1783f9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 16:05:03 +0000 Subject: [PATCH 043/123] check_sig: use Option::unzip to tuple-ize --- hugr-passes/src/lower_types.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 406f34632c..0d5df4fb22 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -259,9 +259,9 @@ impl LowerTypes { let old_sig = old_sig.into_owned(); let mut expected_sig = old_sig.clone(); expected_sig.transform(self)?; - (Some(old_sig), Some(expected_sig)) + Some((old_sig, expected_sig)) } else { - (None, None) + None }, ) } else { @@ -271,18 +271,19 @@ impl LowerTypes { let new_dfsig = hugr.get_optype(n).dataflow_signature(); // (If check_sig) then verify that the Signature still has the same arity/wires, // with only the expected changes to types within. - if let Some((old, expected)) = maybe_check_sig { - match (&expected, &new_dfsig) { + if let Some(old_and_expected) = maybe_check_sig { + match (&old_and_expected, &new_dfsig) { (None, None) => (), - (Some(exp), Some(act)) + (Some((_, exp)), Some(act)) if exp.input == act.input && exp.output == act.output => {} _ => { + let (old, expected) = old_and_expected.unzip(); return Err(ChangeTypeError::SignatureMismatch { op: hugr.get_optype(n).clone(), old, expected, actual: new_dfsig.map(Cow::into_owned), - }) + }); } }; } From 28f24701448433cbcdf4f863d887e68f7bdceb8d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 12:55:34 +0000 Subject: [PATCH 044/123] Move private utility classes below the pub ones; add comments --- hugr-passes/src/lower_types.rs | 97 ++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 0d5df4fb22..c69e271add 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -21,58 +21,23 @@ use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; -#[derive(Clone, Hash, PartialEq, Eq)] -struct OpHashWrapper { - ext_name: ExtensionId, - op_name: String, // Only because SmolStr not in hugr-passes yet - args: Vec, -} - -impl From<&ExtensionOp> for OpHashWrapper { - fn from(op: &ExtensionOp) -> Self { - Self { - ext_name: op.def().extension_id().clone(), - op_name: op.def().name().to_string(), - args: op.args().to_vec(), - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -struct ParametricType(ExtensionId, String); - -impl From<&TypeDef> for ParametricType { - fn from(value: &TypeDef) -> Self { - Self(value.extension_id().clone(), value.name().to_string()) - } -} - -impl From<&CustomType> for ParametricType { - fn from(value: &CustomType) -> Self { - Self(value.extension().clone(), value.name().to_string()) - } -} - -// Separate from above for clarity -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -struct ParametricOp(ExtensionId, String); - -impl From<&OpDef> for ParametricOp { - fn from(value: &OpDef) -> Self { - Self(value.extension_id().clone(), value.name().to_string()) - } -} - +/// A thing to which an Op can be lowered, i.e. with which a node can be replaced. #[derive(Clone, Debug, PartialEq)] pub enum OpReplacement { + /// Keep the same node (inputs/outputs, modulo lowering of types therein), change only the op SingleOp(OpType), - /// Defines a sub-Hugr to splice in place of the op. + /// Defines a sub-Hugr to splice in place of the op - a [CFG](OpType::CFG), + /// [Conditional](OpType::Conditional) or [DFG](OpType::DFG), which must have + /// the same (lowered) inputs and outputs as the original op. + // Not a FuncDefn, nor Case/DataflowBlock /// Note this will be of limited use before [monomorphization](super::monomorphize()) because /// the sub-Hugr will not be able to use type variables present in the op. // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO add Call to...a Node in the existing Hugr (?) + // TODO allow also Call to a Node in the existing Hugr + // (can't see any other way to achieve multiple calls to the same decl. + // So client should add the functions before lowering, then remove unused ones afterwards.) } impl OpReplacement { @@ -182,7 +147,7 @@ impl LowerTypes { // We could require copy/discard-generators if src is Copyable, or *might be* // (depending on arguments - i.e. if src's TypeDefBound is anything other than // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying - // overapproximation. + // overapproximation. Moreover, these depend upon the *return type* of the Fn. self.param_types.insert(src.into(), Arc::from(dest_fn)); } @@ -410,6 +375,48 @@ impl LowerTypes { } } +#[derive(Clone, Hash, PartialEq, Eq)] +struct OpHashWrapper { + ext_name: ExtensionId, + op_name: String, // Only because SmolStr not in hugr-passes yet + args: Vec, +} + +impl From<&ExtensionOp> for OpHashWrapper { + fn from(op: &ExtensionOp) -> Self { + Self { + ext_name: op.def().extension_id().clone(), + op_name: op.def().name().to_string(), + args: op.args().to_vec(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct ParametricType(ExtensionId, String); + +impl From<&TypeDef> for ParametricType { + fn from(value: &TypeDef) -> Self { + Self(value.extension_id().clone(), value.name().to_string()) + } +} + +impl From<&CustomType> for ParametricType { + fn from(value: &CustomType) -> Self { + Self(value.extension().clone(), value.name().to_string()) + } +} + +// Separate from above for clarity +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct ParametricOp(ExtensionId, String); + +impl From<&OpDef> for ParametricOp { + fn from(value: &OpDef) -> Self { + Self(value.extension_id().clone(), value.name().to_string()) + } +} + #[cfg(test)] mod test { use std::sync::Arc; From bef90b047378397869c4a1847f718d8af13ef393 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Mar 2025 19:22:16 +0000 Subject: [PATCH 045/123] Comments - all callbacks return Option --- hugr-passes/src/lower_types.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index c69e271add..87a13bd271 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -66,7 +66,7 @@ pub struct LowerTypes { /// ArrayOfCopyables(T1). This would require an additional entry for that. type_map: HashMap, /// Parametric types are handled by a function which receives the lowered typeargs. - param_types: HashMap Type>>, + param_types: HashMap Option>>, // Handles simple cases Op1 -> Op2. op_map: HashMap, // Called after lowering typeargs; return None to use original OpDef @@ -89,7 +89,7 @@ impl TypeTransformer for LowerTypes { nargs .iter_mut() .try_for_each(|ta| ta.transform(self).map(|_ch| ()))?; - Some(dest_fn(&nargs)) + dest_fn(&nargs) } else { None }) @@ -141,7 +141,7 @@ impl LowerTypes { pub fn lower_parametric_type( &mut self, src: &TypeDef, - dest_fn: Box Type>, // TODO should we return Option ? + dest_fn: Box Option>, ) { // No way to check that dest_fn never produces a linear type. // We could require copy/discard-generators if src is Copyable, or *might be* @@ -170,19 +170,22 @@ impl LowerTypes { pub fn lower_parametric_op( &mut self, src: &OpDef, - dest_fn: Box Option>, // TODO or just OpReplacement? + dest_fn: Box Option>, ) { self.param_ops.insert(src.into(), Arc::from(dest_fn)); } /// Configures this instance to change occurrences consts of type `src_ty`, using - /// a callback given the value of the constant (of that type). + /// a callback given the value of the constant (of that type). (The callback may + /// return `None` to indicate nothing has changed; we assume `Some` means something + /// has changed when evaluating the `bool` result of [Self::run].) + /// /// Note that if `src_ty` is an instance of a *parametrized* [TypeDef], this /// takes precedence over [Self::lower_consts_parametric] where the `src_ty`s overlap. pub fn lower_consts( &mut self, src_ty: &CustomType, - const_fn: Box Option>, // TODO should we return Value?? + const_fn: Box Option>, ) { self.consts .insert(Either::Left(src_ty.clone()), Arc::from(const_fn)); @@ -526,7 +529,7 @@ mod test { let [TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs") }; - array_type(64, ty.clone()) + Some(array_type(64, ty.clone())) }), ); lw.lower_op( @@ -687,7 +690,7 @@ mod test { let [TypeArg::Type { ty }] = args else { panic!("Expected elem type") }; - array_type(4, ty.clone()) + Some(array_type(4, ty.clone())) }), ); lowerer.lower_consts_parametric( From 85df7ff684c28382e295b06b20d3e80e137f937d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 12:10:45 +0000 Subject: [PATCH 046/123] test: Rename lower_types to lowerer --- hugr-passes/src/lower_types.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 87a13bd271..4a695e1f46 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -494,7 +494,7 @@ mod test { ) } - fn lower_types(ext: &Arc) -> LowerTypes { + fn lowerer(ext: &Arc) -> LowerTypes { fn lowered_read(args: &[TypeArg]) -> Option { let [TypeArg::Type { ty }] = args else { panic!("Illegal TypeArgs") @@ -595,7 +595,7 @@ mod test { fb.finish_with_outputs(cfg.outputs()).unwrap(); let mut h = mb.finish_hugr().unwrap(); - assert!(lower_types(&ext).run(&mut h).unwrap()); + assert!(lowerer(&ext).run(&mut h).unwrap()); let ext_ops = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); assert_eq!( @@ -637,7 +637,7 @@ mod test { let cond = cb.finish_sub_container().unwrap(); let mut h = dfb.finish_hugr_with_outputs(cond.outputs()).unwrap(); - lower_types(&ext).run(&mut h).unwrap(); + lowerer(&ext).run(&mut h).unwrap(); let ext_ops = h .nodes() From 507f02787bd87d915cba261215ca2fa805817e7d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 12:11:04 +0000 Subject: [PATCH 047/123] Extend loop_const test --- hugr-passes/src/lower_types.rs | 39 ++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 4a695e1f46..535857c1fa 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -439,7 +439,7 @@ mod test { array_type, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; - use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound}; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; @@ -682,7 +682,42 @@ mod test { tl.set_outputs(pred, [bools]).unwrap(); let mut h = tl.finish_hugr().unwrap(); - // Lower List to Array<4,T> so we can use List's handy CustomConst + // 1. Lower List to Array<10, T> UNLESS T is usize_t() or bool_t - this should have no effect + let mut lowerer = LowerTypes::default(); + lowerer.lower_parametric_type( + list_type_def(), + Box::new(|args| { + let [TypeArg::Type { ty }] = args else { + panic!("Expected elem type") + }; + (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) + }), + ); + let backup = h.clone(); + assert!(!lowerer.run(&mut h).unwrap()); + assert_eq!(h, backup); + + //2. Lower List to Array<10, T> UNLESS T is usize_t() - this leaves the Const unchanged + let mut lowerer = LowerTypes::default(); + lowerer.lower_parametric_type( + list_type_def(), + Box::new(|args| { + let [TypeArg::Type { ty }] = args else { + panic!("Expected elem type") + }; + (usize_t() != *ty).then_some(array_type(10, ty.clone())) + }), + ); + assert!(lowerer.run(&mut h).unwrap()); + let sig = h.signature(h.root()).unwrap(); + assert_eq!( + sig.input(), + &TypeRow::from(vec![list_type(usize_t()), array_type(10, bool_t())]) + ); + assert_eq!(sig.input(), sig.output()); + + // 3. Lower all List to Array<4,T> so we can use List's handy CustomConst + let mut h = backup; let mut lowerer = LowerTypes::default(); lowerer.lower_parametric_type( list_type_def(), From bae0ebefc7b9711664569ba2d957fe8d50e6da33 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 12:20:11 +0000 Subject: [PATCH 048/123] Add copy/dup linearization stuff --- hugr-passes/src/lower_types.rs | 109 ++++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 535857c1fa..489c9efdee 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -17,7 +17,7 @@ use hugr_core::ops::{ use hugr_core::types::{ CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, }; -use hugr_core::{Hugr, Node}; +use hugr_core::{Hugr, IncomingPort, Node, OutgoingPort}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -41,6 +41,13 @@ pub enum OpReplacement { } impl OpReplacement { + fn add(&self, hugr: &mut impl HugrMut, parent: Node) -> Node { + match self.clone() { + OpReplacement::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), + OpReplacement::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + } + } + fn replace(&self, hugr: &mut impl HugrMut, n: Node) { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { @@ -67,6 +74,22 @@ pub struct LowerTypes { type_map: HashMap, /// Parametric types are handled by a function which receives the lowered typeargs. param_types: HashMap Option>>, + // Keyed by lowered type, as only needed when there is an op outputting such + copy_discard: HashMap, + // Copy/discard of parametric types handled by a function that receives the new/lowered type. + // We do not allow linearization to "parametrized" non-extension types, at least not + // in one step. We could do that using a trait, but it seems enough of a corner case. + // Instead that can be achieved by *firstly* lowering to a custom linear type, with copy/dup + // inserted; *secondly* by lowering that to the desired non-extension linear type, + // including lowering of the copy/dup operations to...whatever. + copy_discard_parametric: HashMap< + ParametricType, + // TODO should pass &LowerTypes, or at least some way to call copy_op / discard_op, to these + ( + Arc OpReplacement>, + Arc OpReplacement>, + ), + >, // Handles simple cases Op1 -> Op2. op_map: HashMap, // Called after lowering typeargs; return None to use original OpDef @@ -151,6 +174,20 @@ impl LowerTypes { self.param_types.insert(src.into(), Arc::from(dest_fn)); } + pub fn linearize(&mut self, src: Type, copy: OpReplacement, discard: OpReplacement) { + self.copy_discard.insert(src, (copy, discard)); + } + + pub fn linearize_parametric( + &mut self, + src: TypeDef, + copy_fn: Box OpReplacement>, + discard_fn: Box OpReplacement>, + ) { + self.copy_discard_parametric + .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); + } + /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [OpDef], this should only /// be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as substitution @@ -255,10 +292,80 @@ impl LowerTypes { } }; } + let Some(new_sig) = changed.then_some(new_dfsig).flatten().map(Cow::into_owned) else { + continue; + }; + for outp in new_sig.output_ports() { + if new_sig.out_port_type(outp).unwrap().copyable() { + continue; + }; + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() == 1 { + continue; + }; + hugr.disconnect(n, outp); + let typ = new_sig.out_port_type(outp).unwrap(); + if targets.len() == 0 { + let discard = self + .discard_op(typ) + .expect("Don't know how to discard {typ:?}"); // TODO return error + + let disc = discard.add(hugr, hugr.get_parent(n).unwrap()); + hugr.connect(n, outp, disc, 0); + } else { + // TODO return error + let copy = self.copy_op(typ).expect("Don't know how to copy {typ:?}"); + self.do_copy_chain(hugr, n, outp, copy, &targets) + } + } } Ok(changed) } + fn do_copy_chain( + &self, + hugr: &mut impl HugrMut, + mut src_node: Node, + mut src_port: OutgoingPort, + copy: OpReplacement, + inps: &[(Node, IncomingPort)], + ) { + assert!(inps.len() > 1); + // Could sanity-check signature here? + for (tgt_node, tgt_port) in &inps[..inps.len() - 1] { + let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); + hugr.connect(src_node, src_port, n, 0); + hugr.connect(n, 0, *tgt_node, *tgt_port); + (src_node, src_port) = (n, 1.into()); + } + let (tgt_node, tgt_port) = inps.last().unwrap(); + hugr.connect(src_node, src_port, *tgt_node, *tgt_port) + } + + pub fn copy_op(&self, typ: &Type) -> Option { + if let Some((copy, _)) = self.copy_discard.get(typ) { + return Some(copy.clone()); + } + let TypeEnum::Extension(exty) = typ.as_type_enum() else { + return None; + }; + self.copy_discard_parametric + .get(&exty.into()) + .map(|(copy_fn, _)| copy_fn(exty.args())) + } + + pub fn discard_op(&self, typ: &Type) -> Option { + if let Some((_, discard)) = self.copy_discard.get(typ) { + return Some(discard.clone()); + } + let TypeEnum::Extension(exty) = typ.as_type_enum() else { + return None; + }; + self.copy_discard_parametric + .get(&exty.into()) + .map(|(_, discard_fn)| discard_fn(exty.args())) + } + fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) From 8fa79bd53ffbdc823628f3a04287105a027f4034 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 13:18:04 +0000 Subject: [PATCH 049/123] do_copy_chain => insert_copy_discard --- hugr-passes/src/lower_types.rs | 63 +++++++++++++++++----------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 489c9efdee..830133d2e5 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -296,50 +296,49 @@ impl LowerTypes { continue; }; for outp in new_sig.output_ports() { - if new_sig.out_port_type(outp).unwrap().copyable() { - continue; - }; - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() == 1 { - continue; - }; - hugr.disconnect(n, outp); - let typ = new_sig.out_port_type(outp).unwrap(); - if targets.len() == 0 { - let discard = self - .discard_op(typ) - .expect("Don't know how to discard {typ:?}"); // TODO return error - - let disc = discard.add(hugr, hugr.get_parent(n).unwrap()); - hugr.connect(n, outp, disc, 0); - } else { - // TODO return error - let copy = self.copy_op(typ).expect("Don't know how to copy {typ:?}"); - self.do_copy_chain(hugr, n, outp, copy, &targets) + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let typ = new_sig.out_port_type(outp).unwrap(); + self.insert_copy_discard(hugr, n, outp, typ, &targets) + } } } } Ok(changed) } - fn do_copy_chain( + fn insert_copy_discard( &self, hugr: &mut impl HugrMut, mut src_node: Node, mut src_port: OutgoingPort, - copy: OpReplacement, - inps: &[(Node, IncomingPort)], + typ: &Type, + targets: &[(Node, IncomingPort)], ) { - assert!(inps.len() > 1); - // Could sanity-check signature here? - for (tgt_node, tgt_port) in &inps[..inps.len() - 1] { - let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); - hugr.connect(src_node, src_port, n, 0); - hugr.connect(n, 0, *tgt_node, *tgt_port); - (src_node, src_port) = (n, 1.into()); + let (last_node, last_inport) = match targets.last() { + None => { + let discard = self + .discard_op(typ) + .expect("Don't know how to discard {typ:?}"); // TODO return error + + let disc = discard.add(hugr, hugr.get_parent(src_node).unwrap()); + (disc, 0.into()) + } + Some(last) => *last, + }; + if targets.len() > 1 { + let copy = self.copy_op(typ).expect("Don't know how copy {typ:?"); // TODO return error + // Could sanity-check signature here? + for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { + let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); + hugr.connect(src_node, src_port, n, 0); + hugr.connect(n, 0, *tgt_node, *tgt_port); + (src_node, src_port) = (n, 1.into()); + } } - let (tgt_node, tgt_port) = inps.last().unwrap(); - hugr.connect(src_node, src_port, *tgt_node, *tgt_port) + hugr.connect(src_node, src_port, last_node, last_inport); } pub fn copy_op(&self, typ: &Type) -> Option { From 6a186a54e49ec5d972f99786fb80fa060dbf3b30 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 12:21:31 +0000 Subject: [PATCH 050/123] Move insert_copy_discard into Linearizer --- hugr-passes/src/lower_types.rs | 85 ++------------------- hugr-passes/src/lower_types/linearize.rs | 94 ++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 77 deletions(-) create mode 100644 hugr-passes/src/lower_types/linearize.rs diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 830133d2e5..4763907d4f 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -17,10 +17,13 @@ use hugr_core::ops::{ use hugr_core::types::{ CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, }; -use hugr_core::{Hugr, IncomingPort, Node, OutgoingPort}; +use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; +mod linearize; +pub use linearize::Linearizer; + /// A thing to which an Op can be lowered, i.e. with which a node can be replaced. #[derive(Clone, Debug, PartialEq)] pub enum OpReplacement { @@ -74,22 +77,7 @@ pub struct LowerTypes { type_map: HashMap, /// Parametric types are handled by a function which receives the lowered typeargs. param_types: HashMap Option>>, - // Keyed by lowered type, as only needed when there is an op outputting such - copy_discard: HashMap, - // Copy/discard of parametric types handled by a function that receives the new/lowered type. - // We do not allow linearization to "parametrized" non-extension types, at least not - // in one step. We could do that using a trait, but it seems enough of a corner case. - // Instead that can be achieved by *firstly* lowering to a custom linear type, with copy/dup - // inserted; *secondly* by lowering that to the desired non-extension linear type, - // including lowering of the copy/dup operations to...whatever. - copy_discard_parametric: HashMap< - ParametricType, - // TODO should pass &LowerTypes, or at least some way to call copy_op / discard_op, to these - ( - Arc OpReplacement>, - Arc OpReplacement>, - ), - >, + linearize: Linearizer, // Handles simple cases Op1 -> Op2. op_map: HashMap, // Called after lowering typeargs; return None to use original OpDef @@ -175,7 +163,7 @@ impl LowerTypes { } pub fn linearize(&mut self, src: Type, copy: OpReplacement, discard: OpReplacement) { - self.copy_discard.insert(src, (copy, discard)); + self.linearize.register(src, copy, discard) } pub fn linearize_parametric( @@ -184,8 +172,7 @@ impl LowerTypes { copy_fn: Box OpReplacement>, discard_fn: Box OpReplacement>, ) { - self.copy_discard_parametric - .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); + self.linearize.register_parametric(src, copy_fn, discard_fn) } /// Configures this instance to change occurrences of `src` to `dest`. @@ -301,7 +288,7 @@ impl LowerTypes { if targets.len() != 1 { hugr.disconnect(n, outp); let typ = new_sig.out_port_type(outp).unwrap(); - self.insert_copy_discard(hugr, n, outp, typ, &targets) + self.linearize.insert_copy_discard(hugr, n, outp, typ, &targets) } } } @@ -309,62 +296,6 @@ impl LowerTypes { Ok(changed) } - fn insert_copy_discard( - &self, - hugr: &mut impl HugrMut, - mut src_node: Node, - mut src_port: OutgoingPort, - typ: &Type, - targets: &[(Node, IncomingPort)], - ) { - let (last_node, last_inport) = match targets.last() { - None => { - let discard = self - .discard_op(typ) - .expect("Don't know how to discard {typ:?}"); // TODO return error - - let disc = discard.add(hugr, hugr.get_parent(src_node).unwrap()); - (disc, 0.into()) - } - Some(last) => *last, - }; - if targets.len() > 1 { - let copy = self.copy_op(typ).expect("Don't know how copy {typ:?"); // TODO return error - // Could sanity-check signature here? - for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { - let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); - hugr.connect(src_node, src_port, n, 0); - hugr.connect(n, 0, *tgt_node, *tgt_port); - (src_node, src_port) = (n, 1.into()); - } - } - hugr.connect(src_node, src_port, last_node, last_inport); - } - - pub fn copy_op(&self, typ: &Type) -> Option { - if let Some((copy, _)) = self.copy_discard.get(typ) { - return Some(copy.clone()); - } - let TypeEnum::Extension(exty) = typ.as_type_enum() else { - return None; - }; - self.copy_discard_parametric - .get(&exty.into()) - .map(|(copy_fn, _)| copy_fn(exty.args())) - } - - pub fn discard_op(&self, typ: &Type) -> Option { - if let Some((_, discard)) = self.copy_discard.get(typ) { - return Some(discard.clone()); - } - let TypeEnum::Extension(exty) = typ.as_type_enum() else { - return None; - }; - self.copy_discard_parametric - .get(&exty.into()) - .map(|(_, discard_fn)| discard_fn(exty.args())) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs new file mode 100644 index 0000000000..6acc8dfd21 --- /dev/null +++ b/hugr-passes/src/lower_types/linearize.rs @@ -0,0 +1,94 @@ +use std::{collections::HashMap, sync::Arc}; + +use hugr_core::{extension::TypeDef, hugr::hugrmut::HugrMut, types::{Type, TypeArg, TypeEnum}, IncomingPort, Node, OutgoingPort}; + +use super::{OpReplacement, ParametricType}; + +#[derive(Clone, Default)] +pub struct Linearizer { + // Keyed by lowered type, as only needed when there is an op outputting such + copy_discard: HashMap, + // Copy/discard of parametric types handled by a function that receives the new/lowered type. + // We do not allow linearization to "parametrized" non-extension types, at least not + // in one step. We could do that using a trait, but it seems enough of a corner case. + // Instead that can be achieved by *firstly* lowering to a custom linear type, with copy/dup + // inserted; *secondly* by lowering that to the desired non-extension linear type, + // including lowering of the copy/dup operations to...whatever. + copy_discard_parametric: HashMap< + ParametricType, + // TODO should pass &LowerTypes, or at least some way to call copy_op / discard_op, to these + ( + Arc OpReplacement>, + Arc OpReplacement>, + ), + >, +} + +impl Linearizer { + pub fn register(&mut self, typ: Type, copy: OpReplacement, discard: OpReplacement) { + self.copy_discard.insert(typ, (copy, discard)); + } + + pub fn register_parametric(&mut self,src: TypeDef, + copy_fn: Box OpReplacement>, + discard_fn: Box OpReplacement>) { + self.copy_discard_parametric + .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); + } + + pub fn insert_copy_discard( + &self, + hugr: &mut impl HugrMut, + mut src_node: Node, + mut src_port: OutgoingPort, + typ: &Type, + targets: &[(Node, IncomingPort)], + ) { + let (last_node, last_inport) = match targets.last() { + None => { + let discard = self + .discard_op(typ) + .expect("Don't know how to discard {typ:?}"); // TODO return error + + let disc = discard.add(hugr, hugr.get_parent(src_node).unwrap()); + (disc, 0.into()) + } + Some(last) => *last, + }; + if targets.len() > 1 { + let copy = self.copy_op(typ).expect("Don't know how copy {typ:?"); // TODO return error + // Could sanity-check signature here? + for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { + let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); + hugr.connect(src_node, src_port, n, 0); + hugr.connect(n, 0, *tgt_node, *tgt_port); + (src_node, src_port) = (n, 1.into()); + } + } + hugr.connect(src_node, src_port, last_node, last_inport); + } + + fn copy_op(&self, typ: &Type) -> Option { + if let Some((copy, _)) = self.copy_discard.get(typ) { + return Some(copy.clone()); + } + let TypeEnum::Extension(exty) = typ.as_type_enum() else { + return None; + }; + self.copy_discard_parametric + .get(&exty.into()) + .map(|(copy_fn, _)| copy_fn(exty.args())) + } + + fn discard_op(&self, typ: &Type) -> Option { + if let Some((_, discard)) = self.copy_discard.get(typ) { + return Some(discard.clone()); + } + let TypeEnum::Extension(exty) = typ.as_type_enum() else { + return None; + }; + self.copy_discard_parametric + .get(&exty.into()) + .map(|(_, discard_fn)| discard_fn(exty.args())) + } +} \ No newline at end of file From 6d8a89b36b906364ad08effa86b64a654ff0f2e1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 13:29:30 +0000 Subject: [PATCH 051/123] Add LinearizeError --- hugr-passes/src/lower_types.rs | 8 +-- hugr-passes/src/lower_types/linearize.rs | 67 +++++++++++++++--------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 4763907d4f..20ed08721c 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -22,7 +22,7 @@ use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; mod linearize; -pub use linearize::Linearizer; +pub use linearize::{LinearizeError, Linearizer}; /// A thing to which an Op can be lowered, i.e. with which a node can be replaced. #[derive(Clone, Debug, PartialEq)] @@ -120,8 +120,9 @@ pub enum ChangeTypeError { actual: Option, }, #[error(transparent)] - #[allow(missing_docs)] ValidationError(#[from] ValidatePassError), + #[error(transparent)] + LinearizeError(#[from] LinearizeError), } impl LowerTypes { @@ -288,7 +289,8 @@ impl LowerTypes { if targets.len() != 1 { hugr.disconnect(n, outp); let typ = new_sig.out_port_type(outp).unwrap(); - self.linearize.insert_copy_discard(hugr, n, outp, typ, &targets) + self.linearize + .insert_copy_discard(hugr, n, outp, typ, &targets)?; } } } diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 6acc8dfd21..fd4d93f7e4 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -1,6 +1,11 @@ use std::{collections::HashMap, sync::Arc}; -use hugr_core::{extension::TypeDef, hugr::hugrmut::HugrMut, types::{Type, TypeArg, TypeEnum}, IncomingPort, Node, OutgoingPort}; +use hugr_core::{ + extension::TypeDef, + hugr::hugrmut::HugrMut, + types::{Type, TypeArg, TypeEnum}, + IncomingPort, Node, OutgoingPort, +}; use super::{OpReplacement, ParametricType}; @@ -24,14 +29,25 @@ pub struct Linearizer { >, } +#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +pub enum LinearizeError { + #[error("Need copy op for {_0}")] + NeedCopy(Type), + #[error("Need discard op for {_0}")] + NeedDiscard(Type), +} + impl Linearizer { pub fn register(&mut self, typ: Type, copy: OpReplacement, discard: OpReplacement) { self.copy_discard.insert(typ, (copy, discard)); } - pub fn register_parametric(&mut self,src: TypeDef, + pub fn register_parametric( + &mut self, + src: TypeDef, copy_fn: Box OpReplacement>, - discard_fn: Box OpReplacement>) { + discard_fn: Box OpReplacement>, + ) { self.copy_discard_parametric .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); } @@ -41,54 +57,55 @@ impl Linearizer { hugr: &mut impl HugrMut, mut src_node: Node, mut src_port: OutgoingPort, - typ: &Type, + typ: &Type, // Or better to get the signature ourselves?? targets: &[(Node, IncomingPort)], - ) { + ) -> Result<(), LinearizeError> { let (last_node, last_inport) = match targets.last() { None => { - let discard = self - .discard_op(typ) - .expect("Don't know how to discard {typ:?}"); // TODO return error - - let disc = discard.add(hugr, hugr.get_parent(src_node).unwrap()); - (disc, 0.into()) + let parent = hugr.get_parent(src_node).unwrap(); + (self.discard_op(typ)?.add(hugr, parent), 0.into()) } Some(last) => *last, }; if targets.len() > 1 { - let copy = self.copy_op(typ).expect("Don't know how copy {typ:?"); // TODO return error - // Could sanity-check signature here? + let copy_op = self.copy_op(typ)?; + for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { - let n = copy.add(hugr, hugr.get_parent(src_node).unwrap()); + let n = copy_op.add(hugr, hugr.get_parent(src_node).unwrap()); hugr.connect(src_node, src_port, n, 0); hugr.connect(n, 0, *tgt_node, *tgt_port); (src_node, src_port) = (n, 1.into()); } } hugr.connect(src_node, src_port, last_node, last_inport); + Ok(()) } - fn copy_op(&self, typ: &Type) -> Option { + fn copy_op(&self, typ: &Type) -> Result { if let Some((copy, _)) = self.copy_discard.get(typ) { - return Some(copy.clone()); + return Ok(copy.clone()); } let TypeEnum::Extension(exty) = typ.as_type_enum() else { - return None; + todo!() // handle sums, etc.... }; - self.copy_discard_parametric + let (copy_fn, _) = self + .copy_discard_parametric .get(&exty.into()) - .map(|(copy_fn, _)| copy_fn(exty.args())) + .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; + Ok(copy_fn(exty.args())) } - fn discard_op(&self, typ: &Type) -> Option { + fn discard_op(&self, typ: &Type) -> Result { if let Some((_, discard)) = self.copy_discard.get(typ) { - return Some(discard.clone()); + return Ok(discard.clone()); } let TypeEnum::Extension(exty) = typ.as_type_enum() else { - return None; + todo!() // handle sums, etc... }; - self.copy_discard_parametric + let (_, discard_fn) = self + .copy_discard_parametric .get(&exty.into()) - .map(|(_, discard_fn)| discard_fn(exty.args())) + .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; + Ok(discard_fn(exty.args())) } -} \ No newline at end of file +} From 3f0f1e63291ec6aa479100354bc20da6367f3bff Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 13:40:00 +0000 Subject: [PATCH 052/123] pass Linearizer to callbacks --- hugr-passes/src/lower_types.rs | 4 ++-- hugr-passes/src/lower_types/linearize.rs | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 20ed08721c..32e5f4f696 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -170,8 +170,8 @@ impl LowerTypes { pub fn linearize_parametric( &mut self, src: TypeDef, - copy_fn: Box OpReplacement>, - discard_fn: Box OpReplacement>, + copy_fn: Box OpReplacement>, + discard_fn: Box OpReplacement>, ) { self.linearize.register_parametric(src, copy_fn, discard_fn) } diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index fd4d93f7e4..190b501bb0 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -21,10 +21,9 @@ pub struct Linearizer { // including lowering of the copy/dup operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - // TODO should pass &LowerTypes, or at least some way to call copy_op / discard_op, to these ( - Arc OpReplacement>, - Arc OpReplacement>, + Arc OpReplacement>, + Arc OpReplacement>, ), >, } @@ -45,8 +44,8 @@ impl Linearizer { pub fn register_parametric( &mut self, src: TypeDef, - copy_fn: Box OpReplacement>, - discard_fn: Box OpReplacement>, + copy_fn: Box OpReplacement>, + discard_fn: Box OpReplacement>, ) { self.copy_discard_parametric .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); @@ -92,7 +91,7 @@ impl Linearizer { .copy_discard_parametric .get(&exty.into()) .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - Ok(copy_fn(exty.args())) + Ok(copy_fn(exty.args(), self)) } fn discard_op(&self, typ: &Type) -> Result { @@ -106,6 +105,6 @@ impl Linearizer { .copy_discard_parametric .get(&exty.into()) .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; - Ok(discard_fn(exty.args())) + Ok(discard_fn(exty.args(), self)) } } From 974a25b825c37777a3ea33237ff47d3bd00e1271 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Mar 2025 19:14:21 +0000 Subject: [PATCH 053/123] comments --- hugr-passes/src/lower_types.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 32e5f4f696..c26157f1cc 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -160,19 +160,45 @@ impl LowerTypes { // (depending on arguments - i.e. if src's TypeDefBound is anything other than // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying // overapproximation. Moreover, these depend upon the *return type* of the Fn. + // We could take an + // `dyn Fn(&TypeArg) -> (Type, Fn(&Linearizer) -> OpReplacement, Fn(&Linearizer) -> OpReplacement))` + // but that seems too awkward. self.param_types.insert(src.into(), Arc::from(dest_fn)); } + /// Configures this instance that, when an outport of type `src` has other than one connected + /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. + /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; + /// `discard` should have exactly one inport, of type 'src', and no outports.) + /// + /// To clarify, these are used if `src` is not [Copyable], but is (perhaps contained in) the + /// result of lonering a type that was either copied or discarded in the input Hugr. + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable pub fn linearize(&mut self, src: Type, copy: OpReplacement, discard: OpReplacement) { + // We could raise an error if src's bound is Copyable? self.linearize.register(src, copy, discard) } + /// Configures this instance that when lowering produces an outport which + /// * has type an instantiation of the parametric type `src`, and + /// * is not [Copyable](hugr_core::types::TypeBound::Copyable), and + /// * has other than one connected inport, + /// + /// ...then these functions should be used to generate `copy` or `discard` ops. + /// + /// (That is, this is the equivalent of [Self::linearize] but for parametric types.) + /// + /// The [Linearizer] is passed so that the callbacks can use this to generate + /// `copy/`discard` ops for other types (e.g. the elements of a collection), + /// as part of an [OpReplacement::CompoundOp]. pub fn linearize_parametric( &mut self, src: TypeDef, copy_fn: Box OpReplacement>, discard_fn: Box OpReplacement>, ) { + // We could raise an error if src's TypeDefBound is explicit Copyable ? self.linearize.register_parametric(src, copy_fn, discard_fn) } From 2bbf6632de939e2734b08d0b8f49c09c4b881230 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 21 Mar 2025 15:57:50 +0000 Subject: [PATCH 054/123] first test, don't linearize root node outputs, reject nonlocal edges --- hugr-passes/src/lower_types.rs | 6 +- hugr-passes/src/lower_types/linearize.rs | 126 ++++++++++++++++++++++- 2 files changed, 130 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index c26157f1cc..51c7b848d1 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -306,7 +306,11 @@ impl LowerTypes { } }; } - let Some(new_sig) = changed.then_some(new_dfsig).flatten().map(Cow::into_owned) else { + let Some(new_sig) = (changed && n != hugr.root()) + .then_some(new_dfsig) + .flatten() + .map(Cow::into_owned) + else { continue; }; for outp in new_sig.output_ports() { diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 190b501bb0..4a855b9228 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -22,7 +22,7 @@ pub struct Linearizer { copy_discard_parametric: HashMap< ParametricType, ( - Arc OpReplacement>, + Arc OpReplacement>, // TODO return Result<...,LinearizeError> ? Arc OpReplacement>, ), >, @@ -34,6 +34,13 @@ pub enum LinearizeError { NeedCopy(Type), #[error("Need discard op for {_0}")] NeedDiscard(Type), + #[error("Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})")] + NoLinearNonLocalEdges { + src: Node, + src_parent: Node, + tgt: Node, + tgt_parent: Node, + }, } impl Linearizer { @@ -51,6 +58,12 @@ impl Linearizer { .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); } + /// Insert copy or discard operations (as appropriate) enough to wire `src_port` of `src_node` + /// up to all `targets`. + /// + /// # Errors + /// + /// If needed copy or discard ops cannot be found; pub fn insert_copy_discard( &self, hugr: &mut impl HugrMut, @@ -66,7 +79,26 @@ impl Linearizer { } Some(last) => *last, }; + if targets.len() > 1 { + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + let src_parent = hugr + .get_parent(src_node) + .expect("Root node cannot have out edges"); + if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { + let tgt_parent = hugr + .get_parent(*tgt) + .expect("Root node cannot have incoming edges"); + (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) + }) { + return Err(LinearizeError::NoLinearNonLocalEdges { + src: src_node, + src_parent, + tgt, + tgt_parent, + }); + } + let copy_op = self.copy_op(typ)?; for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { @@ -108,3 +140,95 @@ impl Linearizer { Ok(discard_fn(exty.args(), self)) } } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::hugr::IdentList; + use hugr_core::ops::{ExtensionOp, NamedOp, OpName}; + use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; + use hugr_core::types::{Type, TypeEnum}; + use hugr_core::{extension::prelude::usize_t, types::Signature}; + use hugr_core::{Extension, HugrView}; + + use crate::lower_types::OpReplacement; + use crate::LowerTypes; + + #[test] + fn single_values() { + // Extension with a linear type, a copy and discard op + let e = Extension::new_arc( + IdentList::new_unchecked("TestExt"), + Version::new(0, 0, 0), + |e, w| { + let lin = Type::new_extension( + e.add_type("Lin".into(), vec![], String::new(), TypeDefBound::any(), w) + .unwrap() + .instantiate([]) + .unwrap(), + ); + e.add_op( + "copy".into(), + String::new(), + Signature::new(lin.clone(), vec![lin.clone(); 2]), + w, + ) + .unwrap(); + e.add_op( + "discard".into(), + String::new(), + Signature::new(lin, vec![]), + w, + ) + .unwrap(); + }, + ); + let lin_t = Type::new_extension(e.get_type("Lin").unwrap().instantiate([]).unwrap()); + + // Configure to lower usize_t to the linear type above + let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), []).unwrap(); + let discard_op = ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(); + let mut lowerer = LowerTypes::default(); + let TypeEnum::Extension(usize_custom_t) = usize_t().as_type_enum().clone() else { + panic!() + }; + lowerer.lower_type(usize_custom_t, lin_t.clone()); + lowerer.linearize( + lin_t, + OpReplacement::SingleOp(copy_op.into()), + OpReplacement::SingleOp(discard_op.into()), + ); + + // Build Hugr - uses first input three times, discards second input (both usize) + let mut outer = DFGBuilder::new(Signature::new( + vec![usize_t(); 2], + vec![usize_t(), array_type(2, usize_t())], + )) + .unwrap(); + let [inp, _] = outer.input_wires_arr(); + let new_array = outer + .add_dataflow_op(ArrayOpDef::new_array.to_concrete(usize_t(), 2), [inp, inp]) + .unwrap(); + let [arr] = new_array.outputs_arr(); + let mut h = outer.finish_hugr_with_outputs([inp, arr]).unwrap(); + + assert!(lowerer.run(&mut h).unwrap()); + + let ext_ops = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); + let mut counts = HashMap::::new(); + for e in ext_ops { + *counts.entry(e.name()).or_default() += 1; + } + assert_eq!( + counts, + HashMap::from([ + ("TestExt.copy".into(), 2), + ("TestExt.discard".into(), 1), + ("collections.array.new_array".into(), 1) + ]) + ); + } +} From 290a8aeb1bdce6ad609ee114c4a0d856be754263 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sun, 23 Mar 2025 10:42:21 +0000 Subject: [PATCH 055/123] copy_fn/discard_fn return Result --- hugr-passes/src/lower_types.rs | 4 ++-- hugr-passes/src/lower_types/linearize.rs | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 51c7b848d1..c0ea15fa94 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -195,8 +195,8 @@ impl LowerTypes { pub fn linearize_parametric( &mut self, src: TypeDef, - copy_fn: Box OpReplacement>, - discard_fn: Box OpReplacement>, + copy_fn: Box Result>, + discard_fn: Box Result>, ) { // We could raise an error if src's TypeDefBound is explicit Copyable ? self.linearize.register_parametric(src, copy_fn, discard_fn) diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 4a855b9228..406b828111 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -22,8 +22,8 @@ pub struct Linearizer { copy_discard_parametric: HashMap< ParametricType, ( - Arc OpReplacement>, // TODO return Result<...,LinearizeError> ? - Arc OpReplacement>, + Arc Result>, + Arc Result>, ), >, } @@ -51,8 +51,8 @@ impl Linearizer { pub fn register_parametric( &mut self, src: TypeDef, - copy_fn: Box OpReplacement>, - discard_fn: Box OpReplacement>, + copy_fn: Box Result>, + discard_fn: Box Result>, ) { self.copy_discard_parametric .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); @@ -123,7 +123,7 @@ impl Linearizer { .copy_discard_parametric .get(&exty.into()) .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - Ok(copy_fn(exty.args(), self)) + copy_fn(exty.args(), self) } fn discard_op(&self, typ: &Type) -> Result { @@ -137,7 +137,7 @@ impl Linearizer { .copy_discard_parametric .get(&exty.into()) .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; - Ok(discard_fn(exty.args(), self)) + discard_fn(exty.args(), self) } } From 19b14427ed40d69cdb87c3171d62912154db2926 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sun, 23 Mar 2025 18:47:55 +0000 Subject: [PATCH 056/123] linearize takes &TypeDef not TypeDef --- hugr-passes/src/lower_types.rs | 2 +- hugr-passes/src/lower_types/linearize.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index c0ea15fa94..61954f504c 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -194,7 +194,7 @@ impl LowerTypes { /// as part of an [OpReplacement::CompoundOp]. pub fn linearize_parametric( &mut self, - src: TypeDef, + src: &TypeDef, copy_fn: Box Result>, discard_fn: Box Result>, ) { diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 406b828111..b7cddd25af 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -50,12 +50,12 @@ impl Linearizer { pub fn register_parametric( &mut self, - src: TypeDef, + src: &TypeDef, copy_fn: Box Result>, discard_fn: Box Result>, ) { self.copy_discard_parametric - .insert((&src).into(), (Arc::from(copy_fn), Arc::from(discard_fn))); + .insert(src.into(), (Arc::from(copy_fn), Arc::from(discard_fn))); } /// Insert copy or discard operations (as appropriate) enough to wire `src_port` of `src_node` From 269f0f141c7b5e2682134f4bcd63f4aae10c605e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Mar 2025 11:37:54 +0000 Subject: [PATCH 057/123] OpReplacement::add(&self, ...) -> add_hugr(self, ...) --- hugr-passes/src/lower_types.rs | 4 ++-- hugr-passes/src/lower_types/linearize.rs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 61954f504c..6b44c9664b 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -44,8 +44,8 @@ pub enum OpReplacement { } impl OpReplacement { - fn add(&self, hugr: &mut impl HugrMut, parent: Node) -> Node { - match self.clone() { + fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + match self { OpReplacement::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), OpReplacement::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, } diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index b7cddd25af..63d3e0873a 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -75,7 +75,7 @@ impl Linearizer { let (last_node, last_inport) = match targets.last() { None => { let parent = hugr.get_parent(src_node).unwrap(); - (self.discard_op(typ)?.add(hugr, parent), 0.into()) + (self.discard_op(typ)?.add_hugr(hugr, parent), 0.into()) } Some(last) => *last, }; @@ -102,7 +102,9 @@ impl Linearizer { let copy_op = self.copy_op(typ)?; for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { - let n = copy_op.add(hugr, hugr.get_parent(src_node).unwrap()); + let n = copy_op + .clone() + .add_hugr(hugr, hugr.get_parent(src_node).unwrap()); hugr.connect(src_node, src_port, n, 0); hugr.connect(n, 0, *tgt_node, *tgt_port); (src_node, src_port) = (n, 1.into()); From f6a1af5648b092e15fa14185228cf673fc34bd66 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 12:35:02 +0000 Subject: [PATCH 058/123] Drop an else-continue --- hugr-passes/src/lower_types.rs | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 6b44c9664b..1d063626e7 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -306,21 +306,20 @@ impl LowerTypes { } }; } - let Some(new_sig) = (changed && n != hugr.root()) + if let Some(new_sig) = (changed && n != hugr.root()) .then_some(new_dfsig) .flatten() .map(Cow::into_owned) - else { - continue; - }; - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let typ = new_sig.out_port_type(outp).unwrap(); - self.linearize - .insert_copy_discard(hugr, n, outp, typ, &targets)?; + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let typ = new_sig.out_port_type(outp).unwrap(); + self.linearize + .insert_copy_discard(hugr, n, outp, typ, &targets)?; + } } } } From ac9adac257f17998894acf7f3428ea3d88904be4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Mar 2025 11:38:55 +0000 Subject: [PATCH 059/123] Add OpReplacement::add for builder --- hugr-passes/src/lower_types.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 1d063626e7..2da1357807 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -3,6 +3,8 @@ use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; +use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; +use hugr_core::ops::handle::DataflowOpID; use itertools::Either; use thiserror::Error; @@ -17,7 +19,7 @@ use hugr_core::ops::{ use hugr_core::types::{ CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, }; -use hugr_core::{Hugr, Node}; +use hugr_core::{Hugr, Node, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -51,6 +53,17 @@ impl OpReplacement { } } + fn add( + self, + dfb: &mut impl Dataflow, + inputs: impl IntoIterator, + ) -> Result, BuildError> { + match self { + OpReplacement::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), + OpReplacement::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + } + } + fn replace(&self, hugr: &mut impl HugrMut, n: Node) { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { From 65eaf521c1661a74b943ec63417c8f082c00c10a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 13:01:36 +0000 Subject: [PATCH 060/123] Handle Sum Types in copy_op/discard_op --- hugr-passes/src/lower_types/linearize.rs | 103 +++++++++++++++++++---- 1 file changed, 85 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 63d3e0873a..b45c73b3c3 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -1,11 +1,14 @@ use std::{collections::HashMap, sync::Arc}; use hugr_core::{ - extension::TypeDef, + builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, + extension::{SignatureError, TypeDef}, hugr::hugrmut::HugrMut, - types::{Type, TypeArg, TypeEnum}, + ops::Tag, + types::{Type, TypeArg, TypeEnum, TypeRow}, IncomingPort, Node, OutgoingPort, }; +use itertools::Itertools; use super::{OpReplacement, ParametricType}; @@ -41,6 +44,13 @@ pub enum LinearizeError { tgt: Node, tgt_parent: Node, }, + /// SignatureError's can happen when converting nested types e.g. Sums + #[error(transparent)] + SignatureError(#[from] SignatureError), + /// Type variables, Row variables, and Aliases are not supported; + /// nor Function types, as these are always Copyable. + #[error("Cannot linearize type {_0}")] + UnsupportedType(Type), } impl Linearizer { @@ -118,28 +128,85 @@ impl Linearizer { if let Some((copy, _)) = self.copy_discard.get(typ) { return Ok(copy.clone()); } - let TypeEnum::Extension(exty) = typ.as_type_enum() else { - todo!() // handle sums, etc.... - }; - let (copy_fn, _) = self - .copy_discard_parametric - .get(&exty.into()) - .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - copy_fn(exty.args(), self) + match typ.as_type_enum() { + TypeEnum::Sum(sum_type) => { + let variants = sum_type + .variants() + .map(|trv| trv.clone().try_into()) + .collect::, _>>()?; + let mut cb = ConditionalBuilder::new( + variants.clone(), + vec![], + vec![sum_type.clone().into(); 2], + ) + .unwrap(); + for (tag, variant) in variants.iter().enumerate() { + let mut case_b = cb.case_builder(tag).unwrap(); + let mut orig_elems = vec![]; + let mut copy_elems = vec![]; + for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { + let [orig_elem, copy_elem] = self + .copy_op(ty)? + .add(&mut case_b, [inp]) + .unwrap() + .outputs_arr(); + orig_elems.push(orig_elem); + copy_elems.push(copy_elem); + } + let t = Tag::new(tag, variants.clone()); + let [orig] = case_b + .add_dataflow_op(t.clone(), orig_elems) + .unwrap() + .outputs_arr(); + let [copy] = case_b.add_dataflow_op(t, copy_elems).unwrap().outputs_arr(); + case_b.finish_with_outputs([orig, copy]).unwrap(); + } + Ok(OpReplacement::CompoundOp(Box::new( + cb.finish_hugr().unwrap(), + ))) + } + TypeEnum::Extension(cty) => { + let (copy_fn, _) = self + .copy_discard_parametric + .get(&cty.into()) + .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; + copy_fn(cty.args(), self) + } + _ => Err(LinearizeError::UnsupportedType(typ.clone())), + } } fn discard_op(&self, typ: &Type) -> Result { if let Some((_, discard)) = self.copy_discard.get(typ) { return Ok(discard.clone()); } - let TypeEnum::Extension(exty) = typ.as_type_enum() else { - todo!() // handle sums, etc... - }; - let (_, discard_fn) = self - .copy_discard_parametric - .get(&exty.into()) - .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; - discard_fn(exty.args(), self) + match typ.as_type_enum() { + TypeEnum::Sum(sum_type) => { + let variants = sum_type + .variants() + .map(|trv| trv.clone().try_into()) + .collect::, _>>()?; + let mut cb = ConditionalBuilder::new(variants.clone(), vec![], vec![]).unwrap(); + for (idx, variant) in variants.into_iter().enumerate() { + let mut case_b = cb.case_builder(idx).unwrap(); + for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { + self.discard_op(ty)?.add(&mut case_b, [inp]).unwrap(); + } + case_b.finish_with_outputs([]).unwrap(); + } + Ok(OpReplacement::CompoundOp(Box::new( + cb.finish_hugr().unwrap(), + ))) + } + TypeEnum::Extension(cty) => { + let (_, discard_fn) = self + .copy_discard_parametric + .get(&cty.into()) + .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; + discard_fn(cty.args(), self) + } + _ => Err(LinearizeError::UnsupportedType(typ.clone())), + } } } From 354a89c47f473dfcf11bdf9d97b8f8f819125e1b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 14:15:23 +0000 Subject: [PATCH 061/123] drop redundant allow-missing_docs --- hugr-passes/src/lower_types.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 535857c1fa..5fe2ac29b1 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -109,7 +109,6 @@ pub enum ChangeTypeError { actual: Option, }, #[error(transparent)] - #[allow(missing_docs)] ValidationError(#[from] ValidatePassError), } From 1b316310057e907f14631d3d55854f7c6551b95c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sun, 23 Mar 2025 17:53:25 +0000 Subject: [PATCH 062/123] Refactor test code, fix --all-features --- hugr-passes/src/lower_types/linearize.rs | 29 +++++++++++++++--------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index b45c73b3c3..52bfe334db 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -213,28 +213,29 @@ impl Linearizer { #[cfg(test)] mod test { use std::collections::HashMap; + use std::sync::Arc; - use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; - use hugr_core::extension::{TypeDefBound, Version}; - use hugr_core::hugr::IdentList; + use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; + + use hugr_core::extension::{prelude::usize_t, TypeDefBound, Version}; use hugr_core::ops::{ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; - use hugr_core::types::{Type, TypeEnum}; - use hugr_core::{extension::prelude::usize_t, types::Signature}; - use hugr_core::{Extension, HugrView}; + use hugr_core::types::{Signature, Type, TypeEnum}; + use hugr_core::{hugr::IdentList, Extension, HugrView}; use crate::lower_types::OpReplacement; use crate::LowerTypes; - #[test] - fn single_values() { + const LIN_T: &str = "Lin"; + + fn ext_lowerer() -> (Arc, LowerTypes) { // Extension with a linear type, a copy and discard op let e = Extension::new_arc( IdentList::new_unchecked("TestExt"), Version::new(0, 0, 0), |e, w| { let lin = Type::new_extension( - e.add_type("Lin".into(), vec![], String::new(), TypeDefBound::any(), w) + e.add_type(LIN_T.into(), vec![], String::new(), TypeDefBound::any(), w) .unwrap() .instantiate([]) .unwrap(), @@ -255,7 +256,8 @@ mod test { .unwrap(); }, ); - let lin_t = Type::new_extension(e.get_type("Lin").unwrap().instantiate([]).unwrap()); + + let lin_t = Type::new_extension(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); // Configure to lower usize_t to the linear type above let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), []).unwrap(); @@ -270,9 +272,14 @@ mod test { OpReplacement::SingleOp(copy_op.into()), OpReplacement::SingleOp(discard_op.into()), ); + (e, lowerer) + } + #[test] + fn single_values() { + let (_e, lowerer) = ext_lowerer(); // Build Hugr - uses first input three times, discards second input (both usize) - let mut outer = DFGBuilder::new(Signature::new( + let mut outer = DFGBuilder::new(inout_sig( vec![usize_t(); 2], vec![usize_t(), array_type(2, usize_t())], )) From d5cbd58309105f78f8bf2f902950cfac4dc4a8a0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 14:14:14 +0000 Subject: [PATCH 063/123] Test copying+discarding an option of two elements --- hugr-passes/src/lower_types/linearize.rs | 63 +++++++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 52bfe334db..1c4a8389e2 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -215,13 +215,19 @@ mod test { use std::collections::HashMap; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; + use hugr_core::builder::{ + endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + }; + use hugr_core::extension::prelude::option_type; use hugr_core::extension::{prelude::usize_t, TypeDefBound, Version}; + use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::{ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; - use hugr_core::types::{Signature, Type, TypeEnum}; + use hugr_core::type_row; + use hugr_core::types::{Signature, Type, TypeEnum, TypeRow}; use hugr_core::{hugr::IdentList, Extension, HugrView}; + use itertools::Itertools; use crate::lower_types::OpReplacement; use crate::LowerTypes; @@ -307,4 +313,57 @@ mod test { ]) ); } + + #[test] + fn sums() { + let (e, lowerer) = ext_lowerer(); + let sum_ty = Type::from(option_type(vec![usize_t(), usize_t()])); + let mut outer = DFGBuilder::new(endo_sig(sum_ty.clone())).unwrap(); + let [inp] = outer.input_wires_arr(); + let inner = outer + .dfg_builder(inout_sig(sum_ty, vec![]), [inp]) + .unwrap() + .finish_with_outputs([]) + .unwrap(); + let mut h = outer.finish_hugr_with_outputs([inp]).unwrap(); + + assert!(lowerer.run(&mut h).unwrap()); + + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let option_ty = Type::from(option_type(vec![lin_t.clone(); 2])); + let copy_out: TypeRow = vec![option_ty.clone(); 2].into(); + let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); + + // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... + for (dfg, num_tags, out_row, ext_op_name) in [ + (inner.node(), 0, type_row![], "TestExt.discard"), + (h.root(), 2, copy_out, "TestExt.copy"), + ] { + let [cond] = h + .children(dfg) + .filter(|n| h.get_optype(*n).is_conditional()) + .collect_array() + .unwrap(); + let [case0, case1] = h.children(cond).collect_array().unwrap(); + // first is for empty + assert_eq!(h.children(case0).count(), 2 + num_tags); // Input, Output + assert_eq!(count_tags(case0), num_tags); + let case0 = h.get_optype(case0).as_case().unwrap(); + assert_eq!(case0.signature.io(), (&vec![].into(), &out_row)); + + // second is for two elements + assert_eq!(h.children(case1).count(), 4 + num_tags); // Input, Output, two leaf copies/discards: + assert_eq!(count_tags(case1), num_tags); + assert_eq!( + h.children(case1) + .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) + .collect_vec(), + vec![ext_op_name; 2] + ); + assert_eq!( + h.get_optype(case1).as_case().unwrap().signature.io(), + (&vec![lin_t.clone(); 2].into(), &out_row) + ); + } + } } From 63bdd3bcead948246acc8ef3efc49ca1c36991ac Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 15:13:59 +0000 Subject: [PATCH 064/123] tidy imports --- hugr-passes/src/lower_types/linearize.rs | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/lower_types/linearize.rs b/hugr-passes/src/lower_types/linearize.rs index 1c4a8389e2..221e1ef43e 100644 --- a/hugr-passes/src/lower_types/linearize.rs +++ b/hugr-passes/src/lower_types/linearize.rs @@ -1,13 +1,9 @@ use std::{collections::HashMap, sync::Arc}; -use hugr_core::{ - builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, - extension::{SignatureError, TypeDef}, - hugr::hugrmut::HugrMut, - ops::Tag, - types::{Type, TypeArg, TypeEnum, TypeRow}, - IncomingPort, Node, OutgoingPort, -}; +use hugr_core::builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}; +use hugr_core::extension::{SignatureError, TypeDef}; +use hugr_core::types::{Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node, OutgoingPort}; use itertools::Itertools; use super::{OpReplacement, ParametricType}; @@ -219,14 +215,12 @@ mod test { endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, }; - use hugr_core::extension::prelude::option_type; - use hugr_core::extension::{prelude::usize_t, TypeDefBound, Version}; - use hugr_core::ops::handle::NodeHandle; - use hugr_core::ops::{ExtensionOp, NamedOp, OpName}; + use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; - use hugr_core::type_row; use hugr_core::types::{Signature, Type, TypeEnum, TypeRow}; - use hugr_core::{hugr::IdentList, Extension, HugrView}; + use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; use crate::lower_types::OpReplacement; From 1ac3728e5153442f2009bb99bc758457217f8e8a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 16:18:43 +0000 Subject: [PATCH 065/123] Combine lower_consts(_parametric) with lower_(parametric_)type --- hugr-passes/src/lower_types.rs | 93 +++++++++++++++------------------- 1 file changed, 42 insertions(+), 51 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 5fe2ac29b1..4bf85f28d5 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -3,7 +3,6 @@ use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; -use itertools::Either; use thiserror::Error; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; @@ -71,7 +70,8 @@ pub struct LowerTypes { op_map: HashMap, // Called after lowering typeargs; return None to use original OpDef param_ops: HashMap Option>>, - consts: HashMap, Arc Option>>, + consts: HashMap Value>>, + param_consts: HashMap Option>>, check_sig: bool, validation: ValidationLevel, } @@ -121,26 +121,41 @@ impl LowerTypes { self } - /// Configures this instance to change occurrences of `src` to `dest`. - /// Note that if `src` is an instance of a *parametrized* Type, this should only - /// be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as substitution - /// (parametric polymorphism) happening later will not respect the lowering(s). + /// Configures this instance to change occurrences of type `src` to `dest`. + /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes + /// precedence over [Self::lower_parametric_type] where the `src`s overlap. Thus, this + /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as + /// substitution (parametric polymorphism) happening later will not respect the lowering(s). /// - /// This takes precedence over [Self::lower_parametric_type] where the `src`s overlap. - pub fn lower_type(&mut self, src: CustomType, dest: Type) { + /// [Const]s of the specified type are transformed using the provided `const_fn` callback, + /// which is passed the value of the constant. + pub fn lower_type( + &mut self, + src: CustomType, + dest: Type, + const_fn: Box Value>, + ) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't // check that for parametric types, we'll be consistent and not check here either. - self.type_map.insert(src, dest); + self.type_map.insert(src.clone(), dest); + self.consts.insert(src, Arc::from(const_fn)); } /// Configures this instance to change occurrences of a parametrized type `src` - /// via a callback that builds the replacement type given the [TypeArg]s. + /// via a callback `dest_fn` that builds the replacement type given the [TypeArg]s. /// Note that the TypeArgs will already have been lowered (e.g. they may not - /// fit the bounds of the original type). + /// fit the bounds of the original type). The callback may return `None` to indicate + /// no change (in which case the TypeArgs will be lowered recursively but the outer + /// type constructor kept the same). + /// + /// Constants whose types are instantiations of `src` will be converted by the `const_fn` + /// callback, given the value of the constant; the callback may return `None` to + /// leave the constant unchanged. pub fn lower_parametric_type( &mut self, src: &TypeDef, dest_fn: Box Option>, + const_fn: Box Option>, ) { // No way to check that dest_fn never produces a linear type. // We could require copy/discard-generators if src is Copyable, or *might be* @@ -148,6 +163,7 @@ impl LowerTypes { // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying // overapproximation. Moreover, these depend upon the *return type* of the Fn. self.param_types.insert(src.into(), Arc::from(dest_fn)); + self.param_consts.insert(src.into(), Arc::from(const_fn)); } /// Configures this instance to change occurrences of `src` to `dest`. @@ -174,34 +190,6 @@ impl LowerTypes { self.param_ops.insert(src.into(), Arc::from(dest_fn)); } - /// Configures this instance to change occurrences consts of type `src_ty`, using - /// a callback given the value of the constant (of that type). (The callback may - /// return `None` to indicate nothing has changed; we assume `Some` means something - /// has changed when evaluating the `bool` result of [Self::run].) - /// - /// Note that if `src_ty` is an instance of a *parametrized* [TypeDef], this - /// takes precedence over [Self::lower_consts_parametric] where the `src_ty`s overlap. - pub fn lower_consts( - &mut self, - src_ty: &CustomType, - const_fn: Box Option>, - ) { - self.consts - .insert(Either::Left(src_ty.clone()), Arc::from(const_fn)); - } - - /// Configures this instance to change occurrences consts of all types that - /// are instances of a parametric typedef `src_ty`, using a callback given - /// the value of the constant (the [OpaqueValue] contains the [TypeArg]s). - pub fn lower_consts_parametric( - &mut self, - src_ty: &TypeDef, - const_fn: Box Option>, - ) { - self.consts - .insert(Either::Right(src_ty.into()), Arc::from(const_fn)); - } - /// Configures this instance to check signatures of ops lowered following [Self::lower_op] /// and [Self::lower_parametric_op] are as expected, i.e. match the signatures of the /// original op modulo the required type substitutions. (If signatures are incorrect, @@ -359,15 +347,14 @@ impl LowerTypes { } Value::Extension { e } => Ok('changed: { if let TypeEnum::Extension(exty) = e.get_type().as_type_enum() { - if let Some(const_fn) = self - .consts - .get(&Either::Left(exty.clone())) - .or(self.consts.get(&Either::Right(exty.into()))) + if let Some(new_const) = + self.consts.get(exty).map(|const_fn| const_fn(e)).or(self + .param_consts + .get(&exty.into()) + .and_then(|const_fn| const_fn(e))) { - if let Some(new_const) = const_fn(e) { - *value = new_const; - break 'changed true; - } + *value = new_const; + break 'changed true; } } false @@ -521,7 +508,11 @@ mod test { } let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = LowerTypes::default(); - lw.lower_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); + lw.lower_type( + pv.instantiate([bool_t().into()]).unwrap(), + i64_t(), + Box::new(|_| panic!("There are no constants")), + ); lw.lower_parametric_type( pv, Box::new(|args: &[TypeArg]| { @@ -530,6 +521,7 @@ mod test { }; Some(array_type(64, ty.clone())) }), + Box::new(|_| panic!("There are no constants?")), ); lw.lower_op( &read_op(ext, bool_t()), @@ -691,6 +683,7 @@ mod test { }; (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) }), + Box::new(|_| None), // leave the List unchanged ); let backup = h.clone(); assert!(!lowerer.run(&mut h).unwrap()); @@ -706,6 +699,7 @@ mod test { }; (usize_t() != *ty).then_some(array_type(10, ty.clone())) }), + Box::new(|_| None), // leave the List unchanged ); assert!(lowerer.run(&mut h).unwrap()); let sig = h.signature(h.root()).unwrap(); @@ -726,9 +720,6 @@ mod test { }; Some(array_type(4, ty.clone())) }), - ); - lowerer.lower_consts_parametric( - list_type_def(), Box::new(|opaq| { let lv = opaq .value() From e76881422f738e1f3cbd24267530495da41cbb99 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 19:06:47 +0000 Subject: [PATCH 066/123] Test sig checking --- hugr-passes/src/lower_types.rs | 131 +++++++++++++++++++++++++++++++-- 1 file changed, 123 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 4bf85f28d5..f69a301a8c 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -414,22 +414,33 @@ mod test { inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; - use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; + use hugr_core::extension::prelude::{ + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + }; use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - + use hugr_core::extension::{SignatureError, TypeDefBound, Version}; + use hugr_core::hugr::ValidationError; + use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::{ExtensionOp, OpType, Tag, Value}; - - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{ array_type, ArrayOp, ArrayOpDef, ArrayValue, }; - use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; - use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; + use hugr_core::std_extensions::collections::list::{ + self, list_type, list_type_def, ListOp, ListValue, + }; + use hugr_core::types::TypeEnum; + use hugr_core::types::{ + type_param::TypeArgError, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, + TypeRow, + }; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use super::{LowerTypes, OpReplacement}; + use crate::validation::ValidatePassError; + + use super::{ChangeTypeError, LowerTypes, OpReplacement}; const PACKED_VEC: &str = "PackedVec"; fn i64_t() -> Type { @@ -743,4 +754,108 @@ mod test { ])) ); } + + #[test] + fn copyable_used_linearly() { + const READ: &str = "read"; + let e = Extension::new_arc( + IdentList::new_unchecked("CopyableReader"), + Version::new(0, 0, 0), + |e, w| { + let params = vec![TypeBound::Copyable.into()]; + let tv = Type::new_var_use(0, TypeBound::Copyable); + let list_of_var = list_type(tv.clone()); + e.add_op( + READ.into(), + String::new(), + PolyFuncType::new(params, Signature::new(vec![list_of_var, usize_t()], tv)), + w, + ) + .unwrap(); + }, + ); + let read = e.get_op(READ).unwrap(); + let i32_t = || INT_TYPES[5].to_owned(); + let TypeEnum::Extension(i32_custom_t) = i32_t().as_type_enum().clone() else { + panic!() + }; + let mut dfb = DFGBuilder::new(inout_sig(list_type(i32_t()), i32_t())).unwrap(); + let [c_in] = dfb.input_wires_arr(); + let idx = dfb.add_load_value(ConstUsize::new(2)); + let read_op = dfb + .add_dataflow_op( + ExtensionOp::new(read.clone(), [i32_t().into()]).unwrap(), + [c_in, idx], + ) + .unwrap(); + let backup = dfb.finish_hugr_with_outputs(read_op.outputs()).unwrap(); + + let mut lowerer = LowerTypes::default(); + lowerer.lower_type( + i32_custom_t, + qb_t(), + Box::new(|_| panic!("There are no constants")), + ); + // That tries to create a read, which is not a legal instantiation + assert_eq!( + lowerer.run(&mut backup.clone()), + Err(ChangeTypeError::SignatureError( + SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { + param: TypeBound::Copyable.into(), + arg: qb_t().into() + }) + )) + ); + // So lower read to the normal list get...(which returns an option not QB) + fn make_list_get(args: &[TypeArg]) -> Option { + let [TypeArg::Type { ty }] = args else { + panic!("Expected just element type") + }; + Some(OpReplacement::SingleOp( + ListOp::get + .with_type(ty.clone()) + .to_extension_op() + .unwrap() + .into(), + )) + } + lowerer.lower_parametric_op(read.as_ref(), Box::new(make_list_get)); + + let res = lowerer.run(&mut backup.clone()); + assert!( + matches!(res, Err(ChangeTypeError::ValidationError(ValidatePassError::OutputError { + err: ValidationError::IncompatiblePorts { + from, .. + }, .. + })) if from == read_op.node()) + ); + + lowerer.check_signatures(true); + let res = lowerer.run(&mut backup.clone()); + assert_eq!( + res, + Err(ChangeTypeError::SignatureMismatch { + op: ListOp::get + .with_type(qb_t()) + .to_extension_op() + .unwrap() + .into(), + old: Some( + Signature::new(vec![list_type(i32_t()), usize_t()], i32_t()) + .with_extension_delta(e.name.clone()) + ), + expected: Some( + Signature::new(vec![list_type(qb_t()), usize_t()], qb_t()) + .with_extension_delta(e.name.clone()) + ), + actual: Some( + Signature::new( + vec![list_type(qb_t()), usize_t()], + Type::from(option_type(qb_t())) + ) + .with_extension_delta(list::EXTENSION_ID) + ) + }) + ); + } } From 506570bc2bdbbd90fa9c9da87359a9a673f54c0b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 19:07:46 +0000 Subject: [PATCH 067/123] No, remove sig checking --- hugr-passes/src/lower_types.rs | 175 ++------------------------------- 1 file changed, 8 insertions(+), 167 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index f69a301a8c..59000e64c8 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -1,5 +1,4 @@ #![allow(clippy::type_complexity)] -use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -72,7 +71,6 @@ pub struct LowerTypes { param_ops: HashMap Option>>, consts: HashMap Value>>, param_consts: HashMap Option>>, - check_sig: bool, validation: ValidationLevel, } @@ -190,15 +188,6 @@ impl LowerTypes { self.param_ops.insert(src.into(), Arc::from(dest_fn)); } - /// Configures this instance to check signatures of ops lowered following [Self::lower_op] - /// and [Self::lower_parametric_op] are as expected, i.e. match the signatures of the - /// original op modulo the required type substitutions. (If signatures are incorrect, - /// it is likely that the wires in the Hugr will be invalid, so this gives an early warning - /// by instead raising [ChangeTypeError::SignatureMismatch].) - pub fn check_signatures(&mut self, check_sig: bool) { - self.check_sig = check_sig; - } - /// Run the pass using specified configuration. pub fn run(&self, hugr: &mut H) -> Result { self.validation @@ -208,40 +197,7 @@ impl LowerTypes { fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { - let maybe_check_sig = if self.check_sig { - Some( - if let Some(old_sig) = hugr.get_optype(n).dataflow_signature() { - let old_sig = old_sig.into_owned(); - let mut expected_sig = old_sig.clone(); - expected_sig.transform(self)?; - Some((old_sig, expected_sig)) - } else { - None - }, - ) - } else { - None - }; changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - // (If check_sig) then verify that the Signature still has the same arity/wires, - // with only the expected changes to types within. - if let Some(old_and_expected) = maybe_check_sig { - match (&old_and_expected, &new_dfsig) { - (None, None) => (), - (Some((_, exp)), Some(act)) - if exp.input == act.input && exp.output == act.output => {} - _ => { - let (old, expected) = old_and_expected.unzip(); - return Err(ChangeTypeError::SignatureMismatch { - op: hugr.get_optype(n).clone(), - old, - expected, - actual: new_dfsig.map(Cow::into_owned), - }); - } - }; - } } Ok(changed) } @@ -414,33 +370,22 @@ mod test { inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; - use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, - }; + use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{SignatureError, TypeDefBound, Version}; - use hugr_core::hugr::ValidationError; - use hugr_core::ops::handle::NodeHandle; + use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::ops::{ExtensionOp, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; - use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, ArrayOp, ArrayOpDef, ArrayValue, }; - use hugr_core::std_extensions::collections::list::{ - self, list_type, list_type_def, ListOp, ListValue, - }; - use hugr_core::types::TypeEnum; - use hugr_core::types::{ - type_param::TypeArgError, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, - TypeRow, - }; + use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; + + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use crate::validation::ValidatePassError; - - use super::{ChangeTypeError, LowerTypes, OpReplacement}; + use super::{LowerTypes, OpReplacement}; const PACKED_VEC: &str = "PackedVec"; fn i64_t() -> Type { @@ -754,108 +699,4 @@ mod test { ])) ); } - - #[test] - fn copyable_used_linearly() { - const READ: &str = "read"; - let e = Extension::new_arc( - IdentList::new_unchecked("CopyableReader"), - Version::new(0, 0, 0), - |e, w| { - let params = vec![TypeBound::Copyable.into()]; - let tv = Type::new_var_use(0, TypeBound::Copyable); - let list_of_var = list_type(tv.clone()); - e.add_op( - READ.into(), - String::new(), - PolyFuncType::new(params, Signature::new(vec![list_of_var, usize_t()], tv)), - w, - ) - .unwrap(); - }, - ); - let read = e.get_op(READ).unwrap(); - let i32_t = || INT_TYPES[5].to_owned(); - let TypeEnum::Extension(i32_custom_t) = i32_t().as_type_enum().clone() else { - panic!() - }; - let mut dfb = DFGBuilder::new(inout_sig(list_type(i32_t()), i32_t())).unwrap(); - let [c_in] = dfb.input_wires_arr(); - let idx = dfb.add_load_value(ConstUsize::new(2)); - let read_op = dfb - .add_dataflow_op( - ExtensionOp::new(read.clone(), [i32_t().into()]).unwrap(), - [c_in, idx], - ) - .unwrap(); - let backup = dfb.finish_hugr_with_outputs(read_op.outputs()).unwrap(); - - let mut lowerer = LowerTypes::default(); - lowerer.lower_type( - i32_custom_t, - qb_t(), - Box::new(|_| panic!("There are no constants")), - ); - // That tries to create a read, which is not a legal instantiation - assert_eq!( - lowerer.run(&mut backup.clone()), - Err(ChangeTypeError::SignatureError( - SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: qb_t().into() - }) - )) - ); - // So lower read to the normal list get...(which returns an option not QB) - fn make_list_get(args: &[TypeArg]) -> Option { - let [TypeArg::Type { ty }] = args else { - panic!("Expected just element type") - }; - Some(OpReplacement::SingleOp( - ListOp::get - .with_type(ty.clone()) - .to_extension_op() - .unwrap() - .into(), - )) - } - lowerer.lower_parametric_op(read.as_ref(), Box::new(make_list_get)); - - let res = lowerer.run(&mut backup.clone()); - assert!( - matches!(res, Err(ChangeTypeError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts { - from, .. - }, .. - })) if from == read_op.node()) - ); - - lowerer.check_signatures(true); - let res = lowerer.run(&mut backup.clone()); - assert_eq!( - res, - Err(ChangeTypeError::SignatureMismatch { - op: ListOp::get - .with_type(qb_t()) - .to_extension_op() - .unwrap() - .into(), - old: Some( - Signature::new(vec![list_type(i32_t()), usize_t()], i32_t()) - .with_extension_delta(e.name.clone()) - ), - expected: Some( - Signature::new(vec![list_type(qb_t()), usize_t()], qb_t()) - .with_extension_delta(e.name.clone()) - ), - actual: Some( - Signature::new( - vec![list_type(qb_t()), usize_t()], - Type::from(option_type(qb_t())) - ) - .with_extension_delta(list::EXTENSION_ID) - ) - }) - ); - } } From 0958e2ac0241178951d358782beb3abf4db8a07b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 21:16:12 +0000 Subject: [PATCH 068/123] Test funky partial replace --- hugr-passes/src/lower_types.rs | 122 ++++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 59000e64c8..d655bf1330 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -9,8 +9,8 @@ use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, - FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, - Value, CFG, DFG, + FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value, + CFG, DFG, }; use hugr_core::types::{ CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, @@ -370,30 +370,38 @@ mod test { inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; - use hugr_core::extension::prelude::{bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder}; + use hugr_core::extension::prelude::{ + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + }; use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{TypeDefBound, Version}; - use hugr_core::ops::{ExtensionOp, OpType, Tag, Value}; + use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, ArrayOp, ArrayOpDef, ArrayValue, }; - use hugr_core::std_extensions::collections::list::{list_type, list_type_def, ListValue}; + use hugr_core::std_extensions::collections::list::{ + list_type, list_type_def, ListOp, ListValue, + }; - use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; + use hugr_core::types::{ + PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow, + }; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; use super::{LowerTypes, OpReplacement}; const PACKED_VEC: &str = "PackedVec"; + const READ: &str = "read"; + fn i64_t() -> Type { INT_TYPES[6].clone() } fn read_op(ext: &Arc, t: Type) -> ExtensionOp { - ExtensionOp::new(ext.get_op("read").unwrap().clone(), [t.into()]).unwrap() + ExtensionOp::new(ext.get_op(READ).unwrap().clone(), [t.into()]).unwrap() } fn ext() -> Arc { @@ -413,7 +421,7 @@ mod test { .instantiate(vec![Type::new_var_use(0, TypeBound::Copyable).into()]) .unwrap(); ext.add_op( - "read".into(), + READ.into(), "".into(), PolyFuncType::new( vec![TypeBound::Copyable.into()], @@ -487,7 +495,7 @@ mod test { .into(), ), ); - lw.lower_parametric_op(ext.get_op("read").unwrap().as_ref(), Box::new(lowered_read)); + lw.lower_parametric_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); lw } @@ -699,4 +707,100 @@ mod test { ])) ); } + + #[test] + fn partial_replace() { + let e = Extension::new_arc( + IdentList::new_unchecked("NoBoundsChecking"), + Version::new(0, 0, 0), + |e, w| { + let params = vec![TypeBound::Any.into()]; + let tv = Type::new_var_use(0, TypeBound::Any); + let list_of_var = list_type(tv.clone()); + e.add_op( + READ.into(), + "Like List::get but without the option".to_string(), + PolyFuncType::new(params, Signature::new(vec![list_of_var, usize_t()], tv)), + w, + ) + .unwrap(); + }, + ); + fn option_contents(ty: &Type) -> Option { + let row = ty.as_sum()?.get_variant(1).unwrap().clone(); + let elem = row.into_owned().into_iter().exactly_one().unwrap(); + Some(elem.try_into_type().unwrap()) + } + let i32_t = || INT_TYPES[5].to_owned(); + let opt_i32 = Type::from(option_type(i32_t())); + let TypeEnum::Extension(i32_custom_t) = i32_t().as_type_enum().clone() else { + panic!() + }; + let mut dfb = DFGBuilder::new(inout_sig( + vec![list_type(i32_t()), list_type(opt_i32.clone())], + vec![i32_t(), opt_i32.clone()], + )) + .unwrap(); + let [l_i, l_oi] = dfb.input_wires_arr(); + let idx = dfb.add_load_value(ConstUsize::new(2)); + let [i] = dfb + .add_dataflow_op(read_op(&e, i32_t()), [l_i, idx]) + .unwrap() + .outputs_arr(); + let [oi] = dfb + .add_dataflow_op(read_op(&e, opt_i32.clone()), [l_oi, idx]) + .unwrap() + .outputs_arr(); + let mut h = dfb.finish_hugr_with_outputs([i, oi]).unwrap(); + + let mut lowerer = LowerTypes::default(); + lowerer.lower_type(i32_custom_t, qb_t(), Box::new(|_| panic!("No consts"))); + // Lower list> to list + lowerer.lower_parametric_type( + list_type_def(), + Box::new(|args| { + let [TypeArg::Type { ty }] = args else { + panic!("Expected just elem type") + }; + option_contents(ty).map(list_type) + }), + Box::new(|_| panic!("No consts")), + ); + // and read> to get - the latter has the expected option return type + lowerer.lower_parametric_op( + e.get_op(READ).unwrap().as_ref(), + Box::new(|args: &[TypeArg]| { + let [TypeArg::Type { ty }] = args else { + panic!("Expected just elem type") + }; + option_contents(ty).map(|elem| { + OpReplacement::SingleOp( + ListOp::get + .with_type(elem) + .to_extension_op() + .unwrap() + .into(), + ) + }) + }), + ); + assert!(lowerer.run(&mut h).unwrap()); + // list -> read -> usz just becomes list -> read -> qb + // list> -> read> -> opt becomes list -> get -> opt + assert_eq!( + h.root_type().dataflow_signature().unwrap().io(), + ( + &vec![list_type(qb_t()); 2].into(), + &vec![qb_t(), option_type(qb_t()).into()].into() + ) + ); + assert_eq!( + h.nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(ExtensionOp::name) + .sorted() + .collect_vec(), + ["NoBoundsChecking.read", "collections.list.get"] + ); + } } From 74cfdb12716c6f322d0651f05cc00af6d77d755a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 21:47:02 +0000 Subject: [PATCH 069/123] common up just_elem_type --- hugr-passes/src/lower_types.rs | 45 +++++++++++----------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index d655bf1330..3371a5903d 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -404,6 +404,13 @@ mod test { ExtensionOp::new(ext.get_op(READ).unwrap().clone(), [t.into()]).unwrap() } + fn just_elem_type(args: &[TypeArg]) -> &Type { + let [TypeArg::Type { ty }] = args else { + panic!("Expected just elem type") + }; + ty + } + fn ext() -> Arc { Extension::new_arc( IdentList::new("TestExt").unwrap(), @@ -446,9 +453,7 @@ mod test { fn lowerer(ext: &Arc) -> LowerTypes { fn lowered_read(args: &[TypeArg]) -> Option { - let [TypeArg::Type { ty }] = args else { - panic!("Illegal TypeArgs") - }; + let ty = just_elem_type(args); let mut dfb = DFGBuilder::new(inout_sig( vec![array_type(64, ty.clone()), i64_t()], ty.clone(), @@ -479,12 +484,7 @@ mod test { ); lw.lower_parametric_type( pv, - Box::new(|args: &[TypeArg]| { - let [TypeArg::Type { ty }] = args else { - panic!("Illegal TypeArgs") - }; - Some(array_type(64, ty.clone())) - }), + Box::new(|args: &[TypeArg]| Some(array_type(64, just_elem_type(args).clone()))), Box::new(|_| panic!("There are no constants?")), ); lw.lower_op( @@ -642,9 +642,7 @@ mod test { lowerer.lower_parametric_type( list_type_def(), Box::new(|args| { - let [TypeArg::Type { ty }] = args else { - panic!("Expected elem type") - }; + let ty = just_elem_type(args); (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) }), Box::new(|_| None), // leave the List unchanged @@ -658,9 +656,7 @@ mod test { lowerer.lower_parametric_type( list_type_def(), Box::new(|args| { - let [TypeArg::Type { ty }] = args else { - panic!("Expected elem type") - }; + let ty = just_elem_type(args); (usize_t() != *ty).then_some(array_type(10, ty.clone())) }), Box::new(|_| None), // leave the List unchanged @@ -678,12 +674,7 @@ mod test { let mut lowerer = LowerTypes::default(); lowerer.lower_parametric_type( list_type_def(), - Box::new(|args: &[TypeArg]| { - let [TypeArg::Type { ty }] = args else { - panic!("Expected elem type") - }; - Some(array_type(4, ty.clone())) - }), + Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), Box::new(|opaq| { let lv = opaq .value() @@ -758,22 +749,14 @@ mod test { // Lower list> to list lowerer.lower_parametric_type( list_type_def(), - Box::new(|args| { - let [TypeArg::Type { ty }] = args else { - panic!("Expected just elem type") - }; - option_contents(ty).map(list_type) - }), + Box::new(|args| option_contents(just_elem_type(args)).map(list_type)), Box::new(|_| panic!("No consts")), ); // and read> to get - the latter has the expected option return type lowerer.lower_parametric_op( e.get_op(READ).unwrap().as_ref(), Box::new(|args: &[TypeArg]| { - let [TypeArg::Type { ty }] = args else { - panic!("Expected just elem type") - }; - option_contents(ty).map(|elem| { + option_contents(just_elem_type(args)).map(|elem| { OpReplacement::SingleOp( ListOp::get .with_type(elem) From dc79c936fb9126ebb0137181eff7f07ae25eb600 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 21:52:46 +0000 Subject: [PATCH 070/123] comment spacing --- hugr-passes/src/lower_types.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 3371a5903d..f84be86799 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -702,7 +702,7 @@ mod test { #[test] fn partial_replace() { let e = Extension::new_arc( - IdentList::new_unchecked("NoBoundsChecking"), + IdentList::new_unchecked("NoBoundsCheck"), Version::new(0, 0, 0), |e, w| { let params = vec![TypeBound::Any.into()]; @@ -768,8 +768,8 @@ mod test { }), ); assert!(lowerer.run(&mut h).unwrap()); - // list -> read -> usz just becomes list -> read -> qb - // list> -> read> -> opt becomes list -> get -> opt + // list -> read -> usz just becomes list -> read -> qb + // list> -> read> -> opt becomes list -> get -> opt assert_eq!( h.root_type().dataflow_signature().unwrap().io(), ( @@ -783,7 +783,7 @@ mod test { .map(ExtensionOp::name) .sorted() .collect_vec(), - ["NoBoundsChecking.read", "collections.list.get"] + ["NoBoundsCheck.read", "collections.list.get"] ); } } From 63d24b08e64c0a8c68291117c81ebb0b86c25afe Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 25 Mar 2025 22:48:10 +0000 Subject: [PATCH 071/123] RIP SignatureMismatch --- hugr-passes/src/lower_types.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index f84be86799..e8eafa6498 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -12,9 +12,7 @@ use hugr_core::ops::{ FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{ - CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, -}; +use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -99,13 +97,6 @@ impl TypeTransformer for LowerTypes { pub enum ChangeTypeError { #[error(transparent)] SignatureError(#[from] SignatureError), - #[error("Lowering op {op} with original signature {old:?}\nExpected signature: {expected:?}\nBut got: {actual:?}")] - SignatureMismatch { - op: OpType, - old: Option, - expected: Option, - actual: Option, - }, #[error(transparent)] ValidationError(#[from] ValidatePassError), } From ba63fbef399f40420a9b5787d8fef3cb30958ffd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 09:08:30 +0000 Subject: [PATCH 072/123] Reinstate separate lower_consts/lower_consts_parametric --- hugr-passes/src/lower_types.rs | 76 ++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index e8eafa6498..8c10a495cc 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -114,37 +114,22 @@ impl LowerTypes { /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::lower_parametric_type] where the `src`s overlap. Thus, this /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as - /// substitution (parametric polymorphism) happening later will not respect the lowering(s). - /// - /// [Const]s of the specified type are transformed using the provided `const_fn` callback, - /// which is passed the value of the constant. - pub fn lower_type( - &mut self, - src: CustomType, - dest: Type, - const_fn: Box Value>, - ) { + /// substitution (parametric polymorphism) happening later will not respect this lowering. + pub fn lower_type(&mut self, src: CustomType, dest: Type) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't // check that for parametric types, we'll be consistent and not check here either. - self.type_map.insert(src.clone(), dest); - self.consts.insert(src, Arc::from(const_fn)); + self.type_map.insert(src, dest); } /// Configures this instance to change occurrences of a parametrized type `src` - /// via a callback `dest_fn` that builds the replacement type given the [TypeArg]s. + /// via a callback that builds the replacement type given the [TypeArg]s. /// Note that the TypeArgs will already have been lowered (e.g. they may not /// fit the bounds of the original type). The callback may return `None` to indicate - /// no change (in which case the TypeArgs will be lowered recursively but the outer - /// type constructor kept the same). - /// - /// Constants whose types are instantiations of `src` will be converted by the `const_fn` - /// callback, given the value of the constant; the callback may return `None` to - /// leave the constant unchanged. + /// no change (in which case the supplied/lowered TypeArgs will be given to `src`). pub fn lower_parametric_type( &mut self, src: &TypeDef, dest_fn: Box Option>, - const_fn: Box Option>, ) { // No way to check that dest_fn never produces a linear type. // We could require copy/discard-generators if src is Copyable, or *might be* @@ -152,15 +137,14 @@ impl LowerTypes { // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying // overapproximation. Moreover, these depend upon the *return type* of the Fn. self.param_types.insert(src.into(), Arc::from(dest_fn)); - self.param_consts.insert(src.into(), Arc::from(const_fn)); } /// Configures this instance to change occurrences of `src` to `dest`. - /// Note that if `src` is an instance of a *parametrized* [OpDef], this should only - /// be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as substitution - /// (parametric polymorphism) happening later will not respect the lowering(s). - /// - /// This takes precedence over [Self::lower_parametric_op] where the `src`s overlap. + /// Note that if `src` is an instance of a *parametrized* [OpDef], this takes + /// precedence over [Self::lower_parametric_op] where the `src`s overlap. Thus, this + /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as + /// substitution (parametric polymorphism) happening later will not respect this + /// lowering. pub fn lower_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { self.op_map.insert(OpHashWrapper::from(src), dest); } @@ -179,6 +163,31 @@ impl LowerTypes { self.param_ops.insert(src.into(), Arc::from(dest_fn)); } + /// Configures this instance to change [Const]s of type `src_ty`, using + /// a callback that is passed the value of the constant (of that type). + /// + /// Note that if `src_ty` is an instance of a *parametrized* [TypeDef], + /// this takes precedence over [Self::lower_consts_parametric] where + /// the `src_ty`s overlap. + pub fn lower_consts( + &mut self, + src_ty: CustomType, + const_fn: Box Value>, + ) { + self.consts.insert(src_ty.clone(), Arc::from(const_fn)); + } + + /// Configures this instance to change [Const]s of all types that are instances + /// of a parametric typedef `src_ty`, using a callback that is passed the + /// value of the constant (the [OpaqueValue] contains the [TypeArg]s). + pub fn lower_consts_parametric( + &mut self, + src_ty: &TypeDef, + const_fn: Box Option>, + ) { + self.param_consts.insert(src_ty.into(), Arc::from(const_fn)); + } + /// Run the pass using specified configuration. pub fn run(&self, hugr: &mut H) -> Result { self.validation @@ -468,15 +477,10 @@ mod test { } let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = LowerTypes::default(); - lw.lower_type( - pv.instantiate([bool_t().into()]).unwrap(), - i64_t(), - Box::new(|_| panic!("There are no constants")), - ); + lw.lower_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); lw.lower_parametric_type( pv, Box::new(|args: &[TypeArg]| Some(array_type(64, just_elem_type(args).clone()))), - Box::new(|_| panic!("There are no constants?")), ); lw.lower_op( &read_op(ext, bool_t()), @@ -636,7 +640,6 @@ mod test { let ty = just_elem_type(args); (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) }), - Box::new(|_| None), // leave the List unchanged ); let backup = h.clone(); assert!(!lowerer.run(&mut h).unwrap()); @@ -650,7 +653,6 @@ mod test { let ty = just_elem_type(args); (usize_t() != *ty).then_some(array_type(10, ty.clone())) }), - Box::new(|_| None), // leave the List unchanged ); assert!(lowerer.run(&mut h).unwrap()); let sig = h.signature(h.root()).unwrap(); @@ -666,6 +668,9 @@ mod test { lowerer.lower_parametric_type( list_type_def(), Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), + ); + lowerer.lower_consts_parametric( + list_type_def(), Box::new(|opaq| { let lv = opaq .value() @@ -736,12 +741,11 @@ mod test { let mut h = dfb.finish_hugr_with_outputs([i, oi]).unwrap(); let mut lowerer = LowerTypes::default(); - lowerer.lower_type(i32_custom_t, qb_t(), Box::new(|_| panic!("No consts"))); + lowerer.lower_type(i32_custom_t, qb_t()); // Lower list> to list lowerer.lower_parametric_type( list_type_def(), Box::new(|args| option_contents(just_elem_type(args)).map(list_type)), - Box::new(|_| panic!("No consts")), ); // and read> to get - the latter has the expected option return type lowerer.lower_parametric_op( From 2c7bb802d467d0172b4254e7054962c927a561ea Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 09:27:13 +0000 Subject: [PATCH 073/123] Remove the boxes --- hugr-passes/src/lower_types.rs | 65 ++++++++++++++-------------------- 1 file changed, 26 insertions(+), 39 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index 8c10a495cc..b85d2e19c0 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -129,14 +129,14 @@ impl LowerTypes { pub fn lower_parametric_type( &mut self, src: &TypeDef, - dest_fn: Box Option>, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { // No way to check that dest_fn never produces a linear type. // We could require copy/discard-generators if src is Copyable, or *might be* // (depending on arguments - i.e. if src's TypeDefBound is anything other than // `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying // overapproximation. Moreover, these depend upon the *return type* of the Fn. - self.param_types.insert(src.into(), Arc::from(dest_fn)); + self.param_types.insert(src.into(), Arc::new(dest_fn)); } /// Configures this instance to change occurrences of `src` to `dest`. @@ -158,9 +158,9 @@ impl LowerTypes { pub fn lower_parametric_op( &mut self, src: &OpDef, - dest_fn: Box Option>, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { - self.param_ops.insert(src.into(), Arc::from(dest_fn)); + self.param_ops.insert(src.into(), Arc::new(dest_fn)); } /// Configures this instance to change [Const]s of type `src_ty`, using @@ -172,9 +172,9 @@ impl LowerTypes { pub fn lower_consts( &mut self, src_ty: CustomType, - const_fn: Box Value>, + const_fn: impl Fn(&OpaqueValue) -> Value + 'static, ) { - self.consts.insert(src_ty.clone(), Arc::from(const_fn)); + self.consts.insert(src_ty.clone(), Arc::new(const_fn)); } /// Configures this instance to change [Const]s of all types that are instances @@ -183,9 +183,9 @@ impl LowerTypes { pub fn lower_consts_parametric( &mut self, src_ty: &TypeDef, - const_fn: Box Option>, + const_fn: impl Fn(&OpaqueValue) -> Option + 'static, ) { - self.param_consts.insert(src_ty.into(), Arc::from(const_fn)); + self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } /// Run the pass using specified configuration. @@ -634,26 +634,20 @@ mod test { // 1. Lower List to Array<10, T> UNLESS T is usize_t() or bool_t - this should have no effect let mut lowerer = LowerTypes::default(); - lowerer.lower_parametric_type( - list_type_def(), - Box::new(|args| { - let ty = just_elem_type(args); - (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) - }), - ); + lowerer.lower_parametric_type(list_type_def(), |args| { + let ty = just_elem_type(args); + (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) + }); let backup = h.clone(); assert!(!lowerer.run(&mut h).unwrap()); assert_eq!(h, backup); //2. Lower List to Array<10, T> UNLESS T is usize_t() - this leaves the Const unchanged let mut lowerer = LowerTypes::default(); - lowerer.lower_parametric_type( - list_type_def(), - Box::new(|args| { - let ty = just_elem_type(args); - (usize_t() != *ty).then_some(array_type(10, ty.clone())) - }), - ); + lowerer.lower_parametric_type(list_type_def(), |args| { + let ty = just_elem_type(args); + (usize_t() != *ty).then_some(array_type(10, ty.clone())) + }); assert!(lowerer.run(&mut h).unwrap()); let sig = h.signature(h.root()).unwrap(); assert_eq!( @@ -669,19 +663,13 @@ mod test { list_type_def(), Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), ); - lowerer.lower_consts_parametric( - list_type_def(), - Box::new(|opaq| { - let lv = opaq - .value() - .downcast_ref::() - .expect("Only one constant in test"); - Some( - ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()) - .into(), - ) - }), - ); + lowerer.lower_consts_parametric(list_type_def(), |opaq| { + let lv = opaq + .value() + .downcast_ref::() + .expect("Only one constant in test"); + Some(ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into()) + }); lowerer.run(&mut h).unwrap(); assert_eq!( @@ -743,10 +731,9 @@ mod test { let mut lowerer = LowerTypes::default(); lowerer.lower_type(i32_custom_t, qb_t()); // Lower list> to list - lowerer.lower_parametric_type( - list_type_def(), - Box::new(|args| option_contents(just_elem_type(args)).map(list_type)), - ); + lowerer.lower_parametric_type(list_type_def(), |args| { + option_contents(just_elem_type(args)).map(list_type) + }); // and read> to get - the latter has the expected option return type lowerer.lower_parametric_op( e.get_op(READ).unwrap().as_ref(), From 08d3227758a4a964c2688e4996b71c033471326e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:11:36 +0000 Subject: [PATCH 074/123] drop comment re. validation_level --- hugr-passes/src/lower_types.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/lower_types.rs index b85d2e19c0..e132737bba 100644 --- a/hugr-passes/src/lower_types.rs +++ b/hugr-passes/src/lower_types.rs @@ -103,8 +103,6 @@ pub enum ChangeTypeError { impl LowerTypes { /// Sets the validation level used before and after the pass is run. - // Note the self -> Self style is consistent with other passes, but not the other methods here. - // TODO change the others? But we are planning to drop validation_level in https://github.com/CQCL/hugr/pull/1895 pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; self From dedc3d5b3c3427c29eeb44284afdc54faba6dc5a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:40:37 +0000 Subject: [PATCH 075/123] rename module lower_types -> replace_types --- hugr-passes/src/lib.rs | 4 ++-- hugr-passes/src/{lower_types.rs => replace_types.rs} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename hugr-passes/src/{lower_types.rs => replace_types.rs} (100%) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index e7b643bcc9..0b776ff232 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -28,8 +28,8 @@ pub use monomorphize::remove_polyfuncs; #[allow(deprecated)] pub use monomorphize::monomorphize; pub use monomorphize::{MonomorphizeError, MonomorphizePass}; -pub mod lower_types; -pub use lower_types::LowerTypes; +pub mod replace_types; +pub use replace_types::LowerTypes; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-passes/src/lower_types.rs b/hugr-passes/src/replace_types.rs similarity index 100% rename from hugr-passes/src/lower_types.rs rename to hugr-passes/src/replace_types.rs From 99934673cd44406b47c6cbecf80bc3b67e975f07 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:42:50 +0000 Subject: [PATCH 076/123] rename LowerTypes=>ReplaceTypes, ChangeTypeError => ReplaceTypesError --- hugr-passes/src/lib.rs | 2 +- hugr-passes/src/replace_types.rs | 34 ++++++++++++++++---------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 0b776ff232..445f6a3841 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -29,7 +29,7 @@ pub use monomorphize::remove_polyfuncs; pub use monomorphize::monomorphize; pub use monomorphize::{MonomorphizeError, MonomorphizePass}; pub mod replace_types; -pub use replace_types::LowerTypes; +pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e132737bba..f6c8122f19 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -56,7 +56,7 @@ impl OpReplacement { } #[derive(Clone, Default)] -pub struct LowerTypes { +pub struct ReplaceTypes { /// Handles simple cases like T1 -> T2. /// If T1 is Copyable and T2 Linear, then error will be raised if we find e.g. /// ArrayOfCopyables(T1). This would require an additional entry for that. @@ -72,8 +72,8 @@ pub struct LowerTypes { validation: ValidationLevel, } -impl TypeTransformer for LowerTypes { - type Err = ChangeTypeError; +impl TypeTransformer for ReplaceTypes { + type Err = ReplaceTypesError; fn apply_custom(&self, ct: &CustomType) -> Result, Self::Err> { Ok(if let Some(res) = self.type_map.get(ct) { @@ -94,14 +94,14 @@ impl TypeTransformer for LowerTypes { #[derive(Debug, Error, PartialEq)] #[non_exhaustive] -pub enum ChangeTypeError { +pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] ValidationError(#[from] ValidatePassError), } -impl LowerTypes { +impl ReplaceTypes { /// Sets the validation level used before and after the pass is run. pub fn validation_level(mut self, level: ValidationLevel) -> Self { self.validation = level; @@ -187,12 +187,12 @@ impl LowerTypes { } /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { + pub fn run(&self, hugr: &mut H) -> Result { self.validation .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) } - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { changed |= self.change_node(hugr, n)?; @@ -200,7 +200,7 @@ impl LowerTypes { Ok(changed) } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { + fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) | OpType::FuncDecl(FuncDecl { signature, .. }) => signature.body_mut().transform(self), @@ -224,7 +224,7 @@ impl LowerTypes { if change { let new_inst = func_sig .instantiate(type_args) - .map_err(ChangeTypeError::SignatureError)?; + .map_err(ReplaceTypesError::SignatureError)?; *instantiation = new_inst; } Ok(change) @@ -287,7 +287,7 @@ impl LowerTypes { } } - fn change_value(&self, value: &mut Value) -> Result { + fn change_value(&self, value: &mut Value) -> Result { match value { Value::Sum(Sum { values, sum_type, .. @@ -389,7 +389,7 @@ mod test { use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use super::{LowerTypes, OpReplacement}; + use super::{ReplaceTypes, OpReplacement}; const PACKED_VEC: &str = "PackedVec"; const READ: &str = "read"; @@ -449,7 +449,7 @@ mod test { ) } - fn lowerer(ext: &Arc) -> LowerTypes { + fn lowerer(ext: &Arc) -> ReplaceTypes { fn lowered_read(args: &[TypeArg]) -> Option { let ty = just_elem_type(args); let mut dfb = DFGBuilder::new(inout_sig( @@ -474,7 +474,7 @@ mod test { ))) } let pv = ext.get_type(PACKED_VEC).unwrap(); - let mut lw = LowerTypes::default(); + let mut lw = ReplaceTypes::default(); lw.lower_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); lw.lower_parametric_type( pv, @@ -631,7 +631,7 @@ mod test { let mut h = tl.finish_hugr().unwrap(); // 1. Lower List to Array<10, T> UNLESS T is usize_t() or bool_t - this should have no effect - let mut lowerer = LowerTypes::default(); + let mut lowerer = ReplaceTypes::default(); lowerer.lower_parametric_type(list_type_def(), |args| { let ty = just_elem_type(args); (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) @@ -641,7 +641,7 @@ mod test { assert_eq!(h, backup); //2. Lower List to Array<10, T> UNLESS T is usize_t() - this leaves the Const unchanged - let mut lowerer = LowerTypes::default(); + let mut lowerer = ReplaceTypes::default(); lowerer.lower_parametric_type(list_type_def(), |args| { let ty = just_elem_type(args); (usize_t() != *ty).then_some(array_type(10, ty.clone())) @@ -656,7 +656,7 @@ mod test { // 3. Lower all List to Array<4,T> so we can use List's handy CustomConst let mut h = backup; - let mut lowerer = LowerTypes::default(); + let mut lowerer = ReplaceTypes::default(); lowerer.lower_parametric_type( list_type_def(), Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), @@ -726,7 +726,7 @@ mod test { .outputs_arr(); let mut h = dfb.finish_hugr_with_outputs([i, oi]).unwrap(); - let mut lowerer = LowerTypes::default(); + let mut lowerer = ReplaceTypes::default(); lowerer.lower_type(i32_custom_t, qb_t()); // Lower list> to list lowerer.lower_parametric_type(list_type_def(), |args| { From cc81d874557307c454c7c31c77fe0cc803506b07 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:44:59 +0000 Subject: [PATCH 077/123] rename methods, parametric => parametrized --- hugr-passes/src/replace_types.rs | 40 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index f6c8122f19..5c962d9073 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -110,12 +110,12 @@ impl ReplaceTypes { /// Configures this instance to change occurrences of type `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes - /// precedence over [Self::lower_parametric_type] where the `src`s overlap. Thus, this + /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as /// substitution (parametric polymorphism) happening later will not respect this lowering. - pub fn lower_type(&mut self, src: CustomType, dest: Type) { + pub fn replace_type(&mut self, src: CustomType, dest: Type) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't - // check that for parametric types, we'll be consistent and not check here either. + // check that for parametrized types, we'll be consistent and not check here either. self.type_map.insert(src, dest); } @@ -124,7 +124,7 @@ impl ReplaceTypes { /// Note that the TypeArgs will already have been lowered (e.g. they may not /// fit the bounds of the original type). The callback may return `None` to indicate /// no change (in which case the supplied/lowered TypeArgs will be given to `src`). - pub fn lower_parametric_type( + pub fn replace_parametrized_type( &mut self, src: &TypeDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, @@ -143,7 +143,7 @@ impl ReplaceTypes { /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as /// substitution (parametric polymorphism) happening later will not respect this /// lowering. - pub fn lower_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { + pub fn replace_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { self.op_map.insert(OpHashWrapper::from(src), dest); } @@ -153,7 +153,7 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. - pub fn lower_parametric_op( + pub fn replace_parametrized_op( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, @@ -167,7 +167,7 @@ impl ReplaceTypes { /// Note that if `src_ty` is an instance of a *parametrized* [TypeDef], /// this takes precedence over [Self::lower_consts_parametric] where /// the `src_ty`s overlap. - pub fn lower_consts( + pub fn replace_consts( &mut self, src_ty: CustomType, const_fn: impl Fn(&OpaqueValue) -> Value + 'static, @@ -176,9 +176,9 @@ impl ReplaceTypes { } /// Configures this instance to change [Const]s of all types that are instances - /// of a parametric typedef `src_ty`, using a callback that is passed the + /// of a parametrized typedef `src_ty`, using a callback that is passed the /// value of the constant (the [OpaqueValue] contains the [TypeArg]s). - pub fn lower_consts_parametric( + pub fn replace_consts_parametrized( &mut self, src_ty: &TypeDef, const_fn: impl Fn(&OpaqueValue) -> Option + 'static, @@ -475,12 +475,12 @@ mod test { } let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); - lw.lower_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); - lw.lower_parametric_type( + lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); + lw.replace_parametrized_type( pv, Box::new(|args: &[TypeArg]| Some(array_type(64, just_elem_type(args).clone()))), ); - lw.lower_op( + lw.replace_op( &read_op(ext, bool_t()), OpReplacement::SingleOp( ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), []) @@ -488,7 +488,7 @@ mod test { .into(), ), ); - lw.lower_parametric_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); lw } @@ -632,7 +632,7 @@ mod test { // 1. Lower List to Array<10, T> UNLESS T is usize_t() or bool_t - this should have no effect let mut lowerer = ReplaceTypes::default(); - lowerer.lower_parametric_type(list_type_def(), |args| { + lowerer.replace_parametrized_type(list_type_def(), |args| { let ty = just_elem_type(args); (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) }); @@ -642,7 +642,7 @@ mod test { //2. Lower List to Array<10, T> UNLESS T is usize_t() - this leaves the Const unchanged let mut lowerer = ReplaceTypes::default(); - lowerer.lower_parametric_type(list_type_def(), |args| { + lowerer.replace_parametrized_type(list_type_def(), |args| { let ty = just_elem_type(args); (usize_t() != *ty).then_some(array_type(10, ty.clone())) }); @@ -657,11 +657,11 @@ mod test { // 3. Lower all List to Array<4,T> so we can use List's handy CustomConst let mut h = backup; let mut lowerer = ReplaceTypes::default(); - lowerer.lower_parametric_type( + lowerer.replace_parametrized_type( list_type_def(), Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), ); - lowerer.lower_consts_parametric(list_type_def(), |opaq| { + lowerer.replace_consts_parametrized(list_type_def(), |opaq| { let lv = opaq .value() .downcast_ref::() @@ -727,13 +727,13 @@ mod test { let mut h = dfb.finish_hugr_with_outputs([i, oi]).unwrap(); let mut lowerer = ReplaceTypes::default(); - lowerer.lower_type(i32_custom_t, qb_t()); + lowerer.replace_type(i32_custom_t, qb_t()); // Lower list> to list - lowerer.lower_parametric_type(list_type_def(), |args| { + lowerer.replace_parametrized_type(list_type_def(), |args| { option_contents(just_elem_type(args)).map(list_type) }); // and read> to get - the latter has the expected option return type - lowerer.lower_parametric_op( + lowerer.replace_parametrized_op( e.get_op(READ).unwrap().as_ref(), Box::new(|args: &[TypeArg]| { option_contents(just_elem_type(args)).map(|elem| { From 74c6492bafd10a565d52ad2e5ba872b3e167ea09 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:49:06 +0000 Subject: [PATCH 078/123] doc notes --- hugr-passes/src/replace_types.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 5c962d9073..6d5a677ab5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1,4 +1,9 @@ #![allow(clippy::type_complexity)] +//! Replace types with other types across the Hugr. +//! +//! Parametrized types and ops will be reparametrized taking into account the replacements, +//! but any ops taking/returning the replaced types *not* as a result of parametrization, +//! will also need to be replaced - see [ReplaceTypes::replace_op]. (Similarly [Const]s.) use std::collections::HashMap; use std::sync::Arc; @@ -113,6 +118,10 @@ impl ReplaceTypes { /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as /// substitution (parametric polymorphism) happening later will not respect this lowering. + /// + /// If there are any [LoadConstant]s of this type, callers should also call [Self::replace_consts] + /// (or [Self::replace_consts_parametrized]) as the load-constants will be reparametrized + /// (and this will break the edge from const to loadconstant). pub fn replace_type(&mut self, src: CustomType, dest: Type) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't // check that for parametrized types, we'll be consistent and not check here either. @@ -124,6 +133,11 @@ impl ReplaceTypes { /// Note that the TypeArgs will already have been lowered (e.g. they may not /// fit the bounds of the original type). The callback may return `None` to indicate /// no change (in which case the supplied/lowered TypeArgs will be given to `src`). + /// + /// If there are any [LoadConstant]s of any of these types, callers should also call + /// [Self::replace_consts_parametrized] (or [Self::replace_consts]) as the + /// load-constants will be reparametrized (and this will break the edge from const to + /// loadconstant). pub fn replace_parametrized_type( &mut self, src: &TypeDef, @@ -389,7 +403,7 @@ mod test { use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use super::{ReplaceTypes, OpReplacement}; + use super::{OpReplacement, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; const READ: &str = "read"; From 18316de56cd621f3f39100ee176c20db9cecae7e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:55:44 +0000 Subject: [PATCH 079/123] comment/doc updates --- hugr-passes/src/replace_types.rs | 38 +++++++++++++++----------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 6d5a677ab5..1ee68e84fe 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -22,14 +22,14 @@ use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; -/// A thing to which an Op can be lowered, i.e. with which a node can be replaced. +/// A thing with which an Op (i.e. node) can be replaced #[derive(Clone, Debug, PartialEq)] pub enum OpReplacement { - /// Keep the same node (inputs/outputs, modulo lowering of types therein), change only the op + /// Keep the same node (inputs/outputs, modulo replacement of types therein), change only the op SingleOp(OpType), /// Defines a sub-Hugr to splice in place of the op - a [CFG](OpType::CFG), /// [Conditional](OpType::Conditional) or [DFG](OpType::DFG), which must have - /// the same (lowered) inputs and outputs as the original op. + /// the same inputs and outputs as the original op, modulo replacement. // Not a FuncDefn, nor Case/DataflowBlock /// Note this will be of limited use before [monomorphization](super::monomorphize()) because /// the sub-Hugr will not be able to use type variables present in the op. @@ -38,7 +38,7 @@ pub enum OpReplacement { CompoundOp(Box), // TODO allow also Call to a Node in the existing Hugr // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before lowering, then remove unused ones afterwards.) + // So client should add the functions before replacement, then remove unused ones afterwards.) } impl OpReplacement { @@ -62,15 +62,9 @@ impl OpReplacement { #[derive(Clone, Default)] pub struct ReplaceTypes { - /// Handles simple cases like T1 -> T2. - /// If T1 is Copyable and T2 Linear, then error will be raised if we find e.g. - /// ArrayOfCopyables(T1). This would require an additional entry for that. type_map: HashMap, - /// Parametric types are handled by a function which receives the lowered typeargs. param_types: HashMap Option>>, - // Handles simple cases Op1 -> Op2. op_map: HashMap, - // Called after lowering typeargs; return None to use original OpDef param_ops: HashMap Option>>, consts: HashMap Value>>, param_consts: HashMap Option>>, @@ -113,15 +107,19 @@ impl ReplaceTypes { self } - /// Configures this instance to change occurrences of type `src` to `dest`. + /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as - /// substitution (parametric polymorphism) happening later will not respect this lowering. + /// substitution (parametric polymorphism) happening later will not respect this replacement. /// /// If there are any [LoadConstant]s of this type, callers should also call [Self::replace_consts] /// (or [Self::replace_consts_parametrized]) as the load-constants will be reparametrized /// (and this will break the edge from const to loadconstant). + /// + /// Note that if `src` is Copyable and `dest` is Linear, then (besides linearity violations) + /// [SignatureError] will be raised if this leads to an impossible type e.g. ArrayOfCopyables(src). + /// (This can be overridden by an additional [Self::replace_type].) pub fn replace_type(&mut self, src: CustomType, dest: Type) { // We could check that 'dest' is copyable or 'src' is linear, but since we can't // check that for parametrized types, we'll be consistent and not check here either. @@ -130,9 +128,9 @@ impl ReplaceTypes { /// Configures this instance to change occurrences of a parametrized type `src` /// via a callback that builds the replacement type given the [TypeArg]s. - /// Note that the TypeArgs will already have been lowered (e.g. they may not + /// Note that the TypeArgs will already have been updated (e.g. they may not /// fit the bounds of the original type). The callback may return `None` to indicate - /// no change (in which case the supplied/lowered TypeArgs will be given to `src`). + /// no change (in which case the supplied TypeArgs will be given to `src`). /// /// If there are any [LoadConstant]s of any of these types, callers should also call /// [Self::replace_consts_parametrized] (or [Self::replace_consts]) as the @@ -153,17 +151,17 @@ impl ReplaceTypes { /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [OpDef], this takes - /// precedence over [Self::lower_parametric_op] where the `src`s overlap. Thus, this - /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as - /// substitution (parametric polymorphism) happening later will not respect this - /// lowering. + /// precedence over [Self::replace_parametrized_op] where the `src`s overlap. Thus, + /// this should only be used on already-*[monomorphize](super::monomorphize())d* + /// Hugrs, as substitution (parametric polymorphism) happening later will not respect + /// this replacement. pub fn replace_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { self.op_map.insert(OpHashWrapper::from(src), dest); } /// Configures this instance to change occurrences of a parametrized op `src` /// via a callback that builds the replacement type given the [TypeArg]s. - /// Note that the TypeArgs will already have been lowered (e.g. they may not + /// Note that the TypeArgs will already have been updated (e.g. they may not /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. @@ -179,7 +177,7 @@ impl ReplaceTypes { /// a callback that is passed the value of the constant (of that type). /// /// Note that if `src_ty` is an instance of a *parametrized* [TypeDef], - /// this takes precedence over [Self::lower_consts_parametric] where + /// this takes precedence over [Self::replace_consts_parametrized] where /// the `src_ty`s overlap. pub fn replace_consts( &mut self, From 520e4146777a931a1f80e898dbe47d0d4dab43cb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 10:57:24 +0000 Subject: [PATCH 080/123] fmt --- hugr-passes/src/replace_types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 1ee68e84fe..325f87776a 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -116,7 +116,7 @@ impl ReplaceTypes { /// If there are any [LoadConstant]s of this type, callers should also call [Self::replace_consts] /// (or [Self::replace_consts_parametrized]) as the load-constants will be reparametrized /// (and this will break the edge from const to loadconstant). - /// + /// /// Note that if `src` is Copyable and `dest` is Linear, then (besides linearity violations) /// [SignatureError] will be raised if this leads to an impossible type e.g. ArrayOfCopyables(src). /// (This can be overridden by an additional [Self::replace_type].) From 0814bf28931eed0e798fce229d416731942b2e08 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 12:23:09 +0000 Subject: [PATCH 081/123] callbacks for Const take &ReplaceTypes --- hugr-passes/src/replace_types.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 325f87776a..ce43f06794 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -66,8 +66,9 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, op_map: HashMap, param_ops: HashMap Option>>, - consts: HashMap Value>>, - param_consts: HashMap Option>>, + consts: HashMap Value>>, + param_consts: + HashMap Option>>, validation: ValidationLevel, } @@ -182,7 +183,7 @@ impl ReplaceTypes { pub fn replace_consts( &mut self, src_ty: CustomType, - const_fn: impl Fn(&OpaqueValue) -> Value + 'static, + const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Value + 'static, ) { self.consts.insert(src_ty.clone(), Arc::new(const_fn)); } @@ -193,7 +194,7 @@ impl ReplaceTypes { pub fn replace_consts_parametrized( &mut self, src_ty: &TypeDef, - const_fn: impl Fn(&OpaqueValue) -> Option + 'static, + const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Option + 'static, ) { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } @@ -313,11 +314,14 @@ impl ReplaceTypes { } Value::Extension { e } => Ok('changed: { if let TypeEnum::Extension(exty) = e.get_type().as_type_enum() { - if let Some(new_const) = - self.consts.get(exty).map(|const_fn| const_fn(e)).or(self + if let Some(new_const) = self + .consts + .get(exty) + .map(|const_fn| const_fn(e, self)) + .or(self .param_consts .get(&exty.into()) - .and_then(|const_fn| const_fn(e))) + .and_then(|const_fn| const_fn(e, self))) { *value = new_const; break 'changed true; @@ -673,7 +677,7 @@ mod test { list_type_def(), Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), ); - lowerer.replace_consts_parametrized(list_type_def(), |opaq| { + lowerer.replace_consts_parametrized(list_type_def(), |opaq, _| { let lv = opaq .value() .downcast_ref::() From 0a82fef63e5ec00e1bf0c2e66b93d3ab26bce0ed Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 12:28:08 +0000 Subject: [PATCH 082/123] Replace that break 'changed block with a match, + map_or_else to avoid param cb --- hugr-passes/src/replace_types.rs | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ce43f06794..dba5e4f91e 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -312,22 +312,24 @@ impl ReplaceTypes { any_change |= sum_type.transform(self)?; Ok(any_change) } - Value::Extension { e } => Ok('changed: { - if let TypeEnum::Extension(exty) = e.get_type().as_type_enum() { - if let Some(new_const) = self - .consts - .get(exty) - .map(|const_fn| const_fn(e, self)) - .or(self - .param_consts - .get(&exty.into()) - .and_then(|const_fn| const_fn(e, self))) - { - *value = new_const; - break 'changed true; - } + Value::Extension { e } => Ok({ + let new_const = match e.get_type().as_type_enum() { + TypeEnum::Extension(exty) => self.consts.get(exty).map_or_else( + || { + self.param_consts + .get(&exty.into()) + .and_then(|const_fn| const_fn(e, self)) + }, + |const_fn| Some(const_fn(e, self)), + ), + _ => None, + }; + if let Some(new_const) = new_const { + *value = new_const; + true + } else { + false } - false }), Value::Function { hugr } => self.run_no_validate(&mut **hugr), } From 94c0ff3e883330980badb6a248936d1fce4ecb2d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 12:31:09 +0000 Subject: [PATCH 083/123] Add Type::as_extension --- hugr-core/src/types.rs | 14 ++++++++++++++ hugr-passes/src/replace_types.rs | 19 +++++++------------ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index bcf2cda07e..fd6e3f2f76 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -481,6 +481,14 @@ impl TypeBase { } } + /// Returns the inner [Custom] if the type is from an extension. + pub fn as_extension(&self) -> Option<&CustomType> { + match &self.0 { + TypeEnum::Extension(ct) => Some(ct), + _ => None + } + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { @@ -827,6 +835,12 @@ pub(crate) mod test { assert!(t.as_sum().is_some()); } + #[test] + fn as_extension() { + assert_eq!(Type::new_extension(usize_t().as_extension().unwrap().clone), usize_t()); + assert_eq!(Type::new_unit_sum(0).as_extension(), None); + } + #[test] fn sum_variants() { { diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index dba5e4f91e..d3dbf82913 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -17,7 +17,7 @@ use hugr_core::ops::{ FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; +use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeTransformer}; use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -313,17 +313,16 @@ impl ReplaceTypes { Ok(any_change) } Value::Extension { e } => Ok({ - let new_const = match e.get_type().as_type_enum() { - TypeEnum::Extension(exty) => self.consts.get(exty).map_or_else( + let new_const = e.get_type().as_extension().and_then(|exty| { + self.consts.get(exty).map_or_else( || { self.param_consts .get(&exty.into()) .and_then(|const_fn| const_fn(e, self)) }, |const_fn| Some(const_fn(e, self)), - ), - _ => None, - }; + ) + }); if let Some(new_const) = new_const { *value = new_const; true @@ -401,9 +400,7 @@ mod test { list_type, list_type_def, ListOp, ListValue, }; - use hugr_core::types::{ - PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow, - }; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; @@ -724,9 +721,7 @@ mod test { } let i32_t = || INT_TYPES[5].to_owned(); let opt_i32 = Type::from(option_type(i32_t())); - let TypeEnum::Extension(i32_custom_t) = i32_t().as_type_enum().clone() else { - panic!() - }; + let i32_custom_t = i32_t().as_extension().unwrap().clone(); let mut dfb = DFGBuilder::new(inout_sig( vec![list_type(i32_t()), list_type(opt_i32.clone())], vec![i32_t(), opt_i32.clone()], From e9ff81cbd2a13e0b5cbebbc15290fd60b5055f9f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 12:33:06 +0000 Subject: [PATCH 084/123] wip --- hugr-passes/src/replace_types.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index d3dbf82913..87f56cde86 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -300,7 +300,9 @@ impl ReplaceTypes { } } - fn change_value(&self, value: &mut Value) -> Result { + /// Modifies the specified Value in-place according to current configuration. + /// Returns whether the value has changed (conservative over-approximation). + pub fn change_value(&self, value: &mut Value) -> Result { match value { Value::Sum(Sum { values, sum_type, .. From cf6451f28aff0fc8639cae67d75dc2f1ff9f814c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 13:27:05 +0000 Subject: [PATCH 085/123] add list_const handler, use in test --- hugr-passes/src/replace_types.rs | 112 ++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 87f56cde86..65d2f6eb34 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -185,7 +185,7 @@ impl ReplaceTypes { src_ty: CustomType, const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Value + 'static, ) { - self.consts.insert(src_ty.clone(), Arc::new(const_fn)); + self.consts.insert(src_ty, Arc::new(const_fn)); } /// Configures this instance to change [Const]s of all types that are instances @@ -337,6 +337,31 @@ impl ReplaceTypes { } } +pub mod handlers { + use hugr_core::ops::{constant::OpaqueValue, Value}; + use hugr_core::std_extensions::collections::list::ListValue; + use hugr_core::types::Transformable; + + use super::ReplaceTypes; + + pub fn list_const(val: &OpaqueValue, repl: &ReplaceTypes) -> Option { + let lv = val.value().downcast_ref::()?; + let mut vals: Vec = lv.get_contents().to_vec(); + let mut ch = false; + for v in vals.iter_mut() { + ch |= repl.change_value(v).ok()?; // Silently drop errors...? + } + // If none of the values has changed, assume the Type hasn't (Values have a single known type) + if !ch { + return None; + }; + + let mut elem_t = lv.get_element_type().clone(); + elem_t.transform(repl).ok()?; // Silently drop errors + Some(ListValue::new(elem_t, vals).into()) + } +} + #[derive(Clone, Hash, PartialEq, Eq)] struct OpHashWrapper { ext_name: ExtensionId, @@ -394,6 +419,7 @@ mod test { use hugr_core::extension::{TypeDefBound, Version}; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; + use hugr_core::std_extensions::arithmetic::int_types::ConstInt; use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, ArrayOp, ArrayOpDef, ArrayValue, @@ -406,7 +432,7 @@ mod test { use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use super::{OpReplacement, ReplaceTypes}; + use super::{handlers::list_const, OpReplacement, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; const READ: &str = "read"; @@ -645,44 +671,71 @@ mod test { .unwrap(), ); tl.set_outputs(pred, [bools]).unwrap(); - let mut h = tl.finish_hugr().unwrap(); + let backup = tl.finish_hugr().unwrap(); - // 1. Lower List to Array<10, T> UNLESS T is usize_t() or bool_t - this should have no effect let mut lowerer = ReplaceTypes::default(); + // Recursively descend into lists + lowerer.replace_consts_parametrized(list_type_def(), list_const); + + // 1. Lower List to Array<10, T> UNLESS T is usize_t() or i64_t lowerer.replace_parametrized_type(list_type_def(), |args| { let ty = just_elem_type(args); - (![usize_t(), bool_t()].contains(ty)).then_some(array_type(10, ty.clone())) + (![usize_t(), i64_t()].contains(ty)).then_some(array_type(10, ty.clone())) }); - let backup = h.clone(); - assert!(!lowerer.run(&mut h).unwrap()); - assert_eq!(h, backup); + { + let mut h = backup.clone(); + assert_eq!(lowerer.run(&mut h), Ok(true)); + let sig = h.signature(h.root()).unwrap(); + assert_eq!( + sig.input(), + &TypeRow::from(vec![list_type(usize_t()), array_type(10, bool_t())]) + ); + assert_eq!(sig.input(), sig.output()); + } - //2. Lower List to Array<10, T> UNLESS T is usize_t() - this leaves the Const unchanged - let mut lowerer = ReplaceTypes::default(); - lowerer.replace_parametrized_type(list_type_def(), |args| { - let ty = just_elem_type(args); - (usize_t() != *ty).then_some(array_type(10, ty.clone())) + // 2. Now we'll also change usize's to i64_t's + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + lowerer.replace_type(usize_custom_t.clone(), i64_t()); + lowerer.replace_consts(usize_custom_t, |opaq, _| { + ConstInt::new_u( + 6, + opaq.value().downcast_ref::().unwrap().value(), + ) + .unwrap() + .into() }); - assert!(lowerer.run(&mut h).unwrap()); - let sig = h.signature(h.root()).unwrap(); - assert_eq!( - sig.input(), - &TypeRow::from(vec![list_type(usize_t()), array_type(10, bool_t())]) - ); - assert_eq!(sig.input(), sig.output()); + { + let mut h = backup.clone(); + assert_eq!(lowerer.run(&mut h), Ok(true)); + let sig = h.signature(h.root()).unwrap(); + assert_eq!( + sig.input(), + &TypeRow::from(vec![list_type(i64_t()), array_type(10, bool_t())]) + ); + assert_eq!(sig.input(), sig.output()); + // This will have to update inside the Const + let cst = h + .nodes() + .filter_map(|n| h.get_optype(n).as_const()) + .exactly_one() + .ok() + .unwrap(); + assert_eq!(cst.get_type(), Type::new_sum(vec![list_type(i64_t()); 2])); + } - // 3. Lower all List to Array<4,T> so we can use List's handy CustomConst + // 3. Lower all List to Array<4,T> let mut h = backup; - let mut lowerer = ReplaceTypes::default(); lowerer.replace_parametrized_type( list_type_def(), Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), ); - lowerer.replace_consts_parametrized(list_type_def(), |opaq, _| { - let lv = opaq - .value() - .downcast_ref::() - .expect("Only one constant in test"); + lowerer.replace_consts_parametrized(list_type_def(), |opaq, repl| { + // First recursively transform the contents + let Some(Value::Extension { e: opaq }) = list_const(opaq, repl) else { + panic!("Expected list value to stay a list value"); + }; + let lv = opaq.value().downcast_ref::().unwrap(); + Some(ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into()) }); lowerer.run(&mut h).unwrap(); @@ -691,10 +744,7 @@ mod test { h.get_optype(pred.node()) .as_load_constant() .map(|lc| lc.constant_type()), - Some(&Type::new_sum(vec![ - Type::from(array_type(4, usize_t())); - 2 - ])) + Some(&Type::new_sum(vec![Type::from(array_type(4, i64_t())); 2])) ); } From 9a355d55516b95b8fed57ec8f7faec6d25432df2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 13:31:20 +0000 Subject: [PATCH 086/123] Clarify replace_consts_parametrized callback --- hugr-passes/src/replace_types.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 65d2f6eb34..7992fc462b 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -190,7 +190,8 @@ impl ReplaceTypes { /// Configures this instance to change [Const]s of all types that are instances /// of a parametrized typedef `src_ty`, using a callback that is passed the - /// value of the constant (the [OpaqueValue] contains the [TypeArg]s). + /// value of the constant (the [OpaqueValue] contains the [TypeArg]s). The + /// callback may return `None` to indicate no change to the constant. pub fn replace_consts_parametrized( &mut self, src_ty: &TypeDef, From ec2d0ba739c96092d61fa1551f98152c03238a7a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 13:50:54 +0000 Subject: [PATCH 087/123] const callbacks report errors via Result --- hugr-passes/src/replace_types.rs | 69 +++++++++++++++++++------------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 7992fc462b..7b16393cbb 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -17,7 +17,7 @@ use hugr_core::ops::{ FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeTransformer}; +use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; use hugr_core::{Hugr, Node}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -66,9 +66,14 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, op_map: HashMap, param_ops: HashMap Option>>, - consts: HashMap Value>>, - param_consts: - HashMap Option>>, + consts: HashMap< + CustomType, + Arc Result>, + >, + param_consts: HashMap< + ParametricType, + Arc Result, ReplaceTypesError>>, + >, validation: ValidationLevel, } @@ -183,7 +188,7 @@ impl ReplaceTypes { pub fn replace_consts( &mut self, src_ty: CustomType, - const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Value + 'static, + const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Result + 'static, ) { self.consts.insert(src_ty, Arc::new(const_fn)); } @@ -195,7 +200,8 @@ impl ReplaceTypes { pub fn replace_consts_parametrized( &mut self, src_ty: &TypeDef, - const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Option + 'static, + const_fn: impl Fn(&OpaqueValue, &ReplaceTypes) -> Result, ReplaceTypesError> + + 'static, ) { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } @@ -316,18 +322,18 @@ impl ReplaceTypes { Ok(any_change) } Value::Extension { e } => Ok({ - let new_const = e.get_type().as_extension().and_then(|exty| { - self.consts.get(exty).map_or_else( - || { - self.param_consts - .get(&exty.into()) - .and_then(|const_fn| const_fn(e, self)) - }, - |const_fn| Some(const_fn(e, self)), - ) - }); + let new_const = match e.get_type().as_type_enum() { + TypeEnum::Extension(exty) => match self.consts.get(exty) { + Some(const_fn) => Some(const_fn(e, self)), + None => self + .param_consts + .get(&exty.into()) + .and_then(|const_fn| const_fn(e, self).transpose()), + }, + _ => None, + }; if let Some(new_const) = new_const { - *value = new_const; + *value = new_const?; true } else { false @@ -343,23 +349,28 @@ pub mod handlers { use hugr_core::std_extensions::collections::list::ListValue; use hugr_core::types::Transformable; - use super::ReplaceTypes; + use super::{ReplaceTypes, ReplaceTypesError}; - pub fn list_const(val: &OpaqueValue, repl: &ReplaceTypes) -> Option { - let lv = val.value().downcast_ref::()?; + pub fn list_const( + val: &OpaqueValue, + repl: &ReplaceTypes, + ) -> Result, ReplaceTypesError> { + let Some(lv) = val.value().downcast_ref::() else { + return Ok(None); + }; let mut vals: Vec = lv.get_contents().to_vec(); let mut ch = false; for v in vals.iter_mut() { - ch |= repl.change_value(v).ok()?; // Silently drop errors...? + ch |= repl.change_value(v)?; } // If none of the values has changed, assume the Type hasn't (Values have a single known type) if !ch { - return None; + return Ok(None); }; let mut elem_t = lv.get_element_type().clone(); - elem_t.transform(repl).ok()?; // Silently drop errors - Some(ListValue::new(elem_t, vals).into()) + elem_t.transform(repl)?; // Silently drop errors + Ok(Some(ListValue::new(elem_t, vals).into())) } } @@ -698,12 +709,12 @@ mod test { let usize_custom_t = usize_t().as_extension().unwrap().clone(); lowerer.replace_type(usize_custom_t.clone(), i64_t()); lowerer.replace_consts(usize_custom_t, |opaq, _| { - ConstInt::new_u( + Ok(ConstInt::new_u( 6, opaq.value().downcast_ref::().unwrap().value(), ) .unwrap() - .into() + .into()) }); { let mut h = backup.clone(); @@ -732,12 +743,14 @@ mod test { ); lowerer.replace_consts_parametrized(list_type_def(), |opaq, repl| { // First recursively transform the contents - let Some(Value::Extension { e: opaq }) = list_const(opaq, repl) else { + let Some(Value::Extension { e: opaq }) = list_const(opaq, repl)? else { panic!("Expected list value to stay a list value"); }; let lv = opaq.value().downcast_ref::().unwrap(); - Some(ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into()) + Ok(Some( + ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into(), + )) }); lowerer.run(&mut h).unwrap(); From c9d7033491b730092a3a1bd9d29278c440044083 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 14:40:34 +0000 Subject: [PATCH 088/123] fmt+fix --- hugr-core/src/types.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index fd6e3f2f76..b09ae12863 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -481,11 +481,11 @@ impl TypeBase { } } - /// Returns the inner [Custom] if the type is from an extension. + /// Returns the inner [CustomType] if the type is from an extension. pub fn as_extension(&self) -> Option<&CustomType> { match &self.0 { TypeEnum::Extension(ct) => Some(ct), - _ => None + _ => None, } } @@ -837,7 +837,10 @@ pub(crate) mod test { #[test] fn as_extension() { - assert_eq!(Type::new_extension(usize_t().as_extension().unwrap().clone), usize_t()); + assert_eq!( + Type::new_extension(usize_t().as_extension().unwrap().clone()), + usize_t() + ); assert_eq!(Type::new_unit_sum(0).as_extension(), None); } From 812ac897bb79f8e63de7524a1a8f3a1f426662ae Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 15:05:27 +0000 Subject: [PATCH 089/123] comment tweaks --- hugr-passes/src/replace_types.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 7b16393cbb..f96f58f298 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -25,11 +25,11 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// A thing with which an Op (i.e. node) can be replaced #[derive(Clone, Debug, PartialEq)] pub enum OpReplacement { - /// Keep the same node (inputs/outputs, modulo replacement of types therein), change only the op + /// Keep the same node, change only the op (updating types of inputs/outputs) SingleOp(OpType), - /// Defines a sub-Hugr to splice in place of the op - a [CFG](OpType::CFG), - /// [Conditional](OpType::Conditional) or [DFG](OpType::DFG), which must have - /// the same inputs and outputs as the original op, modulo replacement. + /// Defines a sub-Hugr to splice in place of the op - a [CFG], [Conditional], [DFG] + /// or [TailLoop], which must have the same inputs and outputs as the original op, + /// modulo replacement. // Not a FuncDefn, nor Case/DataflowBlock /// Note this will be of limited use before [monomorphization](super::monomorphize()) because /// the sub-Hugr will not be able to use type variables present in the op. @@ -369,7 +369,7 @@ pub mod handlers { }; let mut elem_t = lv.get_element_type().clone(); - elem_t.transform(repl)?; // Silently drop errors + elem_t.transform(repl)?; Ok(Some(ListValue::new(elem_t, vals).into())) } } From e0937eeb2031977a23c09aca3f5aea2b60d9c28d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 15:58:28 +0000 Subject: [PATCH 090/123] docs --- hugr-passes/src/replace_types.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index f96f58f298..fa982115e0 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -60,6 +60,8 @@ impl OpReplacement { } } +/// A configuration of what types, ops, and constants should be replaced with what. +/// May be applied to a Hugr via [Self::run]. #[derive(Clone, Default)] pub struct ReplaceTypes { type_map: HashMap, @@ -97,6 +99,7 @@ impl TypeTransformer for ReplaceTypes { } } +/// An error produced by the [ReplaceTypes] pass #[derive(Debug, Error, PartialEq)] #[non_exhaustive] pub enum ReplaceTypesError { @@ -120,8 +123,8 @@ impl ReplaceTypes { /// substitution (parametric polymorphism) happening later will not respect this replacement. /// /// If there are any [LoadConstant]s of this type, callers should also call [Self::replace_consts] - /// (or [Self::replace_consts_parametrized]) as the load-constants will be reparametrized - /// (and this will break the edge from const to loadconstant). + /// (or [Self::replace_consts_parametrized]) as the [LoadConstant]s will be reparametrized + /// (and this will break the edge from [Const] to [LoadConstant]). /// /// Note that if `src` is Copyable and `dest` is Linear, then (besides linearity violations) /// [SignatureError] will be raised if this leads to an impossible type e.g. ArrayOfCopyables(src). @@ -140,8 +143,8 @@ impl ReplaceTypes { /// /// If there are any [LoadConstant]s of any of these types, callers should also call /// [Self::replace_consts_parametrized] (or [Self::replace_consts]) as the - /// load-constants will be reparametrized (and this will break the edge from const to - /// loadconstant). + /// [LoadConstant]s will be reparametrized (and this will break the edge from [Const] to + /// [LoadConstant]). pub fn replace_parametrized_type( &mut self, src: &TypeDef, From 36df03c903cbdb3d474a1c3116a5e509f24ac25d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 16:01:45 +0000 Subject: [PATCH 091/123] and some more - warn on missing_docs except ReplaceTypesError --- hugr-passes/src/replace_types.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index fa982115e0..fb72b86f82 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1,4 +1,5 @@ #![allow(clippy::type_complexity)] +#![warn(missing_docs)] //! Replace types with other types across the Hugr. //! //! Parametrized types and ops will be reparametrized taking into account the replacements, @@ -102,6 +103,7 @@ impl TypeTransformer for ReplaceTypes { /// An error produced by the [ReplaceTypes] pass #[derive(Debug, Error, PartialEq)] #[non_exhaustive] +#[allow(missing_docs)] pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), @@ -348,12 +350,15 @@ impl ReplaceTypes { } pub mod handlers { + //! Callbacks for use with [ReplaceTypes::replace_consts_parametrized] use hugr_core::ops::{constant::OpaqueValue, Value}; use hugr_core::std_extensions::collections::list::ListValue; use hugr_core::types::Transformable; use super::{ReplaceTypes, ReplaceTypesError}; + /// Handler for [ListValue] constants that recursively [ReplaceTypes::change_value]s + /// the elements of the list pub fn list_const( val: &OpaqueValue, repl: &ReplaceTypes, From 31ba4e0c857de904765f91cc4f1843e7ceefbf18 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 17:56:56 +0000 Subject: [PATCH 092/123] docs, make pub, more errors --- hugr-passes/src/replace_types.rs | 18 ++++--- hugr-passes/src/replace_types/linearize.rs | 61 ++++++++++++++++++---- 2 files changed, 62 insertions(+), 17 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 5f357a612b..e2f758182c 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -49,14 +49,17 @@ pub enum OpReplacement { } impl OpReplacement { - fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + /// Adds this instance to the specified [HugrMut] as a new node or subtree under a + /// given parent, returning the unique new child (of that parent) thus created + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { match self { OpReplacement::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), OpReplacement::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, } } - fn add( + /// Adds this instance to the specified [Dataflow] builder as a new node or subtree + pub fn add( self, dfb: &mut impl Dataflow, inputs: impl IntoIterator, @@ -200,8 +203,12 @@ impl ReplaceTypes { /// result of lonering a type that was either copied or discarded in the input Hugr. /// /// [Copyable]: hugr_core::types::TypeBound::Copyable - pub fn linearize(&mut self, src: Type, copy: OpReplacement, discard: OpReplacement) { - // We could raise an error if src's bound is Copyable? + pub fn linearize( + &mut self, + src: Type, + copy: OpReplacement, + discard: OpReplacement, + ) -> Result<(), LinearizeError> { self.linearize.register(src, copy, discard) } @@ -214,7 +221,7 @@ impl ReplaceTypes { /// /// (That is, this is the equivalent of [Self::linearize] but for parametric types.) /// - /// The [Linearizer] is passed so that the callbacks can use this to generate + /// The [Linearizer] is passed so that the callbacks can use it to generate /// `copy/`discard` ops for other types (e.g. the elements of a collection), /// as part of an [OpReplacement::CompoundOp]. pub fn linearize_parametric( @@ -223,7 +230,6 @@ impl ReplaceTypes { copy_fn: Box Result>, discard_fn: Box Result>, ) { - // We could raise an error if src's TypeDefBound is explicit Copyable ? self.linearize.register_parametric(src, copy_fn, discard_fn) } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index e2bdb89b83..d1e0da5956 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -8,6 +8,8 @@ use itertools::Itertools; use super::{OpReplacement, ParametricType}; +/// Configuration for inserting copy and discard operations for linear types +/// outports of which are sources of multiple or 0 edges. #[derive(Clone, Default)] pub struct Linearizer { // Keyed by lowered type, as only needed when there is an op outputting such @@ -28,6 +30,7 @@ pub struct Linearizer { } #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[allow(missing_docs)] pub enum LinearizeError { #[error("Need copy op for {_0}")] NeedCopy(Type), @@ -43,23 +46,47 @@ pub enum LinearizeError { /// SignatureError's can happen when converting nested types e.g. Sums #[error(transparent)] SignatureError(#[from] SignatureError), - /// Type variables, Row variables, and Aliases are not supported; - /// nor Function types, as these are always Copyable. + /// We cannot linearize (insert copy and discard functions) for + /// [Variable](TypeEnum::Variable)s, [Row variables](TypeEnum::RowVar), + /// or [Alias](TypeEnum::Alias)es. #[error("Cannot linearize type {_0}")] UnsupportedType(Type), + /// Neither does linearization make sense for copyable types + #[error("Type {_0} is copyable")] + CopyableType(Type), } impl Linearizer { - pub fn register(&mut self, typ: Type, copy: OpReplacement, discard: OpReplacement) { - self.copy_discard.insert(typ, (copy, discard)); + /// Registers a type for linearization by providing copy and discard operations. + /// + /// # Errors + /// + /// [LinearizeError::CopyableType] if `typ` is copyable + pub fn register( + &mut self, + typ: Type, + copy: OpReplacement, + discard: OpReplacement, + ) -> Result<(), LinearizeError> { + if typ.copyable() { + Err(LinearizeError::CopyableType(typ)) + } else { + self.copy_discard.insert(typ, (copy, discard)); + Ok(()) + } } + /// Registers that instances of a parametrized [TypeDef] should be linearized + /// by providing functions that generate copy and discard functions given the [TypeArg]s. pub fn register_parametric( &mut self, src: &TypeDef, copy_fn: Box Result>, discard_fn: Box Result>, ) { + // We could look for `src`s TypeDefBound being explicit Copyable, otherwise + // it depends on the arguments. Since there is no method to get the TypeDefBound + // from a TypeDef, leaving this for now. self.copy_discard_parametric .insert(src.into(), (Arc::from(copy_fn), Arc::from(discard_fn))); } @@ -69,7 +96,11 @@ impl Linearizer { /// /// # Errors /// - /// If needed copy or discard ops cannot be found; + /// Most variants of [LinearizeError] can be raised, specifically including + /// [LinearizeError::CopyableType] if the type is [Copyable], in which case the Hugr + /// will be unchanged. + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable pub fn insert_copy_discard( &self, hugr: &mut impl HugrMut, @@ -120,7 +151,9 @@ impl Linearizer { Ok(()) } - fn copy_op(&self, typ: &Type) -> Result { + /// Gets an [OpReplacement] for copying a value of type `typ`, i.e. + /// a recipe for a node with one input of that type and two outputs. + pub fn copy_op(&self, typ: &Type) -> Result { if let Some((copy, _)) = self.copy_discard.get(typ) { return Ok(copy.clone()); } @@ -168,10 +201,13 @@ impl Linearizer { .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; copy_fn(cty.args(), self) } + TypeEnum::Function(_) => Err(LinearizeError::CopyableType(typ.clone())), _ => Err(LinearizeError::UnsupportedType(typ.clone())), } } + /// Gets an [OpReplacement] for discarding a value of type `typ`, i.e. + /// a recipe for a node with one input of that type and no outputs. fn discard_op(&self, typ: &Type) -> Result { if let Some((_, discard)) = self.copy_discard.get(typ) { return Ok(discard.clone()); @@ -201,6 +237,7 @@ impl Linearizer { .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; discard_fn(cty.args(), self) } + TypeEnum::Function(_) => Err(LinearizeError::CopyableType(typ.clone())), _ => Err(LinearizeError::UnsupportedType(typ.clone())), } } @@ -265,11 +302,13 @@ mod test { let mut lowerer = ReplaceTypes::default(); let usize_custom_t = usize_t().as_extension().unwrap().clone(); lowerer.replace_type(usize_custom_t, lin_t.clone()); - lowerer.linearize( - lin_t, - OpReplacement::SingleOp(copy_op.into()), - OpReplacement::SingleOp(discard_op.into()), - ); + lowerer + .linearize( + lin_t, + OpReplacement::SingleOp(copy_op.into()), + OpReplacement::SingleOp(discard_op.into()), + ) + .unwrap(); (e, lowerer) } From 86c73edd77a56ec85772e6ca882590ea98e7dc71 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 18:06:09 +0000 Subject: [PATCH 093/123] Test copyable element inside Sum - breaks test --- hugr-passes/src/replace_types/linearize.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index d1e0da5956..f49f5eb90c 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -252,9 +252,10 @@ mod test { endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, }; - use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::prelude::usize_t; use hugr_core::extension::{TypeDefBound, Version}; use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; @@ -347,8 +348,8 @@ mod test { #[test] fn sums() { - let (e, lowerer) = ext_lowerer(); - let sum_ty = Type::from(option_type(vec![usize_t(), usize_t()])); + let i8_t = || INT_TYPES[3].clone(); + let sum_ty = Type::new_sum([vec![i8_t()], vec![usize_t(); 2]]); let mut outer = DFGBuilder::new(endo_sig(sum_ty.clone())).unwrap(); let [inp] = outer.input_wires_arr(); let inner = outer @@ -358,11 +359,12 @@ mod test { .unwrap(); let mut h = outer.finish_hugr_with_outputs([inp]).unwrap(); + let (e, lowerer) = ext_lowerer(); assert!(lowerer.run(&mut h).unwrap()); let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); - let option_ty = Type::from(option_type(vec![lin_t.clone(); 2])); - let copy_out: TypeRow = vec![option_ty.clone(); 2].into(); + let sum_ty = Type::new_sum([vec![i8_t()], vec![lin_t.clone(); 2]]); + let copy_out: TypeRow = vec![sum_ty.clone(); 2].into(); let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... @@ -376,11 +378,11 @@ mod test { .collect_array() .unwrap(); let [case0, case1] = h.children(cond).collect_array().unwrap(); - // first is for empty + // first is for empty - the only input is Copyable so can be directly wired or ignored assert_eq!(h.children(case0).count(), 2 + num_tags); // Input, Output assert_eq!(count_tags(case0), num_tags); let case0 = h.get_optype(case0).as_case().unwrap(); - assert_eq!(case0.signature.io(), (&vec![].into(), &out_row)); + assert_eq!(case0.signature.io(), (&vec![i8_t()].into(), &out_row)); // second is for two elements assert_eq!(h.children(case1).count(), 4 + num_tags); // Input, Output, two leaf copies/discards: From 92d7a51002d0db0253f27b3f27352de32a68a3bd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 17:58:47 +0000 Subject: [PATCH 094/123] Error on copyable; handle copyable elements of sums - fixes --- hugr-passes/src/replace_types/linearize.rs | 23 ++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index f49f5eb90c..6944391225 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -154,6 +154,9 @@ impl Linearizer { /// Gets an [OpReplacement] for copying a value of type `typ`, i.e. /// a recipe for a node with one input of that type and two outputs. pub fn copy_op(&self, typ: &Type) -> Result { + if typ.copyable() { + return Err(LinearizeError::CopyableType(typ.clone())); + }; if let Some((copy, _)) = self.copy_discard.get(typ) { return Ok(copy.clone()); } @@ -174,11 +177,14 @@ impl Linearizer { let mut orig_elems = vec![]; let mut copy_elems = vec![]; for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { - let [orig_elem, copy_elem] = self - .copy_op(ty)? - .add(&mut case_b, [inp]) - .unwrap() - .outputs_arr(); + let [orig_elem, copy_elem] = if ty.copyable() { + [inp, inp] + } else { + self.copy_op(ty)? + .add(&mut case_b, [inp]) + .unwrap() + .outputs_arr() + }; orig_elems.push(orig_elem); copy_elems.push(copy_elem); } @@ -209,6 +215,9 @@ impl Linearizer { /// Gets an [OpReplacement] for discarding a value of type `typ`, i.e. /// a recipe for a node with one input of that type and no outputs. fn discard_op(&self, typ: &Type) -> Result { + if typ.copyable() { + return Err(LinearizeError::CopyableType(typ.clone())); + }; if let Some((_, discard)) = self.copy_discard.get(typ) { return Ok(discard.clone()); } @@ -222,7 +231,9 @@ impl Linearizer { for (idx, variant) in variants.into_iter().enumerate() { let mut case_b = cb.case_builder(idx).unwrap(); for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { - self.discard_op(ty)?.add(&mut case_b, [inp]).unwrap(); + if !ty.copyable() { + self.discard_op(ty)?.add(&mut case_b, [inp]).unwrap(); + } } case_b.finish_with_outputs([]).unwrap(); } From 582b9f1768a6fa0b1419c85ad96549261acf9ef3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 18:28:40 +0000 Subject: [PATCH 095/123] register errors with type; panic on function as already ruled out --- hugr-passes/src/replace_types.rs | 9 ++++++--- hugr-passes/src/replace_types/linearize.rs | 10 +++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e2f758182c..652a063cea 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -198,9 +198,12 @@ impl ReplaceTypes { /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports.) + /// + /// Also applies if `src` is an element of a Sum or other type. /// - /// To clarify, these are used if `src` is not [Copyable], but is (perhaps contained in) the - /// result of lonering a type that was either copied or discarded in the input Hugr. + /// # Errors + /// + /// If `src` is [Copyable], it is returned as an `Err /// /// [Copyable]: hugr_core::types::TypeBound::Copyable pub fn linearize( @@ -208,7 +211,7 @@ impl ReplaceTypes { src: Type, copy: OpReplacement, discard: OpReplacement, - ) -> Result<(), LinearizeError> { + ) -> Result<(), Type> { self.linearize.register(src, copy, discard) } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 6944391225..e21c528381 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -61,15 +61,15 @@ impl Linearizer { /// /// # Errors /// - /// [LinearizeError::CopyableType] if `typ` is copyable + /// If `typ` is copyable, it is returned as an `Err`. pub fn register( &mut self, typ: Type, copy: OpReplacement, discard: OpReplacement, - ) -> Result<(), LinearizeError> { + ) -> Result<(), Type> { if typ.copyable() { - Err(LinearizeError::CopyableType(typ)) + Err(typ) } else { self.copy_discard.insert(typ, (copy, discard)); Ok(()) @@ -207,7 +207,7 @@ impl Linearizer { .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; copy_fn(cty.args(), self) } - TypeEnum::Function(_) => Err(LinearizeError::CopyableType(typ.clone())), + TypeEnum::Function(_) => panic!("Ruled out above as copyable"), _ => Err(LinearizeError::UnsupportedType(typ.clone())), } } @@ -248,7 +248,7 @@ impl Linearizer { .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; discard_fn(cty.args(), self) } - TypeEnum::Function(_) => Err(LinearizeError::CopyableType(typ.clone())), + TypeEnum::Function(_) => panic!("Ruled out above as copyable"), _ => Err(LinearizeError::UnsupportedType(typ.clone())), } } From 95da97d778e8b469e370d147d309049f02796da9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 26 Mar 2025 22:15:53 +0000 Subject: [PATCH 096/123] more docs - no overriding copy/discard of non-extension types --- hugr-passes/src/replace_types.rs | 4 ++-- hugr-passes/src/replace_types/linearize.rs | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 652a063cea..0b038259d1 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -198,11 +198,11 @@ impl ReplaceTypes { /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports.) - /// + /// /// Also applies if `src` is an element of a Sum or other type. /// /// # Errors - /// + /// /// If `src` is [Copyable], it is returned as an `Err /// /// [Copyable]: hugr_core::types::TypeBound::Copyable diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index e21c528381..8eef2a7f09 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -15,11 +15,10 @@ pub struct Linearizer { // Keyed by lowered type, as only needed when there is an op outputting such copy_discard: HashMap, // Copy/discard of parametric types handled by a function that receives the new/lowered type. - // We do not allow linearization to "parametrized" non-extension types, at least not - // in one step. We could do that using a trait, but it seems enough of a corner case. - // Instead that can be achieved by *firstly* lowering to a custom linear type, with copy/dup + // We do not allow overriding copy/discard of non-extension types, but that + // can be achieved by *firstly* lowering to a custom linear type, with copy/discard // inserted; *secondly* by lowering that to the desired non-extension linear type, - // including lowering of the copy/dup operations to...whatever. + // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, ( From c7e56e6174a5f00d32f94564c624a04fa8f2b8c0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 27 Mar 2025 19:33:11 +0000 Subject: [PATCH 097/123] pub discard_op --- hugr-passes/src/replace_types/linearize.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 8eef2a7f09..31731b8f86 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -213,7 +213,7 @@ impl Linearizer { /// Gets an [OpReplacement] for discarding a value of type `typ`, i.e. /// a recipe for a node with one input of that type and no outputs. - fn discard_op(&self, typ: &Type) -> Result { + pub fn discard_op(&self, typ: &Type) -> Result { if typ.copyable() { return Err(LinearizeError::CopyableType(typ.clone())); }; From a1e63de1ee1d42f4dcc3794e3bbab20db0b2e34c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 13:32:16 +0000 Subject: [PATCH 098/123] rm Boxes --- hugr-passes/src/replace_types.rs | 4 ++-- hugr-passes/src/replace_types/linearize.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 0b038259d1..ff0159d8d5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -230,8 +230,8 @@ impl ReplaceTypes { pub fn linearize_parametric( &mut self, src: &TypeDef, - copy_fn: Box Result>, - discard_fn: Box Result>, + copy_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, + discard_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, ) { self.linearize.register_parametric(src, copy_fn, discard_fn) } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 31731b8f86..16337affc1 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -80,14 +80,14 @@ impl Linearizer { pub fn register_parametric( &mut self, src: &TypeDef, - copy_fn: Box Result>, - discard_fn: Box Result>, + copy_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, + discard_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise // it depends on the arguments. Since there is no method to get the TypeDefBound // from a TypeDef, leaving this for now. self.copy_discard_parametric - .insert(src.into(), (Arc::from(copy_fn), Arc::from(discard_fn))); + .insert(src.into(), (Arc::new(copy_fn), Arc::new(discard_fn))); } /// Insert copy or discard operations (as appropriate) enough to wire `src_port` of `src_node` From 7634e92f2ff69eb5307192347a204e10f5d33052 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 13:55:01 +0000 Subject: [PATCH 099/123] Only allow copy/discard funcs for *Custom*Type's --- hugr-passes/src/replace_types.rs | 6 +-- hugr-passes/src/replace_types/linearize.rs | 55 +++++++++++----------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ff0159d8d5..a8288699a1 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -199,7 +199,7 @@ impl ReplaceTypes { /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports.) /// - /// Also applies if `src` is an element of a Sum or other type. + /// The same [OpReplacement]s are also used in cases where `src` is an element of a [TypeEnum::Sum]. /// /// # Errors /// @@ -208,10 +208,10 @@ impl ReplaceTypes { /// [Copyable]: hugr_core::types::TypeBound::Copyable pub fn linearize( &mut self, - src: Type, + src: CustomType, copy: OpReplacement, discard: OpReplacement, - ) -> Result<(), Type> { + ) -> Result<(), CustomType> { self.linearize.register(src, copy, discard) } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 16337affc1..c250b0f96c 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}; use hugr_core::extension::{SignatureError, TypeDef}; -use hugr_core::types::{Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::types::{CustomType, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node, OutgoingPort}; use itertools::Itertools; @@ -13,7 +13,7 @@ use super::{OpReplacement, ParametricType}; #[derive(Clone, Default)] pub struct Linearizer { // Keyed by lowered type, as only needed when there is an op outputting such - copy_discard: HashMap, + copy_discard: HashMap, // Copy/discard of parametric types handled by a function that receives the new/lowered type. // We do not allow overriding copy/discard of non-extension types, but that // can be achieved by *firstly* lowering to a custom linear type, with copy/discard @@ -63,11 +63,11 @@ impl Linearizer { /// If `typ` is copyable, it is returned as an `Err`. pub fn register( &mut self, - typ: Type, + typ: CustomType, copy: OpReplacement, discard: OpReplacement, - ) -> Result<(), Type> { - if typ.copyable() { + ) -> Result<(), CustomType> { + if typ.bound() == TypeBound::Copyable { Err(typ) } else { self.copy_discard.insert(typ, (copy, discard)); @@ -156,9 +156,6 @@ impl Linearizer { if typ.copyable() { return Err(LinearizeError::CopyableType(typ.clone())); }; - if let Some((copy, _)) = self.copy_discard.get(typ) { - return Ok(copy.clone()); - } match typ.as_type_enum() { TypeEnum::Sum(sum_type) => { let variants = sum_type @@ -199,13 +196,16 @@ impl Linearizer { cb.finish_hugr().unwrap(), ))) } - TypeEnum::Extension(cty) => { - let (copy_fn, _) = self - .copy_discard_parametric - .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - copy_fn(cty.args(), self) - } + TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { + Some((copy, _)) => Ok(copy.clone()), + None => { + let (copy_fn, _) = self + .copy_discard_parametric + .get(&cty.into()) + .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; + copy_fn(cty.args(), self) + } + }, TypeEnum::Function(_) => panic!("Ruled out above as copyable"), _ => Err(LinearizeError::UnsupportedType(typ.clone())), } @@ -217,9 +217,6 @@ impl Linearizer { if typ.copyable() { return Err(LinearizeError::CopyableType(typ.clone())); }; - if let Some((_, discard)) = self.copy_discard.get(typ) { - return Ok(discard.clone()); - } match typ.as_type_enum() { TypeEnum::Sum(sum_type) => { let variants = sum_type @@ -240,13 +237,16 @@ impl Linearizer { cb.finish_hugr().unwrap(), ))) } - TypeEnum::Extension(cty) => { - let (_, discard_fn) = self - .copy_discard_parametric - .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; - discard_fn(cty.args(), self) - } + TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { + Some((_, discard)) => Ok(discard.clone()), + None => { + let (_, discard_fn) = self + .copy_discard_parametric + .get(&cty.into()) + .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; + discard_fn(cty.args(), self) + } + }, TypeEnum::Function(_) => panic!("Ruled out above as copyable"), _ => Err(LinearizeError::UnsupportedType(typ.clone())), } @@ -305,7 +305,8 @@ mod test { }, ); - let lin_t = Type::new_extension(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let lin_custom_t = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t = Type::new_extension(lin_custom_t.clone()); // Configure to lower usize_t to the linear type above let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), []).unwrap(); @@ -315,7 +316,7 @@ mod test { lowerer.replace_type(usize_custom_t, lin_t.clone()); lowerer .linearize( - lin_t, + lin_custom_t, OpReplacement::SingleOp(copy_op.into()), OpReplacement::SingleOp(discard_op.into()), ) From 0cd31e07d6fc323f98fab7084c0185cac4915c6d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 14:29:47 +0000 Subject: [PATCH 100/123] single callback taking num_outports != 1 --- hugr-passes/src/replace_types.rs | 17 +- hugr-passes/src/replace_types/linearize.rs | 216 ++++++++++----------- 2 files changed, 117 insertions(+), 116 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index a8288699a1..9d184fbe1f 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -220,20 +220,23 @@ impl ReplaceTypes { /// * is not [Copyable](hugr_core::types::TypeBound::Copyable), and /// * has other than one connected inport, /// - /// ...then these functions should be used to generate `copy` or `discard` ops. + /// ...then the provided callback should be used to generate a `copy` or `discard` op, + /// passing the desired number of outports (which will never be 1). /// - /// (That is, this is the equivalent of [Self::linearize] but for parametric types.) + /// (That is, this is like [Self::linearize] but for parametric types and/or + /// with a callback that can generate an n-way copy directly, rather than + /// with a 0-way and 2-way copy.) /// - /// The [Linearizer] is passed so that the callbacks can use it to generate - /// `copy/`discard` ops for other types (e.g. the elements of a collection), + /// The [Linearizer] is passed so that the callback can use it to generate + /// `copy`/`discard` ops for other types (e.g. the elements of a collection), /// as part of an [OpReplacement::CompoundOp]. pub fn linearize_parametric( &mut self, src: &TypeDef, - copy_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, - discard_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, + copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + + 'static, ) { - self.linearize.register_parametric(src, copy_fn, discard_fn) + self.linearize.register_parametric(src, copy_discard_fn) } /// Configures this instance to change occurrences of `src` to `dest`. diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index c250b0f96c..9deabf8d40 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,6 +1,10 @@ +use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; -use hugr_core::builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}; +use hugr_core::builder::{ + inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, +}; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::types::{CustomType, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node, OutgoingPort}; @@ -21,10 +25,7 @@ pub struct Linearizer { // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - ( - Arc Result>, - Arc Result>, - ), + Arc Result>, >, } @@ -80,14 +81,14 @@ impl Linearizer { pub fn register_parametric( &mut self, src: &TypeDef, - copy_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, - discard_fn: impl Fn(&[TypeArg], &Linearizer) -> Result + 'static, + copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise // it depends on the arguments. Since there is no method to get the TypeDefBound // from a TypeDef, leaving this for now. self.copy_discard_parametric - .insert(src.into(), (Arc::new(copy_fn), Arc::new(discard_fn))); + .insert(src.into(), Arc::new(copy_discard_fn)); } /// Insert copy or discard operations (as appropriate) enough to wire `src_port` of `src_node` @@ -103,59 +104,68 @@ impl Linearizer { pub fn insert_copy_discard( &self, hugr: &mut impl HugrMut, - mut src_node: Node, - mut src_port: OutgoingPort, + src_node: Node, + src_port: OutgoingPort, typ: &Type, // Or better to get the signature ourselves?? targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { - let (last_node, last_inport) = match targets.last() { - None => { + let (tgt_node, tgt_inport) = match targets.len() { + 0 => { let parent = hugr.get_parent(src_node).unwrap(); - (self.discard_op(typ)?.add_hugr(hugr, parent), 0.into()) - } - Some(last) => *last, - }; - - if targets.len() > 1 { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) - let src_parent = hugr - .get_parent(src_node) - .expect("Root node cannot have out edges"); - if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { - let tgt_parent = hugr - .get_parent(*tgt) - .expect("Root node cannot have incoming edges"); - (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) - }) { - return Err(LinearizeError::NoLinearNonLocalEdges { - src: src_node, - src_parent, - tgt, - tgt_parent, - }); + ( + self.copy_discard_op(typ, 0)?.add_hugr(hugr, parent), + 0.into(), + ) } - - let copy_op = self.copy_op(typ)?; - - for (tgt_node, tgt_port) in &targets[..targets.len() - 1] { - let n = copy_op - .clone() - .add_hugr(hugr, hugr.get_parent(src_node).unwrap()); - hugr.connect(src_node, src_port, n, 0); - hugr.connect(n, 0, *tgt_node, *tgt_port); - (src_node, src_port) = (n, 1.into()); + 1 => *targets.first().unwrap(), + _ => { + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + let src_parent = hugr + .get_parent(src_node) + .expect("Root node cannot have out edges"); + if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { + let tgt_parent = hugr + .get_parent(*tgt) + .expect("Root node cannot have incoming edges"); + (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) + }) { + return Err(LinearizeError::NoLinearNonLocalEdges { + src: src_node, + src_parent, + tgt, + tgt_parent, + }); + } + let copy_op = self + .copy_discard_op(typ, targets.len())? + .add_hugr(hugr, src_parent); + for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { + hugr.connect(copy_op, n, *tgt_node, *tgt_port); + } + (copy_op, 0.into()) } - } - hugr.connect(src_node, src_port, last_node, last_inport); + }; + hugr.connect(src_node, src_port, tgt_node, tgt_inport); Ok(()) } - /// Gets an [OpReplacement] for copying a value of type `typ`, i.e. - /// a recipe for a node with one input of that type and two outputs. - pub fn copy_op(&self, typ: &Type) -> Result { + /// Gets an [OpReplacement] for copying or discarding a value of type `typ`, i.e. + /// a recipe for a node with one input of that type and the specified number of + /// outports. Note that `num_outports` should never be 1 (as no node is required) + /// + /// # Panics + /// + /// if `num_outports == 1` + pub fn copy_discard_op( + &self, + typ: &Type, + num_outports: usize, + ) -> Result { if typ.copyable() { return Err(LinearizeError::CopyableType(typ.clone())); }; + assert!(num_outports != 1); + match typ.as_type_enum() { TypeEnum::Sum(sum_type) => { let variants = sum_type @@ -165,86 +175,71 @@ impl Linearizer { let mut cb = ConditionalBuilder::new( variants.clone(), vec![], - vec![sum_type.clone().into(); 2], + vec![sum_type.clone().into(); num_outports], ) .unwrap(); for (tag, variant) in variants.iter().enumerate() { let mut case_b = cb.case_builder(tag).unwrap(); - let mut orig_elems = vec![]; - let mut copy_elems = vec![]; + let mut elems_per_output = vec![vec![]; num_outports]; for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { - let [orig_elem, copy_elem] = if ty.copyable() { - [inp, inp] + let elems_this_input = if ty.copyable() { + repeat(inp).take(num_outports).collect::>() } else { - self.copy_op(ty)? + self.copy_discard_op(ty, num_outports)? .add(&mut case_b, [inp]) .unwrap() - .outputs_arr() + .outputs() + .collect() }; - orig_elems.push(orig_elem); - copy_elems.push(copy_elem); + for (src, elems_this_output) in elems_this_input + .into_iter() + .zip_eq(elems_per_output.iter_mut()) + { + elems_this_output.push(src) + } } let t = Tag::new(tag, variants.clone()); - let [orig] = case_b - .add_dataflow_op(t.clone(), orig_elems) - .unwrap() - .outputs_arr(); - let [copy] = case_b.add_dataflow_op(t, copy_elems).unwrap().outputs_arr(); - case_b.finish_with_outputs([orig, copy]).unwrap(); + let outputs = elems_per_output + .into_iter() + .map(|elems_this_output| { + let [this_output] = case_b + .add_dataflow_op(t.clone(), elems_this_output) + .unwrap() + .outputs_arr(); + this_output + }) + .collect::>(); // must collect to end borrow of `case_b` by closure + case_b.finish_with_outputs(outputs).unwrap(); } Ok(OpReplacement::CompoundOp(Box::new( cb.finish_hugr().unwrap(), ))) } TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { - Some((copy, _)) => Ok(copy.clone()), - None => { - let (copy_fn, _) = self - .copy_discard_parametric - .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - copy_fn(cty.args(), self) - } - }, - TypeEnum::Function(_) => panic!("Ruled out above as copyable"), - _ => Err(LinearizeError::UnsupportedType(typ.clone())), - } - } - - /// Gets an [OpReplacement] for discarding a value of type `typ`, i.e. - /// a recipe for a node with one input of that type and no outputs. - pub fn discard_op(&self, typ: &Type) -> Result { - if typ.copyable() { - return Err(LinearizeError::CopyableType(typ.clone())); - }; - match typ.as_type_enum() { - TypeEnum::Sum(sum_type) => { - let variants = sum_type - .variants() - .map(|trv| trv.clone().try_into()) - .collect::, _>>()?; - let mut cb = ConditionalBuilder::new(variants.clone(), vec![], vec![]).unwrap(); - for (idx, variant) in variants.into_iter().enumerate() { - let mut case_b = cb.case_builder(idx).unwrap(); - for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { - if !ty.copyable() { - self.discard_op(ty)?.add(&mut case_b, [inp]).unwrap(); - } + Some((copy, discard)) => Ok(if num_outports == 0 { + discard.clone() + } else { + let mut dfb = + DFGBuilder::new(inout_sig(typ.clone(), vec![typ.clone(); num_outports])) + .unwrap(); + let [mut src] = dfb.input_wires_arr(); + let mut outputs = vec![]; + for _ in 0..num_outports - 1 { + let [out0, out1] = copy.clone().add(&mut dfb, [src]).unwrap().outputs_arr(); + outputs.push(out0); + src = out1; } - case_b.finish_with_outputs([]).unwrap(); - } - Ok(OpReplacement::CompoundOp(Box::new( - cb.finish_hugr().unwrap(), - ))) - } - TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { - Some((_, discard)) => Ok(discard.clone()), + outputs.push(src); + OpReplacement::CompoundOp(Box::new( + dfb.finish_hugr_with_outputs(outputs).unwrap(), + )) + }), None => { - let (_, discard_fn) = self + let copy_discard_fn = self .copy_discard_parametric .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedDiscard(typ.clone()))?; - discard_fn(cty.args(), self) + .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; + copy_discard_fn(cty.args(), num_outports, self) } }, TypeEnum::Function(_) => panic!("Ruled out above as copyable"), @@ -264,6 +259,7 @@ mod test { use hugr_core::extension::prelude::usize_t; use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; @@ -399,7 +395,9 @@ mod test { assert_eq!(h.children(case1).count(), 4 + num_tags); // Input, Output, two leaf copies/discards: assert_eq!(count_tags(case1), num_tags); assert_eq!( - h.children(case1) + DescendantsGraph::::try_new(&h, case1) + .unwrap() + .nodes() .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) .collect_vec(), vec![ext_op_name; 2] From 65b899302a81d7829c6f54ceb1e2ed7497991144 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 14:32:30 +0000 Subject: [PATCH 101/123] In insert_copy_discard, a discard really is a 0-way copy --- hugr-passes/src/replace_types/linearize.rs | 58 ++++++++++------------ 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 9deabf8d40..265e4775f9 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -109,41 +109,33 @@ impl Linearizer { typ: &Type, // Or better to get the signature ourselves?? targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { - let (tgt_node, tgt_inport) = match targets.len() { - 0 => { - let parent = hugr.get_parent(src_node).unwrap(); - ( - self.copy_discard_op(typ, 0)?.add_hugr(hugr, parent), - 0.into(), - ) + let (tgt_node, tgt_inport) = if targets.len() == 1 { + *targets.first().unwrap() + } else { + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + let src_parent = hugr + .get_parent(src_node) + .expect("Root node cannot have out edges"); + if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { + let tgt_parent = hugr + .get_parent(*tgt) + .expect("Root node cannot have incoming edges"); + (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) + }) { + return Err(LinearizeError::NoLinearNonLocalEdges { + src: src_node, + src_parent, + tgt, + tgt_parent, + }); } - 1 => *targets.first().unwrap(), - _ => { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) - let src_parent = hugr - .get_parent(src_node) - .expect("Root node cannot have out edges"); - if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { - let tgt_parent = hugr - .get_parent(*tgt) - .expect("Root node cannot have incoming edges"); - (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) - }) { - return Err(LinearizeError::NoLinearNonLocalEdges { - src: src_node, - src_parent, - tgt, - tgt_parent, - }); - } - let copy_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); - for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { - hugr.connect(copy_op, n, *tgt_node, *tgt_port); - } - (copy_op, 0.into()) + let copy_discard_op = self + .copy_discard_op(typ, targets.len())? + .add_hugr(hugr, src_parent); + for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { + hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } + (copy_discard_op, 0.into()) }; hugr.connect(src_node, src_port, tgt_node, tgt_inport); Ok(()) From 198f9907d62585bc973b571d723fb6babd0386a1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 14:37:12 +0000 Subject: [PATCH 102/123] renaming --- hugr-passes/src/replace_types/linearize.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 265e4775f9..8179443075 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -172,9 +172,9 @@ impl Linearizer { .unwrap(); for (tag, variant) in variants.iter().enumerate() { let mut case_b = cb.case_builder(tag).unwrap(); - let mut elems_per_output = vec![vec![]; num_outports]; + let mut elems_for_copy = vec![vec![]; num_outports]; for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { - let elems_this_input = if ty.copyable() { + let inp_copies = if ty.copyable() { repeat(inp).take(num_outports).collect::>() } else { self.copy_discard_op(ty, num_outports)? @@ -183,22 +183,20 @@ impl Linearizer { .outputs() .collect() }; - for (src, elems_this_output) in elems_this_input - .into_iter() - .zip_eq(elems_per_output.iter_mut()) + for (src, elems) in inp_copies.into_iter().zip_eq(elems_for_copy.iter_mut()) { - elems_this_output.push(src) + elems.push(src) } } let t = Tag::new(tag, variants.clone()); - let outputs = elems_per_output + let outputs = elems_for_copy .into_iter() - .map(|elems_this_output| { - let [this_output] = case_b - .add_dataflow_op(t.clone(), elems_this_output) + .map(|elems| { + let [copy] = case_b + .add_dataflow_op(t.clone(), elems) .unwrap() .outputs_arr(); - this_output + copy }) .collect::>(); // must collect to end borrow of `case_b` by closure case_b.finish_with_outputs(outputs).unwrap(); From 517cd0d6eb474f5026c13dac5d814ac255733008 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 14:52:30 +0000 Subject: [PATCH 103/123] Generalize test 2,3,4 copies --- hugr-passes/src/replace_types/linearize.rs | 51 ++++++++++++---------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 8179443075..728b9d9a30 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -243,9 +243,7 @@ mod test { use std::collections::HashMap; use std::sync::Arc; - use hugr_core::builder::{ - endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - }; + use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; use hugr_core::extension::prelude::usize_t; use hugr_core::extension::{TypeDefBound, Version}; @@ -253,9 +251,10 @@ mod test { use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; - use hugr_core::types::{Signature, Type, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::types::{Signature, Type}; + use hugr_core::{hugr::IdentList, Extension, HugrView}; use itertools::Itertools; + use rstest::rstest; use crate::replace_types::OpReplacement; use crate::ReplaceTypes; @@ -343,31 +342,34 @@ mod test { ); } - #[test] - fn sums() { + #[rstest] + fn sums(#[values(2, 3, 4)] num_copies: usize) { + let copy_nodes = num_copies - 1; // 2 binary copy nodes produce 3 outputs, etc. let i8_t = || INT_TYPES[3].clone(); let sum_ty = Type::new_sum([vec![i8_t()], vec![usize_t(); 2]]); - let mut outer = DFGBuilder::new(endo_sig(sum_ty.clone())).unwrap(); + let mut outer = + DFGBuilder::new(inout_sig(sum_ty.clone(), vec![sum_ty.clone(); copy_nodes])).unwrap(); let [inp] = outer.input_wires_arr(); let inner = outer .dfg_builder(inout_sig(sum_ty, vec![]), [inp]) .unwrap() .finish_with_outputs([]) .unwrap(); - let mut h = outer.finish_hugr_with_outputs([inp]).unwrap(); + let mut h = outer + .finish_hugr_with_outputs(vec![inp; copy_nodes]) + .unwrap(); let (e, lowerer) = ext_lowerer(); assert!(lowerer.run(&mut h).unwrap()); let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); let sum_ty = Type::new_sum([vec![i8_t()], vec![lin_t.clone(); 2]]); - let copy_out: TypeRow = vec![sum_ty.clone(); 2].into(); let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... - for (dfg, num_tags, out_row, ext_op_name) in [ - (inner.node(), 0, type_row![], "TestExt.discard"), - (h.root(), 2, copy_out, "TestExt.copy"), + for (dfg, num_tags, expected_ext_ops) in [ + (inner.node(), 0, vec!["TestExt.discard"; 2]), + (h.root(), num_copies, vec!["TestExt.copy"; 2 * copy_nodes]), ] { let [cond] = h .children(dfg) @@ -375,25 +377,26 @@ mod test { .collect_array() .unwrap(); let [case0, case1] = h.children(cond).collect_array().unwrap(); - // first is for empty - the only input is Copyable so can be directly wired or ignored + let out_row = vec![sum_ty.clone(); num_tags].into(); + // first is for empty variant - the only input is Copyable so can be directly wired or ignored assert_eq!(h.children(case0).count(), 2 + num_tags); // Input, Output assert_eq!(count_tags(case0), num_tags); let case0 = h.get_optype(case0).as_case().unwrap(); assert_eq!(case0.signature.io(), (&vec![i8_t()].into(), &out_row)); - // second is for two elements + // second is for variant of two elements assert_eq!(h.children(case1).count(), 4 + num_tags); // Input, Output, two leaf copies/discards: assert_eq!(count_tags(case1), num_tags); + let ext_ops = DescendantsGraph::::try_new(&h, case1) + .unwrap() + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) + .collect_vec(); + assert_eq!(ext_ops, expected_ext_ops); + + let case1 = h.get_optype(case1).as_case().unwrap(); assert_eq!( - DescendantsGraph::::try_new(&h, case1) - .unwrap() - .nodes() - .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) - .collect_vec(), - vec![ext_op_name; 2] - ); - assert_eq!( - h.get_optype(case1).as_case().unwrap().signature.io(), + case1.signature.io(), (&vec![lin_t.clone(); 2].into(), &out_row) ); } From 272f02454c45d9acfd7f5abfaeefdfcb48354a94 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 18:55:52 +0000 Subject: [PATCH 104/123] Simplify test w/binary copy to just option type, add more complex w/ n-way copy --- hugr-passes/src/replace_types/linearize.rs | 153 +++++++++++++++++---- 1 file changed, 130 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 728b9d9a30..b4288e9439 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -245,14 +245,16 @@ mod test { use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; - use hugr_core::extension::prelude::usize_t; - use hugr_core::extension::{TypeDefBound, Version}; + use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::{CustomSignatureFunc, SignatureFunc, TypeDefBound, Version}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; + use hugr_core::ops::DataflowOpTrait; use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; - use hugr_core::types::{Signature, Type}; - use hugr_core::{hugr::IdentList, Extension, HugrView}; + use hugr_core::types::type_param::TypeParam; + use hugr_core::types::{FuncValueType, Signature, Type, TypeRow}; + use hugr_core::{hugr::IdentList, Extension, Hugr, HugrView, Node}; use itertools::Itertools; use rstest::rstest; @@ -342,35 +344,138 @@ mod test { ); } - #[rstest] - fn sums(#[values(2, 3, 4)] num_copies: usize) { - let copy_nodes = num_copies - 1; // 2 binary copy nodes produce 3 outputs, etc. - let i8_t = || INT_TYPES[3].clone(); - let sum_ty = Type::new_sum([vec![i8_t()], vec![usize_t(); 2]]); - let mut outer = - DFGBuilder::new(inout_sig(sum_ty.clone(), vec![sum_ty.clone(); copy_nodes])).unwrap(); + fn copy_n_discard_one(ty: Type, n: usize) -> (Hugr, Node) { + let mut outer = DFGBuilder::new(inout_sig(ty.clone(), vec![ty.clone(); n - 1])).unwrap(); let [inp] = outer.input_wires_arr(); let inner = outer - .dfg_builder(inout_sig(sum_ty, vec![]), [inp]) + .dfg_builder(inout_sig(ty, vec![]), [inp]) .unwrap() .finish_with_outputs([]) .unwrap(); - let mut h = outer - .finish_hugr_with_outputs(vec![inp; copy_nodes]) - .unwrap(); + let h = outer.finish_hugr_with_outputs(vec![inp; n - 1]).unwrap(); + (h, inner.node()) + } + + #[rstest] + fn sums_2way_copy(#[values(2, 3, 4)] num_copies: usize) { + let (mut h, inner) = copy_n_discard_one(option_type(usize_t()).into(), num_copies); let (e, lowerer) = ext_lowerer(); assert!(lowerer.run(&mut h).unwrap()); let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); - let sum_ty = Type::new_sum([vec![i8_t()], vec![lin_t.clone(); 2]]); + let sum_ty: Type = option_type(lin_t.clone()).into(); let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... for (dfg, num_tags, expected_ext_ops) in [ - (inner.node(), 0, vec!["TestExt.discard"; 2]), - (h.root(), num_copies, vec!["TestExt.copy"; 2 * copy_nodes]), + (inner.node(), 0, vec!["TestExt.discard"]), + (h.root(), num_copies, vec!["TestExt.copy"; num_copies - 1]), // 2 copy nodes produce 3 outputs, etc. ] { + let [(cond_node, cond)] = h + .children(dfg) + .filter_map(|n| h.get_optype(n).as_conditional().map(|c| (n, c))) + .collect_array() + .unwrap(); + assert_eq!( + cond.signature().output(), + &TypeRow::from(vec![sum_ty.clone(); num_tags]) + ); + let [case0, case1] = h.children(cond_node).collect_array().unwrap(); + // first is for empty variant + assert_eq!(h.children(case0).count(), 2 + num_tags); // Input, Output + assert_eq!(count_tags(case0), num_tags); + + // second is for variant of a LIN_T + assert_eq!(h.children(case1).count(), 3 + num_tags); // Input, Output, copy/discard + assert_eq!(count_tags(case1), num_tags); + let ext_ops = DescendantsGraph::::try_new(&h, case1) + .unwrap() + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) + .collect_vec(); + assert_eq!(ext_ops, expected_ext_ops); + } + } + + #[rstest] + fn sum_nway_copy(#[values(2, 5, 9)] num_copies: usize) { + use hugr_core::{ + extension::{OpDef, SignatureError}, + types::{PolyFuncTypeRV, TypeArg}, + }; + + let i8_t = || INT_TYPES[3].clone(); + let sum_ty = Type::new_sum([vec![i8_t()], vec![usize_t(); 2]]); + + let (mut h, inner) = copy_n_discard_one(sum_ty, num_copies); + let e = Extension::new_arc( + IdentList::new_unchecked("NWay"), + Version::new(0, 0, 1), + |e, w| { + e.add_type(LIN_T.into(), vec![], String::new(), TypeDefBound::any(), w) + .unwrap(); + struct NWayCopy; + impl CustomSignatureFunc for NWayCopy { + fn compute_signature<'o, 'a: 'o>( + &'a self, + arg_values: &[TypeArg], + def: &'o OpDef, + ) -> Result { + let [TypeArg::BoundedNat { n }] = arg_values else { + panic!() + }; + let lin_t = Type::from( + def.extension() + .upgrade() + .unwrap() + .get_type(LIN_T) + .unwrap() + .instantiate([]) + .unwrap(), + ); + let outs = vec![lin_t.clone(); *n as usize]; + Ok(FuncValueType::new(lin_t, outs).into()) + } + + fn static_params(&self) -> &[TypeParam] { + const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + JUST_NAT + } + } + e.add_op( + "copy_n".into(), + String::new(), + SignatureFunc::CustomFunc(Box::new(NWayCopy)), + w, + ) + .unwrap(); + }, + ); + let mut lowerer = ReplaceTypes::default(); + let lin_t_def = e.get_type(LIN_T).unwrap(); + lowerer.replace_type( + usize_t().as_extension().unwrap().clone(), + lin_t_def.instantiate([]).unwrap().into(), + ); + let opdef = e.get_op("copy_n").unwrap(); + let opdef2 = opdef.clone(); + lowerer.linearize_parametric(lin_t_def, move |args, num_outs, _| { + assert!(args.is_empty()); + Ok(OpReplacement::SingleOp( + ExtensionOp::new(opdef2.clone(), [TypeArg::BoundedNat { n: num_outs as _ }]) + .unwrap() + .into(), + )) + }); + assert!(lowerer.run(&mut h).unwrap()); + + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let sum_ty = Type::new_sum([vec![i8_t()], vec![lin_t.clone(); 2]]); + let count_tags = |n| h.children(n).filter(|n| h.get_optype(*n).is_tag()).count(); + + // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... + for (dfg, num_tags) in [(inner.node(), 0), (h.root(), num_copies)] { let [cond] = h .children(dfg) .filter(|n| h.get_optype(*n).is_conditional()) @@ -387,12 +492,14 @@ mod test { // second is for variant of two elements assert_eq!(h.children(case1).count(), 4 + num_tags); // Input, Output, two leaf copies/discards: assert_eq!(count_tags(case1), num_tags); - let ext_ops = DescendantsGraph::::try_new(&h, case1) - .unwrap() - .nodes() - .filter_map(|n| h.get_optype(n).as_extension_op().map(ExtensionOp::name)) + let ext_ops = h + .children(case1) + .filter_map(|n| h.get_optype(n).as_extension_op()) .collect_vec(); - assert_eq!(ext_ops, expected_ext_ops); + let expected_op = + ExtensionOp::new(opdef.clone(), [TypeArg::BoundedNat { n: num_tags as _ }]) + .unwrap(); + assert_eq!(ext_ops, vec![&expected_op; 2]); let case1 = h.get_optype(case1).as_case().unwrap(); assert_eq!( From e74b12b60d46cab2d2060cb1c6a469a11f5e81ff Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 19:44:35 +0000 Subject: [PATCH 105/123] combine the two test extensions --- hugr-passes/src/replace_types/linearize.rs | 93 ++++++++-------------- 1 file changed, 34 insertions(+), 59 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index b4288e9439..e7f2dbdfc6 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -246,14 +246,16 @@ mod test { use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; use hugr_core::extension::prelude::{option_type, usize_t}; - use hugr_core::extension::{CustomSignatureFunc, SignatureFunc, TypeDefBound, Version}; + use hugr_core::extension::{ + CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, + }; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::DataflowOpTrait; use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; use hugr_core::types::type_param::TypeParam; - use hugr_core::types::{FuncValueType, Signature, Type, TypeRow}; + use hugr_core::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeRow}; use hugr_core::{hugr::IdentList, Extension, Hugr, HugrView, Node}; use itertools::Itertools; use rstest::rstest; @@ -263,8 +265,28 @@ mod test { const LIN_T: &str = "Lin"; + struct NWayCopySigFn(Type); + impl CustomSignatureFunc for NWayCopySigFn { + fn compute_signature<'o, 'a: 'o>( + &'a self, + arg_values: &[TypeArg], + _def: &'o OpDef, + ) -> Result { + let [TypeArg::BoundedNat { n }] = arg_values else { + panic!() + }; + let outs = vec![self.0.clone(); *n as usize]; + Ok(FuncValueType::new(self.0.clone(), outs).into()) + } + + fn static_params(&self) -> &[TypeParam] { + const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + JUST_NAT + } + } + fn ext_lowerer() -> (Arc, ReplaceTypes) { - // Extension with a linear type, a copy and discard op + // Extension with a linear type, an n-way parametric copy op, and a discard op let e = Extension::new_arc( IdentList::new_unchecked("TestExt"), Version::new(0, 0, 0), @@ -276,16 +298,16 @@ mod test { .unwrap(), ); e.add_op( - "copy".into(), + "discard".into(), String::new(), - Signature::new(lin.clone(), vec![lin.clone(); 2]), + Signature::new(lin.clone(), vec![]), w, ) .unwrap(); e.add_op( - "discard".into(), + "copy".into(), String::new(), - Signature::new(lin, vec![]), + SignatureFunc::CustomFunc(Box::new(NWayCopySigFn(lin))), w, ) .unwrap(); @@ -295,8 +317,8 @@ mod test { let lin_custom_t = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); let lin_t = Type::new_extension(lin_custom_t.clone()); - // Configure to lower usize_t to the linear type above - let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), []).unwrap(); + // Configure to lower usize_t to the linear type above, using a 2-way copy only + let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), [2.into()]).unwrap(); let discard_op = ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(); let mut lowerer = ReplaceTypes::default(); let usize_custom_t = usize_t().as_extension().unwrap().clone(); @@ -370,7 +392,7 @@ mod test { // Check we've inserted one Conditional into outer (for copy) and inner (for discard)... for (dfg, num_tags, expected_ext_ops) in [ (inner.node(), 0, vec!["TestExt.discard"]), - (h.root(), num_copies, vec!["TestExt.copy"; num_copies - 1]), // 2 copy nodes produce 3 outputs, etc. + (h.root(), num_copies, vec!["TestExt.copy"; num_copies - 1]), // 2 copy nodes -> 3 outputs, etc. ] { let [(cond_node, cond)] = h .children(dfg) @@ -400,65 +422,18 @@ mod test { #[rstest] fn sum_nway_copy(#[values(2, 5, 9)] num_copies: usize) { - use hugr_core::{ - extension::{OpDef, SignatureError}, - types::{PolyFuncTypeRV, TypeArg}, - }; - let i8_t = || INT_TYPES[3].clone(); let sum_ty = Type::new_sum([vec![i8_t()], vec![usize_t(); 2]]); let (mut h, inner) = copy_n_discard_one(sum_ty, num_copies); - let e = Extension::new_arc( - IdentList::new_unchecked("NWay"), - Version::new(0, 0, 1), - |e, w| { - e.add_type(LIN_T.into(), vec![], String::new(), TypeDefBound::any(), w) - .unwrap(); - struct NWayCopy; - impl CustomSignatureFunc for NWayCopy { - fn compute_signature<'o, 'a: 'o>( - &'a self, - arg_values: &[TypeArg], - def: &'o OpDef, - ) -> Result { - let [TypeArg::BoundedNat { n }] = arg_values else { - panic!() - }; - let lin_t = Type::from( - def.extension() - .upgrade() - .unwrap() - .get_type(LIN_T) - .unwrap() - .instantiate([]) - .unwrap(), - ); - let outs = vec![lin_t.clone(); *n as usize]; - Ok(FuncValueType::new(lin_t, outs).into()) - } - - fn static_params(&self) -> &[TypeParam] { - const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; - JUST_NAT - } - } - e.add_op( - "copy_n".into(), - String::new(), - SignatureFunc::CustomFunc(Box::new(NWayCopy)), - w, - ) - .unwrap(); - }, - ); + let (e, _) = ext_lowerer(); let mut lowerer = ReplaceTypes::default(); let lin_t_def = e.get_type(LIN_T).unwrap(); lowerer.replace_type( usize_t().as_extension().unwrap().clone(), lin_t_def.instantiate([]).unwrap().into(), ); - let opdef = e.get_op("copy_n").unwrap(); + let opdef = e.get_op("copy").unwrap(); let opdef2 = opdef.clone(); lowerer.linearize_parametric(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); From 34bd90d53159804f3189009d4045d62b1ab70d0f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 19:59:36 +0000 Subject: [PATCH 106/123] tweaks --- hugr-passes/src/replace_types/linearize.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index e7f2dbdfc6..db42fb5a2a 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -315,14 +315,13 @@ mod test { ); let lin_custom_t = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); - let lin_t = Type::new_extension(lin_custom_t.clone()); // Configure to lower usize_t to the linear type above, using a 2-way copy only let copy_op = ExtensionOp::new(e.get_op("copy").unwrap().clone(), [2.into()]).unwrap(); let discard_op = ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(); let mut lowerer = ReplaceTypes::default(); let usize_custom_t = usize_t().as_extension().unwrap().clone(); - lowerer.replace_type(usize_custom_t, lin_t.clone()); + lowerer.replace_type(usize_custom_t, Type::new_extension(lin_custom_t.clone())); lowerer .linearize( lin_custom_t, @@ -438,7 +437,7 @@ mod test { lowerer.linearize_parametric(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); Ok(OpReplacement::SingleOp( - ExtensionOp::new(opdef2.clone(), [TypeArg::BoundedNat { n: num_outs as _ }]) + ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) .unwrap() .into(), )) @@ -471,9 +470,7 @@ mod test { .children(case1) .filter_map(|n| h.get_optype(n).as_extension_op()) .collect_vec(); - let expected_op = - ExtensionOp::new(opdef.clone(), [TypeArg::BoundedNat { n: num_tags as _ }]) - .unwrap(); + let expected_op = ExtensionOp::new(opdef.clone(), [(num_tags as u64).into()]).unwrap(); assert_eq!(ext_ops, vec![&expected_op; 2]); let case1 = h.get_optype(case1).as_case().unwrap(); From a2cba09717eaa48ca4e77bf4aa229c66eb8da090 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 20:17:52 +0000 Subject: [PATCH 107/123] filter --- hugr-passes/src/replace_types.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 9d184fbe1f..e7de529961 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -301,9 +301,8 @@ impl ReplaceTypes { for n in hugr.nodes().collect::>() { changed |= self.change_node(hugr, n)?; let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = (changed && n != hugr.root()) - .then_some(new_dfsig) - .flatten() + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) .map(Cow::into_owned) { for outp in new_sig.output_ports() { From 30840043bba2a54c3fe73b06d6865c343dc26b34 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 20:23:57 +0000 Subject: [PATCH 108/123] so that --- hugr-passes/src/replace_types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e7de529961..9600634beb 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -194,7 +194,7 @@ impl ReplaceTypes { self.param_types.insert(src.into(), Arc::new(dest_fn)); } - /// Configures this instance that, when an outport of type `src` has other than one connected + /// Configures this instance so that, when an outport of type `src` has other than one connected /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports.) From cec9162b780da0c117c81ba3ccd33c4f1fba8a73 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 20:40:37 +0000 Subject: [PATCH 109/123] insert_copy_discard takes Wire, no Type --- hugr-passes/src/replace_types.rs | 7 +++---- hugr-passes/src/replace_types/linearize.rs | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 9600634beb..ede174acc6 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -216,7 +216,7 @@ impl ReplaceTypes { } /// Configures this instance that when lowering produces an outport which - /// * has type an instantiation of the parametric type `src`, and + /// * has type which is an instantiation of the parametric type `src`, and /// * is not [Copyable](hugr_core::types::TypeBound::Copyable), and /// * has other than one connected inport, /// @@ -310,9 +310,8 @@ impl ReplaceTypes { let targets = hugr.linked_inputs(n, outp).collect::>(); if targets.len() != 1 { hugr.disconnect(n, outp); - let typ = new_sig.out_port_type(outp).unwrap(); - self.linearize - .insert_copy_discard(hugr, n, outp, typ, &targets)?; + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; } } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index db42fb5a2a..48f160a0ea 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -7,7 +7,8 @@ use hugr_core::builder::{ }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::types::{CustomType, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; -use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node, OutgoingPort}; +use hugr_core::Wire; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; use itertools::Itertools; use super::{OpReplacement, ParametricType}; @@ -101,20 +102,24 @@ impl Linearizer { /// will be unchanged. /// /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// + /// # Panics + /// + /// if `src` is not a valid Wire (does not identify a dataflow out-port) pub fn insert_copy_discard( &self, hugr: &mut impl HugrMut, - src_node: Node, - src_port: OutgoingPort, - typ: &Type, // Or better to get the signature ourselves?? + src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap(); let (tgt_node, tgt_inport) = if targets.len() == 1 { *targets.first().unwrap() } else { // Fail fast if the edges are nonlocal. (TODO transform to local edges!) let src_parent = hugr - .get_parent(src_node) + .get_parent(src.node()) .expect("Root node cannot have out edges"); if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { let tgt_parent = hugr @@ -123,7 +128,7 @@ impl Linearizer { (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) }) { return Err(LinearizeError::NoLinearNonLocalEdges { - src: src_node, + src: src.node(), src_parent, tgt, tgt_parent, @@ -137,7 +142,7 @@ impl Linearizer { } (copy_discard_op, 0.into()) }; - hugr.connect(src_node, src_port, tgt_node, tgt_inport); + hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); Ok(()) } From bf749961020567548260d6da480d91756ba41323 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 21:12:26 +0000 Subject: [PATCH 110/123] Rename OpReplacement -> NodeTemplate --- hugr-passes/src/replace_types.rs | 63 +++++++++++----------- hugr-passes/src/replace_types/linearize.rs | 28 +++++----- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ede174acc6..c62df40921 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -29,17 +29,18 @@ use crate::validation::{ValidatePassError, ValidationLevel}; mod linearize; pub use linearize::{LinearizeError, Linearizer}; -/// A thing with which an Op (i.e. node) can be replaced +/// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] +/// or in order to replace an existing node. #[derive(Clone, Debug, PartialEq)] -pub enum OpReplacement { - /// Keep the same node, change only the op (updating types of inputs/outputs) +pub enum NodeTemplate { + /// A single node - so if replacing an existing node, change only the op SingleOp(OpType), - /// Defines a sub-Hugr to splice in place of the op - a [CFG], [Conditional], [DFG] - /// or [TailLoop], which must have the same inputs and outputs as the original op, - /// modulo replacement. + /// Defines a sub-Hugr to insert, whose root becomes (or replaces) the desired Node. + /// The root must be a [CFG], [Conditional], [DFG] or [TailLoop]. // Not a FuncDefn, nor Case/DataflowBlock - /// Note this will be of limited use before [monomorphization](super::monomorphize()) because - /// the sub-Hugr will not be able to use type variables present in the op. + /// Note this will be of limited use before [monomorphization](super::monomorphize()) + /// because the new subtree will not be able to use type variables present in the + /// parent Hugr or previous op. // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), @@ -48,13 +49,13 @@ pub enum OpReplacement { // So client should add the functions before replacement, then remove unused ones afterwards.) } -impl OpReplacement { +impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { match self { - OpReplacement::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - OpReplacement::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), + NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, } } @@ -65,16 +66,16 @@ impl OpReplacement { inputs: impl IntoIterator, ) -> Result, BuildError> { match self { - OpReplacement::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), - OpReplacement::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), + NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), } } fn replace(&self, hugr: &mut impl HugrMut, n: Node) { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { - OpReplacement::SingleOp(op_type) => op_type, - OpReplacement::CompoundOp(new_h) => { + NodeTemplate::SingleOp(op_type) => op_type, + NodeTemplate::CompoundOp(new_h) => { let new_root = hugr.insert_hugr(n, *new_h).new_root; let children = hugr.children(new_root).collect::>(); let root_opty = hugr.remove_node(new_root); @@ -95,8 +96,8 @@ pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, linearize: Linearizer, - op_map: HashMap, - param_ops: HashMap Option>>, + op_map: HashMap, + param_ops: HashMap Option>>, consts: HashMap< CustomType, Arc Result>, @@ -189,8 +190,8 @@ impl ReplaceTypes { // overapproximation. Moreover, these depend upon the *return type* of the Fn. // It would be too awkward to require: // dest_fn: impl Fn(&TypeArg) -> (Type, - // Fn(&Linearizer) -> OpReplacement, // copy - // Fn(&Linearizer) -> OpReplacement)` // discard + // Fn(&Linearizer) -> NodeTemplate, // copy + // Fn(&Linearizer) -> NodeTemplate)` // discard self.param_types.insert(src.into(), Arc::new(dest_fn)); } @@ -199,7 +200,7 @@ impl ReplaceTypes { /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports.) /// - /// The same [OpReplacement]s are also used in cases where `src` is an element of a [TypeEnum::Sum]. + /// The same [NodeTemplate]s are also used in cases where `src` is an element of a [TypeEnum::Sum]. /// /// # Errors /// @@ -209,8 +210,8 @@ impl ReplaceTypes { pub fn linearize( &mut self, src: CustomType, - copy: OpReplacement, - discard: OpReplacement, + copy: NodeTemplate, + discard: NodeTemplate, ) -> Result<(), CustomType> { self.linearize.register(src, copy, discard) } @@ -229,11 +230,11 @@ impl ReplaceTypes { /// /// The [Linearizer] is passed so that the callback can use it to generate /// `copy`/`discard` ops for other types (e.g. the elements of a collection), - /// as part of an [OpReplacement::CompoundOp]. + /// as part of an [NodeTemplate::CompoundOp]. pub fn linearize_parametric( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + 'static, ) { self.linearize.register_parametric(src, copy_discard_fn) @@ -245,7 +246,7 @@ impl ReplaceTypes { /// this should only be used on already-*[monomorphize](super::monomorphize())d* /// Hugrs, as substitution (parametric polymorphism) happening later will not respect /// this replacement. - pub fn replace_op(&mut self, src: &ExtensionOp, dest: OpReplacement) { + pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { self.op_map.insert(OpHashWrapper::from(src), dest); } @@ -258,7 +259,7 @@ impl ReplaceTypes { pub fn replace_parametrized_op( &mut self, src: &OpDef, - dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -547,7 +548,7 @@ mod test { use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; use itertools::Itertools; - use super::{handlers::list_const, OpReplacement, ReplaceTypes}; + use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; const READ: &str = "read"; @@ -608,7 +609,7 @@ mod test { } fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { + fn lowered_read(args: &[TypeArg]) -> Option { let ty = just_elem_type(args); let mut dfb = DFGBuilder::new(inout_sig( vec![array_type(64, ty.clone()), i64_t()], @@ -627,7 +628,7 @@ mod test { let [res] = dfb .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) .unwrap(); - Some(OpReplacement::CompoundOp(Box::new( + Some(NodeTemplate::CompoundOp(Box::new( dfb.finish_hugr_with_outputs([res]).unwrap(), ))) } @@ -640,7 +641,7 @@ mod test { ); lw.replace_op( &read_op(ext, bool_t()), - OpReplacement::SingleOp( + NodeTemplate::SingleOp( ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), []) .unwrap() .into(), @@ -919,7 +920,7 @@ mod test { e.get_op(READ).unwrap().as_ref(), Box::new(|args: &[TypeArg]| { option_contents(just_elem_type(args)).map(|elem| { - OpReplacement::SingleOp( + NodeTemplate::SingleOp( ListOp::get .with_type(elem) .to_extension_op() diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 48f160a0ea..404bae9d76 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -11,14 +11,14 @@ use hugr_core::Wire; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; use itertools::Itertools; -use super::{OpReplacement, ParametricType}; +use super::{NodeTemplate, ParametricType}; /// Configuration for inserting copy and discard operations for linear types /// outports of which are sources of multiple or 0 edges. #[derive(Clone, Default)] pub struct Linearizer { // Keyed by lowered type, as only needed when there is an op outputting such - copy_discard: HashMap, + copy_discard: HashMap, // Copy/discard of parametric types handled by a function that receives the new/lowered type. // We do not allow overriding copy/discard of non-extension types, but that // can be achieved by *firstly* lowering to a custom linear type, with copy/discard @@ -26,7 +26,7 @@ pub struct Linearizer { // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - Arc Result>, + Arc Result>, >, } @@ -66,8 +66,8 @@ impl Linearizer { pub fn register( &mut self, typ: CustomType, - copy: OpReplacement, - discard: OpReplacement, + copy: NodeTemplate, + discard: NodeTemplate, ) -> Result<(), CustomType> { if typ.bound() == TypeBound::Copyable { Err(typ) @@ -82,7 +82,7 @@ impl Linearizer { pub fn register_parametric( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise @@ -146,7 +146,7 @@ impl Linearizer { Ok(()) } - /// Gets an [OpReplacement] for copying or discarding a value of type `typ`, i.e. + /// Gets an [NodeTemplate] for copying or discarding a value of type `typ`, i.e. /// a recipe for a node with one input of that type and the specified number of /// outports. Note that `num_outports` should never be 1 (as no node is required) /// @@ -157,7 +157,7 @@ impl Linearizer { &self, typ: &Type, num_outports: usize, - ) -> Result { + ) -> Result { if typ.copyable() { return Err(LinearizeError::CopyableType(typ.clone())); }; @@ -206,7 +206,7 @@ impl Linearizer { .collect::>(); // must collect to end borrow of `case_b` by closure case_b.finish_with_outputs(outputs).unwrap(); } - Ok(OpReplacement::CompoundOp(Box::new( + Ok(NodeTemplate::CompoundOp(Box::new( cb.finish_hugr().unwrap(), ))) } @@ -225,7 +225,7 @@ impl Linearizer { src = out1; } outputs.push(src); - OpReplacement::CompoundOp(Box::new( + NodeTemplate::CompoundOp(Box::new( dfb.finish_hugr_with_outputs(outputs).unwrap(), )) }), @@ -265,7 +265,7 @@ mod test { use itertools::Itertools; use rstest::rstest; - use crate::replace_types::OpReplacement; + use crate::replace_types::NodeTemplate; use crate::ReplaceTypes; const LIN_T: &str = "Lin"; @@ -330,8 +330,8 @@ mod test { lowerer .linearize( lin_custom_t, - OpReplacement::SingleOp(copy_op.into()), - OpReplacement::SingleOp(discard_op.into()), + NodeTemplate::SingleOp(copy_op.into()), + NodeTemplate::SingleOp(discard_op.into()), ) .unwrap(); (e, lowerer) @@ -441,7 +441,7 @@ mod test { let opdef2 = opdef.clone(); lowerer.linearize_parametric(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); - Ok(OpReplacement::SingleOp( + Ok(NodeTemplate::SingleOp( ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) .unwrap() .into(), From 221bf7de494ff3fd6170d26311f692963917bfa7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 28 Mar 2025 21:28:05 +0000 Subject: [PATCH 111/123] fix dataflowparent doclink --- hugr-passes/src/replace_types.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index c62df40921..f96bfcc582 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -31,6 +31,8 @@ pub use linearize::{LinearizeError, Linearizer}; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. +/// +/// [DataflowParent]: hugr_core::ops::OpTag::DataflowParent #[derive(Clone, Debug, PartialEq)] pub enum NodeTemplate { /// A single node - so if replacing an existing node, change only the op From f6312bcc68e424de0f9610702a3dba233e6ffa99 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sun, 30 Mar 2025 08:49:07 +0100 Subject: [PATCH 112/123] Remove linearize() proxy methods, add linearizer() getter - docs a mess --- hugr-passes/src/replace_types.rs | 55 +++++----------------- hugr-passes/src/replace_types/linearize.rs | 35 +++++++++++++- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index f96bfcc582..5df267c45f 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -31,7 +31,7 @@ pub use linearize::{LinearizeError, Linearizer}; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. -/// +/// /// [DataflowParent]: hugr_core::ops::OpTag::DataflowParent #[derive(Clone, Debug, PartialEq)] pub enum NodeTemplate { @@ -197,50 +197,19 @@ impl ReplaceTypes { self.param_types.insert(src.into(), Arc::new(dest_fn)); } - /// Configures this instance so that, when an outport of type `src` has other than one connected - /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. - /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; - /// `discard` should have exactly one inport, of type 'src', and no outports.) - /// - /// The same [NodeTemplate]s are also used in cases where `src` is an element of a [TypeEnum::Sum]. - /// - /// # Errors - /// - /// If `src` is [Copyable], it is returned as an `Err - /// - /// [Copyable]: hugr_core::types::TypeBound::Copyable - pub fn linearize( - &mut self, - src: CustomType, - copy: NodeTemplate, - discard: NodeTemplate, - ) -> Result<(), CustomType> { - self.linearize.register(src, copy, discard) - } - - /// Configures this instance that when lowering produces an outport which - /// * has type which is an instantiation of the parametric type `src`, and - /// * is not [Copyable](hugr_core::types::TypeBound::Copyable), and + /// Allows to configure how to deal with types/wires that were [Copyable] + /// but have become linear as a result of type-changing. Specifically, + /// the [Linearizer] is used whenever lowering produces an outport which both + /// * has a non-[Copyable] type - perhaps a direct substitution, or perhaps e.g. + /// as a result of changing the element type of a collection such as an [`array`] /// * has other than one connected inport, /// - /// ...then the provided callback should be used to generate a `copy` or `discard` op, - /// passing the desired number of outports (which will never be 1). - /// - /// (That is, this is like [Self::linearize] but for parametric types and/or - /// with a callback that can generate an n-way copy directly, rather than - /// with a 0-way and 2-way copy.) - /// - /// The [Linearizer] is passed so that the callback can use it to generate - /// `copy`/`discard` ops for other types (e.g. the elements of a collection), - /// as part of an [NodeTemplate::CompoundOp]. - pub fn linearize_parametric( - &mut self, - src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result - + 'static, - ) { - self.linearize.register_parametric(src, copy_discard_fn) - } + /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// [`array`]: hugr_core::std_extensions::collections::array::array_type + pub fn linearizer( + &mut self) -> &mut Linearizer { + &mut self.linearize + } /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [OpDef], this takes diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 404bae9d76..5cfa46c184 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -58,6 +58,19 @@ pub enum LinearizeError { } impl Linearizer { + /// Configures this instance so that, when an outport of type `src` has other than one connected + /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. + /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; + /// `discard` should have exactly one inport, of type 'src', and no outports.) + /// + /// The same [NodeTemplate]s are also used in cases where `src` is an element of a [TypeEnum::Sum]. + /// + /// # Errors + /// + /// If `src` is [Copyable], it is returned as an `Err + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// Registers a type for linearization by providing copy and discard operations. /// /// # Errors @@ -77,6 +90,23 @@ impl Linearizer { } } + + /// Configures this instance that when lowering produces an outport which + /// * has type which is an instantiation of the parametric type `src`, and + /// * is not [Copyable](hugr_core::types::TypeBound::Copyable), and + /// * has other than one connected inport, + /// + /// ...then the provided callback should be used to generate a `copy` or `discard` op, + /// passing the desired number of outports (which will never be 1). + /// + /// (That is, this is like [Self::linearize] but for parametric types and/or + /// with a callback that can generate an n-way copy directly, rather than + /// with a 0-way and 2-way copy.) + /// + /// The [Linearizer] is passed so that the callback can use it to generate + /// `copy`/`discard` ops for other types (e.g. the elements of a collection), + /// as part of an [NodeTemplate::CompoundOp]. + /// Registers that instances of a parametrized [TypeDef] should be linearized /// by providing functions that generate copy and discard functions given the [TypeArg]s. pub fn register_parametric( @@ -328,7 +358,8 @@ mod test { let usize_custom_t = usize_t().as_extension().unwrap().clone(); lowerer.replace_type(usize_custom_t, Type::new_extension(lin_custom_t.clone())); lowerer - .linearize( + .linearizer() + .register( lin_custom_t, NodeTemplate::SingleOp(copy_op.into()), NodeTemplate::SingleOp(discard_op.into()), @@ -439,7 +470,7 @@ mod test { ); let opdef = e.get_op("copy").unwrap(); let opdef2 = opdef.clone(); - lowerer.linearize_parametric(lin_t_def, move |args, num_outs, _| { + lowerer.linearizer().register_parametric(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); Ok(NodeTemplate::SingleOp( ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) From 5fe25c0ba60c04f51d09f47e2135127a56c9e496 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sun, 30 Mar 2025 09:16:21 +0100 Subject: [PATCH 113/123] docs+notes --- hugr-passes/src/replace_types.rs | 39 ++++++++--- hugr-passes/src/replace_types/linearize.rs | 78 ++++++++++------------ 2 files changed, 67 insertions(+), 50 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 5df267c45f..36f37fa004 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -1,10 +1,7 @@ #![allow(clippy::type_complexity)] #![warn(missing_docs)] -//! Replace types with other types across the Hugr. +//! Replace types with other types across the Hugr. See [ReplaceTypes] and [Linearizer]. //! -//! Parametrized types and ops will be reparametrized taking into account the replacements, -//! but any ops taking/returning the replaced types *not* as a result of parametrization, -//! will also need to be replaced - see [ReplaceTypes::replace_op]. (Similarly [Const]s.) use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -31,7 +28,7 @@ pub use linearize::{LinearizeError, Linearizer}; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. -/// +/// /// [DataflowParent]: hugr_core::ops::OpTag::DataflowParent #[derive(Clone, Debug, PartialEq)] pub enum NodeTemplate { @@ -93,6 +90,31 @@ impl NodeTemplate { /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. +/// +/// Parametrized types and ops will be reparametrized taking into account the +/// replacements, but any ops taking/returning the replaced types *not* as a result of +/// parametrization, will also need to be replaced - see [Self::replace_op]. +/// Similarly [Const]s. +/// +/// Types that are [Copyable](hugr_core::types::TypeBound::Copyable) may also be replaced +/// with types that are not, see [Linearizer]. +/// +/// Note that although this pass may be used before [monomorphization], there are some +/// limitations (that do not apply if done after [monomorphization]): +/// * [NodeTemplate::CompoundOp] only works for operations that do not use type variables +/// * "Overrides" of specific instantiations of polymorphic types will not be detected if +/// the instantiations are created inside polymorphic functions. For example, suppose +/// we [Self::replace_type] type `A` with `X`, [Self::replace_parametrized_type] +/// container `MyList` with `List`, and [Self::replace_type] `MyList` with +/// `SpecialListOfXs`. If a function `foo` polymorphic over a type variable `T` dealing +/// with `MyList`s, that is called with type argument `A`, then `foo` will be +/// updated to deal with `List`s and the call `foo` updated to `foo`, but this +/// will still result in using `List` rather than `SpecialListOfXs`. (However this +/// would be fine *after* [monomorphization]: the monomorphic definition of `foo_A` +/// would use `SpecialListOfXs`.) +/// * See also limitations noted for [Linearizer]. +/// +/// [monomorphization]: super::monomorphize() #[derive(Clone, Default)] pub struct ReplaceTypes { type_map: HashMap, @@ -206,10 +228,9 @@ impl ReplaceTypes { /// /// [Copyable]: hugr_core::types::TypeBound::Copyable /// [`array`]: hugr_core::std_extensions::collections::array::array_type - pub fn linearizer( - &mut self) -> &mut Linearizer { - &mut self.linearize - } + pub fn linearizer(&mut self) -> &mut Linearizer { + &mut self.linearize + } /// Configures this instance to change occurrences of `src` to `dest`. /// Note that if `src` is an instance of a *parametrized* [OpDef], this takes diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5cfa46c184..9aa36902e6 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -13,8 +13,18 @@ use itertools::Itertools; use super::{NodeTemplate, ParametricType}; -/// Configuration for inserting copy and discard operations for linear types -/// outports of which are sources of multiple or 0 edges. +/// Configuration for inserting copy and discard operations for linear types when a +///[ReplaceTypes](super::ReplaceTypes) creates outports of these types (or of types +/// containing them) which are sources of multiple or 0 edges. +/// +/// Note that this is not really effective before [monomorphization]: if a +/// function polymorphic over a [TypeBound::Copyable] becomes called with a +/// non-Copyable type argument, [Linearizer] cannot insert copy/discard operations +/// for such a case. However, following [monomorphization], there would be a +/// specific instantiation of the function for the type-that-becomes-linear, +/// into which copy/discard can be inserted. +/// +/// [monomorphization]: crate::monomorphize() #[derive(Clone, Default)] pub struct Linearizer { // Keyed by lowered type, as only needed when there is an op outputting such @@ -58,24 +68,15 @@ pub enum LinearizeError { } impl Linearizer { - /// Configures this instance so that, when an outport of type `src` has other than one connected - /// inport, the specified `copy` and or `discard` ops should be used to wire it to those inports. - /// (`copy` should have exactly one inport, of type `src`, and two outports, of same type; - /// `discard` should have exactly one inport, of type 'src', and no outports.) - /// - /// The same [NodeTemplate]s are also used in cases where `src` is an element of a [TypeEnum::Sum]. - /// - /// # Errors - /// - /// If `src` is [Copyable], it is returned as an `Err - /// - /// [Copyable]: hugr_core::types::TypeBound::Copyable - - /// Registers a type for linearization by providing copy and discard operations. + /// Configures this instance that the specified monomorphic type can be copied and/or + /// discarded via the provided [NodeTemplate]s - directly or as part of a compound type + /// e.g. [TypeEnum::Sum]. + /// `copy` should have exactly one inport, of type `src`, and two outports, of same type; + /// `discard` should have exactly one inport, of type 'src', and no outports. /// /// # Errors /// - /// If `typ` is copyable, it is returned as an `Err`. + /// If `typ` is [Copyable](TypeBound::Copyable), it is returned as an `Err pub fn register( &mut self, typ: CustomType, @@ -90,25 +91,18 @@ impl Linearizer { } } - - /// Configures this instance that when lowering produces an outport which - /// * has type which is an instantiation of the parametric type `src`, and - /// * is not [Copyable](hugr_core::types::TypeBound::Copyable), and - /// * has other than one connected inport, + /// Configures this instance that instances of the specified [TypeDef] (perhaps + /// polymorphic) can be copied and/or discarded by using the provided callback + /// to generate a [NodeTemplate] for an appropriate copy/discard operation. /// - /// ...then the provided callback should be used to generate a `copy` or `discard` op, - /// passing the desired number of outports (which will never be 1). - /// - /// (That is, this is like [Self::linearize] but for parametric types and/or - /// with a callback that can generate an n-way copy directly, rather than - /// with a 0-way and 2-way copy.) - /// - /// The [Linearizer] is passed so that the callback can use it to generate + /// The callback is given + /// * the type arguments (if any - we do not *require* that [TypeDef] take parameters] + /// * the desired number of outports (this will never be 1) + /// * A handle to the [Linearizer], so that the callback can use it to generate /// `copy`/`discard` ops for other types (e.g. the elements of a collection), /// as part of an [NodeTemplate::CompoundOp]. - - /// Registers that instances of a parametrized [TypeDef] should be linearized - /// by providing functions that generate copy and discard functions given the [TypeArg]s. + /// + /// Note that [Self::register] takes precedence when the `src` types overlap. pub fn register_parametric( &mut self, src: &TypeDef, @@ -470,14 +464,16 @@ mod test { ); let opdef = e.get_op("copy").unwrap(); let opdef2 = opdef.clone(); - lowerer.linearizer().register_parametric(lin_t_def, move |args, num_outs, _| { - assert!(args.is_empty()); - Ok(NodeTemplate::SingleOp( - ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) - .unwrap() - .into(), - )) - }); + lowerer + .linearizer() + .register_parametric(lin_t_def, move |args, num_outs, _| { + assert!(args.is_empty()); + Ok(NodeTemplate::SingleOp( + ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) + .unwrap() + .into(), + )) + }); assert!(lowerer.run(&mut h).unwrap()); let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); From edd21269cf3f39115d807c252095df488b13ccd0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sun, 30 Mar 2025 09:30:28 +0100 Subject: [PATCH 114/123] fix lists in docs --- hugr-passes/src/replace_types.rs | 2 +- hugr-passes/src/replace_types/linearize.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 36f37fa004..d7af1aeba0 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -223,7 +223,7 @@ impl ReplaceTypes { /// but have become linear as a result of type-changing. Specifically, /// the [Linearizer] is used whenever lowering produces an outport which both /// * has a non-[Copyable] type - perhaps a direct substitution, or perhaps e.g. - /// as a result of changing the element type of a collection such as an [`array`] + /// as a result of changing the element type of a collection such as an [`array`] /// * has other than one connected inport, /// /// [Copyable]: hugr_core::types::TypeBound::Copyable diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 9aa36902e6..f19af19cae 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -99,8 +99,8 @@ impl Linearizer { /// * the type arguments (if any - we do not *require* that [TypeDef] take parameters] /// * the desired number of outports (this will never be 1) /// * A handle to the [Linearizer], so that the callback can use it to generate - /// `copy`/`discard` ops for other types (e.g. the elements of a collection), - /// as part of an [NodeTemplate::CompoundOp]. + /// `copy`/`discard` ops for other types (e.g. the elements of a collection), + /// as part of an [NodeTemplate::CompoundOp]. /// /// Note that [Self::register] takes precedence when the `src` types overlap. pub fn register_parametric( From 5ea97633813ae0245401a2f4aa5a464df02ee345 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 31 Mar 2025 19:06:49 +0100 Subject: [PATCH 115/123] Add Linearizer trait, rename to DelegatingLinearizer --- hugr-passes/src/replace_types.rs | 6 +- hugr-passes/src/replace_types/linearize.rs | 162 ++++++++++++--------- 2 files changed, 96 insertions(+), 72 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index d7af1aeba0..af379455d2 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -24,7 +24,7 @@ use hugr_core::{Hugr, Node, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; mod linearize; -pub use linearize::{LinearizeError, Linearizer}; +pub use linearize::{DelegatingLinearizer, LinearizeError, Linearizer}; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. @@ -119,7 +119,7 @@ impl NodeTemplate { pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, - linearize: Linearizer, + linearize: DelegatingLinearizer, op_map: HashMap, param_ops: HashMap Option>>, consts: HashMap< @@ -228,7 +228,7 @@ impl ReplaceTypes { /// /// [Copyable]: hugr_core::types::TypeBound::Copyable /// [`array`]: hugr_core::std_extensions::collections::array::array_type - pub fn linearizer(&mut self) -> &mut Linearizer { + pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { &mut self.linearize } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index f19af19cae..388ce6afe0 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -13,9 +13,10 @@ use itertools::Itertools; use super::{NodeTemplate, ParametricType}; -/// Configuration for inserting copy and discard operations for linear types when a -///[ReplaceTypes](super::ReplaceTypes) creates outports of these types (or of types -/// containing them) which are sources of multiple or 0 edges. +/// Trait for things that know how to wire up linear outports to other than one target. +/// Used to restore Hugr validity a [ReplaceTypes](super::ReplaceTypes) results in types +/// of such outports changing from [Copyable](TypeBound::Copyable) to linear (i.e. +/// [TypeBound::Any]). /// /// Note that this is not really effective before [monomorphization]: if a /// function polymorphic over a [TypeBound::Copyable] becomes called with a @@ -25,8 +26,84 @@ use super::{NodeTemplate, ParametricType}; /// into which copy/discard can be inserted. /// /// [monomorphization]: crate::monomorphize() +pub trait Linearizer { + /// Insert copy or discard operations (as appropriate) enough to wire `src` + /// up to all `targets`. + /// + /// The default implementation + /// * if `targets.len() == 1`, wires `src` to the unique target + /// * otherwise, makes a single call to [Self::copy_discard_op], inserts that op, + /// and wires its outputs 1:1 to each target + /// + /// # Errors + /// + /// Most variants of [LinearizeError] can be raised, specifically including + /// [LinearizeError::CopyableType] if the type is [Copyable], in which case the Hugr + /// will be unchanged. + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// + /// # Panics + /// + /// if `src` is not a valid Wire (does not identify a dataflow out-port) + fn insert_copy_discard( + &self, + hugr: &mut impl HugrMut, + src: Wire, + targets: &[(Node, IncomingPort)], + ) -> Result<(), LinearizeError> { + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap(); + let (tgt_node, tgt_inport) = if targets.len() == 1 { + *targets.first().unwrap() + } else { + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + let src_parent = hugr + .get_parent(src.node()) + .expect("Root node cannot have out edges"); + if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { + let tgt_parent = hugr + .get_parent(*tgt) + .expect("Root node cannot have incoming edges"); + (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) + }) { + return Err(LinearizeError::NoLinearNonLocalEdges { + src: src.node(), + src_parent, + tgt, + tgt_parent, + }); + } + let copy_discard_op = self + .copy_discard_op(typ, targets.len())? + .add_hugr(hugr, src_parent); + for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { + hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); + } + (copy_discard_op, 0.into()) + }; + hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); + Ok(()) + } + + /// Gets an [NodeTemplate] for copying or discarding a value of type `typ`, i.e. + /// a recipe for a node with one input of that type and the specified number of + /// outports. + /// + /// Implementations are free to panic if `num_outports == 1`, such calls should never + /// occur as source/target can be directly wired without any node/op being required. + fn copy_discard_op( + &self, + typ: &Type, + num_outports: usize, + ) -> Result; +} + +/// A configuration for implementing [CopyDiscardInserter] by delegating to +/// type-specific callbacks, and by composing them in order to handle compound types +/// such as [TypeEnum::Sum]s. #[derive(Clone, Default)] -pub struct Linearizer { +pub struct DelegatingLinearizer { // Keyed by lowered type, as only needed when there is an op outputting such copy_discard: HashMap, // Copy/discard of parametric types handled by a function that receives the new/lowered type. @@ -36,7 +113,13 @@ pub struct Linearizer { // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - Arc Result>, + Arc< + dyn Fn( + &[TypeArg], + usize, + &DelegatingLinearizer, + ) -> Result, + >, >, } @@ -67,7 +150,7 @@ pub enum LinearizeError { CopyableType(Type), } -impl Linearizer { +impl DelegatingLinearizer { /// Configures this instance that the specified monomorphic type can be copied and/or /// discarded via the provided [NodeTemplate]s - directly or as part of a compound type /// e.g. [TypeEnum::Sum]. @@ -106,7 +189,7 @@ impl Linearizer { pub fn register_parametric( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &Linearizer) -> Result + copy_discard_fn: impl Fn(&[TypeArg], usize, &DelegatingLinearizer) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise @@ -115,69 +198,10 @@ impl Linearizer { self.copy_discard_parametric .insert(src.into(), Arc::new(copy_discard_fn)); } +} - /// Insert copy or discard operations (as appropriate) enough to wire `src_port` of `src_node` - /// up to all `targets`. - /// - /// # Errors - /// - /// Most variants of [LinearizeError] can be raised, specifically including - /// [LinearizeError::CopyableType] if the type is [Copyable], in which case the Hugr - /// will be unchanged. - /// - /// [Copyable]: hugr_core::types::TypeBound::Copyable - /// - /// # Panics - /// - /// if `src` is not a valid Wire (does not identify a dataflow out-port) - pub fn insert_copy_discard( - &self, - hugr: &mut impl HugrMut, - src: Wire, - targets: &[(Node, IncomingPort)], - ) -> Result<(), LinearizeError> { - let sig = hugr.signature(src.node()).unwrap(); - let typ = sig.port_type(src.source()).unwrap(); - let (tgt_node, tgt_inport) = if targets.len() == 1 { - *targets.first().unwrap() - } else { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) - let src_parent = hugr - .get_parent(src.node()) - .expect("Root node cannot have out edges"); - if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { - let tgt_parent = hugr - .get_parent(*tgt) - .expect("Root node cannot have incoming edges"); - (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) - }) { - return Err(LinearizeError::NoLinearNonLocalEdges { - src: src.node(), - src_parent, - tgt, - tgt_parent, - }); - } - let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); - for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { - hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); - } - (copy_discard_op, 0.into()) - }; - hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); - Ok(()) - } - - /// Gets an [NodeTemplate] for copying or discarding a value of type `typ`, i.e. - /// a recipe for a node with one input of that type and the specified number of - /// outports. Note that `num_outports` should never be 1 (as no node is required) - /// - /// # Panics - /// - /// if `num_outports == 1` - pub fn copy_discard_op( +impl Linearizer for DelegatingLinearizer { + fn copy_discard_op( &self, typ: &Type, num_outports: usize, From cc4fd68470967617779264a96dbf8aefd0107b05 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 31 Mar 2025 19:40:01 +0100 Subject: [PATCH 116/123] Pass &dyn Linearizer - requires making object-safe --- hugr-passes/src/replace_types.rs | 3 ++- hugr-passes/src/replace_types/linearize.rs | 15 ++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index af379455d2..d627a597b1 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -304,7 +304,8 @@ impl ReplaceTypes { if targets.len() != 1 { hugr.disconnect(n, outp); let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; + self.linearize + .insert_copy_discard(hugr.hugr_mut(), src, &targets)?; } } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 388ce6afe0..7def721587 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -7,8 +7,7 @@ use hugr_core::builder::{ }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::types::{CustomType, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; -use hugr_core::Wire; -use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, Hugr, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; use super::{NodeTemplate, ParametricType}; @@ -48,7 +47,7 @@ pub trait Linearizer { /// if `src` is not a valid Wire (does not identify a dataflow out-port) fn insert_copy_discard( &self, - hugr: &mut impl HugrMut, + hugr: &mut Hugr, src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { @@ -113,13 +112,7 @@ pub struct DelegatingLinearizer { // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - Arc< - dyn Fn( - &[TypeArg], - usize, - &DelegatingLinearizer, - ) -> Result, - >, + Arc Result>, >, } @@ -189,7 +182,7 @@ impl DelegatingLinearizer { pub fn register_parametric( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &DelegatingLinearizer) -> Result + copy_discard_fn: impl Fn(&[TypeArg], usize, &dyn Linearizer) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise From 8a1ff0c562c5bec0494c4c65db193c11ac3f3cc1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 31 Mar 2025 19:40:05 +0100 Subject: [PATCH 117/123] No - revert - instead pass ref to new CallbackHandler struct --- hugr-passes/src/replace_types.rs | 3 +-- hugr-passes/src/replace_types/linearize.rs | 27 ++++++++++++++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index d627a597b1..af379455d2 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -304,8 +304,7 @@ impl ReplaceTypes { if targets.len() != 1 { hugr.disconnect(n, outp); let src = Wire::new(n, outp); - self.linearize - .insert_copy_discard(hugr.hugr_mut(), src, &targets)?; + self.linearize.insert_copy_discard(hugr, src, &targets)?; } } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 7def721587..131630be6e 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -7,7 +7,8 @@ use hugr_core::builder::{ }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::types::{CustomType, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; -use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, Hugr, HugrView, IncomingPort, Node, Wire}; +use hugr_core::Wire; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; use itertools::Itertools; use super::{NodeTemplate, ParametricType}; @@ -47,7 +48,7 @@ pub trait Linearizer { /// if `src` is not a valid Wire (does not identify a dataflow out-port) fn insert_copy_discard( &self, - hugr: &mut Hugr, + hugr: &mut impl HugrMut, src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { @@ -112,10 +113,16 @@ pub struct DelegatingLinearizer { // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - Arc Result>, + Arc Result>, >, } +/// Implementation of [Linearizer] passed to callbacks, (e.g.) so that callbacks for +/// handling collection types can use it to generate copy/discards of elements. +// (Note, this is its own type just to give a bit of room for future expansion, +// rather than passing a &DelegatingLinearizer directly) +pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); + #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] pub enum LinearizeError { @@ -182,7 +189,7 @@ impl DelegatingLinearizer { pub fn register_parametric( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &dyn Linearizer) -> Result + copy_discard_fn: impl Fn(&[TypeArg], usize, &CallbackHandler) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise @@ -275,7 +282,7 @@ impl Linearizer for DelegatingLinearizer { .copy_discard_parametric .get(&cty.into()) .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - copy_discard_fn(cty.args(), num_outports, self) + copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self)) } }, TypeEnum::Function(_) => panic!("Ruled out above as copyable"), @@ -284,6 +291,16 @@ impl Linearizer for DelegatingLinearizer { } } +impl<'a> Linearizer for CallbackHandler<'a> { + fn copy_discard_op( + &self, + typ: &Type, + num_outports: usize, + ) -> Result { + self.0.copy_discard_op(typ, num_outports) + } +} + #[cfg(test)] mod test { use std::collections::HashMap; From d5e4ac3c177a80bb41ba0efa23cec3fbc3892422 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 1 Apr 2025 10:17:04 +0100 Subject: [PATCH 118/123] Add LinearizeError::WrongSignature --- hugr-passes/src/replace_types.rs | 14 +++++++++++-- hugr-passes/src/replace_types/linearize.rs | 23 ++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index af379455d2..71e5056d7a 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -18,8 +18,10 @@ use hugr_core::ops::{ FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; -use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer}; -use hugr_core::{Hugr, Node, Wire}; +use hugr_core::types::{ + CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, +}; +use hugr_core::{Hugr, HugrView, Node, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; @@ -86,6 +88,14 @@ impl NodeTemplate { }; *hugr.optype_mut(n) = new_optype; } + + fn signature(&self) -> Option> { + match self { + NodeTemplate::SingleOp(op_type) => op_type, + NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + } + .dataflow_signature() + } } /// A configuration of what types, ops, and constants should be replaced with what. diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 131630be6e..5222228950 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; @@ -6,7 +7,7 @@ use hugr_core::builder::{ HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; -use hugr_core::types::{CustomType, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; +use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; use hugr_core::Wire; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; use itertools::Itertools; @@ -130,6 +131,12 @@ pub enum LinearizeError { NeedCopy(Type), #[error("Need discard op for {_0}")] NeedDiscard(Type), + #[error("Callback generated wrong signature for {typ} - requested (1 input and) {num_outports} outputs, got signature {sig:?}")] + WrongSignature { + typ: Type, + num_outports: usize, + sig: Option, + }, #[error("Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})")] NoLinearNonLocalEdges { src: Node, @@ -282,7 +289,19 @@ impl Linearizer for DelegatingLinearizer { .copy_discard_parametric .get(&cty.into()) .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; - copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self)) + let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; + let sig = tmpl.signature(); + if sig.as_ref().is_some_and(|sig| { + sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + }) { + Ok(tmpl) + } else { + Err(LinearizeError::WrongSignature { + typ: typ.clone(), + num_outports, + sig: sig.map(Cow::into_owned), + }) + } } }, TypeEnum::Function(_) => panic!("Ruled out above as copyable"), From b2594268ab1e32dd16146c4e59bf748c7d732bfb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 1 Apr 2025 10:19:41 +0100 Subject: [PATCH 119/123] Remove NeedDiscard, rename NeedCopy --- hugr-passes/src/replace_types/linearize.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5222228950..e19a81b63d 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -127,10 +127,8 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] pub enum LinearizeError { - #[error("Need copy op for {_0}")] - NeedCopy(Type), - #[error("Need discard op for {_0}")] - NeedDiscard(Type), + #[error("Need copy/discard op for {_0}")] + NeedCopyDiscard(Type), #[error("Callback generated wrong signature for {typ} - requested (1 input and) {num_outports} outputs, got signature {sig:?}")] WrongSignature { typ: Type, @@ -288,7 +286,7 @@ impl Linearizer for DelegatingLinearizer { let copy_discard_fn = self .copy_discard_parametric .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedCopy(typ.clone()))?; + .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; let sig = tmpl.signature(); if sig.as_ref().is_some_and(|sig| { From a14fafafc05114402fbb22dd1be7ed9ab9774464 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 1 Apr 2025 11:04:26 +0100 Subject: [PATCH 120/123] test, also check sig for register(CustTy, NodeTempl*2) --- hugr-passes/src/replace_types/linearize.rs | 124 +++++++++++++++++---- 1 file changed, 101 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index e19a81b63d..883a678bb1 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -7,7 +7,7 @@ use hugr_core::builder::{ HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; -use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow}; +use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::Wire; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, IncomingPort, Node}; use itertools::Itertools; @@ -129,7 +129,7 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] NeedCopyDiscard(Type), - #[error("Callback generated wrong signature for {typ} - requested (1 input and) {num_outports} outputs, got signature {sig:?}")] + #[error("Copy/discard op for {typ} with {num_outports} outputs had wrong signature {sig:?}")] WrongSignature { typ: Type, num_outports: usize, @@ -167,16 +167,18 @@ impl DelegatingLinearizer { /// If `typ` is [Copyable](TypeBound::Copyable), it is returned as an `Err pub fn register( &mut self, - typ: CustomType, + cty: CustomType, copy: NodeTemplate, discard: NodeTemplate, - ) -> Result<(), CustomType> { - if typ.bound() == TypeBound::Copyable { - Err(typ) - } else { - self.copy_discard.insert(typ, (copy, discard)); - Ok(()) + ) -> Result<(), LinearizeError> { + let typ = Type::new_extension(cty.clone()); + if typ.copyable() { + return Err(LinearizeError::CopyableType(typ)); } + check_sig(©, &typ, 2)?; + check_sig(&discard, &typ, 0)?; + self.copy_discard.insert(cty, (copy, discard)); + Ok(()) } /// Configures this instance that instances of the specified [TypeDef] (perhaps @@ -205,6 +207,21 @@ impl DelegatingLinearizer { } } +fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { + let sig = tmpl.signature(); + if sig.as_ref().is_some_and(|sig| { + sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + }) { + Ok(()) + } else { + Err(LinearizeError::WrongSignature { + typ: typ.clone(), + num_outports, + sig: sig.map(Cow::into_owned), + }) + } +} + impl Linearizer for DelegatingLinearizer { fn copy_discard_op( &self, @@ -288,18 +305,8 @@ impl Linearizer for DelegatingLinearizer { .get(&cty.into()) .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(tmpl) - } else { - Err(LinearizeError::WrongSignature { - typ: typ.clone(), - num_outports, - sig: sig.map(Cow::into_owned), - }) - } + check_sig(&tmpl, typ, num_outports)?; + Ok(tmpl) } }, TypeEnum::Function(_) => panic!("Ruled out above as copyable"), @@ -330,8 +337,8 @@ mod test { CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, }; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; - use hugr_core::ops::DataflowOpTrait; use hugr_core::ops::{handle::NodeHandle, ExtensionOp, NamedOp, OpName}; + use hugr_core::ops::{DataflowOpTrait, OpType}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::{array_type, ArrayOpDef}; use hugr_core::types::type_param::TypeParam; @@ -340,7 +347,7 @@ mod test { use itertools::Itertools; use rstest::rstest; - use crate::replace_types::NodeTemplate; + use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; use crate::ReplaceTypes; const LIN_T: &str = "Lin"; @@ -563,4 +570,75 @@ mod test { ); } } + + #[test] + fn bad_sig() { + // Change usize to QB_T + let (ext, _) = ext_lowerer(); + let lin_ct = ext.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t = Type::from(lin_ct.clone()); + let copy3 = OpType::from( + ExtensionOp::new(ext.get_op("copy").unwrap().clone(), [3.into()]).unwrap(), + ); + let copy2 = ExtensionOp::new(ext.get_op("copy").unwrap().clone(), [2.into()]).unwrap(); + let discard = ExtensionOp::new(ext.get_op("discard").unwrap().clone(), []).unwrap(); + let mut replacer = ReplaceTypes::default(); + replacer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + + let bad_copy = replacer.linearizer().register( + lin_ct.clone(), + NodeTemplate::SingleOp(copy3.clone()), + NodeTemplate::SingleOp(discard.clone().into()), + ); + let sig3 = Some( + Signature::new(lin_t.clone(), vec![lin_t.clone(); 3]) + .with_extension_delta(ext.name().clone()), + ); + assert_eq!( + bad_copy, + Err(LinearizeError::WrongSignature { + typ: lin_t.clone(), + num_outports: 2, + sig: sig3.clone() + }) + ); + + let bad_discard = replacer.linearizer().register( + lin_ct.clone(), + NodeTemplate::SingleOp(copy2.into()), + NodeTemplate::SingleOp(copy3.clone()), + ); + + assert_eq!( + bad_discard, + Err(LinearizeError::WrongSignature { + typ: lin_t.clone(), + num_outports: 0, + sig: sig3.clone() + }) + ); + + // Try parametrized instead, but this version always returns 3 outports + replacer + .linearizer() + .register_parametric(ext.get_type(LIN_T).unwrap(), move |_args, _, _| { + Ok(NodeTemplate::SingleOp(copy3.clone())) + }); + + // A hugr that copies a usize + let dfb = DFGBuilder::new(inout_sig(usize_t(), vec![usize_t(); 2])).unwrap(); + let [inp] = dfb.input_wires_arr(); + let mut h = dfb.finish_hugr_with_outputs([inp, inp]).unwrap(); + + assert_eq!( + replacer.run(&mut h), + Err(ReplaceTypesError::LinearizeError( + LinearizeError::WrongSignature { + typ: lin_t.clone(), + num_outports: 2, + sig: sig3.clone() + } + )) + ); + } } From a4686d3ada54d9454e342e370da424f010689e4e Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 2 Apr 2025 08:11:48 +0100 Subject: [PATCH 121/123] lint --- hugr-passes/src/replace_types/linearize.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 883a678bb1..2d0e0a3bf5 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -315,7 +315,7 @@ impl Linearizer for DelegatingLinearizer { } } -impl<'a> Linearizer for CallbackHandler<'a> { +impl Linearizer for CallbackHandler<'_> { fn copy_discard_op( &self, typ: &Type, From 1abf1b4b0cc6f24e5f87c9f3fba327795f92d985 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Wed, 2 Apr 2025 08:21:46 +0100 Subject: [PATCH 122/123] docs --- hugr-passes/src/replace_types/linearize.rs | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 2d0e0a3bf5..539b223276 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -14,17 +14,19 @@ use itertools::Itertools; use super::{NodeTemplate, ParametricType}; -/// Trait for things that know how to wire up linear outports to other than one target. -/// Used to restore Hugr validity a [ReplaceTypes](super::ReplaceTypes) results in types -/// of such outports changing from [Copyable](TypeBound::Copyable) to linear (i.e. -/// [TypeBound::Any]). +/// Trait for things that know how to wire up linear outports to other than one +/// target. Used to restore Hugr validity a [ReplaceTypes](super::ReplaceTypes) +/// results in types of such outports changing from +/// [Copyable](hugr_core::types::TypeBound::Copyable) to linear (i.e. +/// [hugr_core::types::TypeBound::Any]). /// /// Note that this is not really effective before [monomorphization]: if a -/// function polymorphic over a [TypeBound::Copyable] becomes called with a -/// non-Copyable type argument, [Linearizer] cannot insert copy/discard operations -/// for such a case. However, following [monomorphization], there would be a -/// specific instantiation of the function for the type-that-becomes-linear, -/// into which copy/discard can be inserted. +/// function polymorphic over a +/// [Copyable](hugr_core::types::TypeBound::Copyable) becomes called with a +/// non-Copyable type argument, [Linearizer] cannot insert copy/discard +/// operations for such a case. However, following [monomorphization], there +/// would be a specific instantiation of the function for the +/// type-that-becomes-linear, into which copy/discard can be inserted. /// /// [monomorphization]: crate::monomorphize() pub trait Linearizer { @@ -100,7 +102,7 @@ pub trait Linearizer { ) -> Result; } -/// A configuration for implementing [CopyDiscardInserter] by delegating to +/// A configuration for implementing [Linearizer] by delegating to /// type-specific callbacks, and by composing them in order to handle compound types /// such as [TypeEnum::Sum]s. #[derive(Clone, Default)] @@ -164,7 +166,7 @@ impl DelegatingLinearizer { /// /// # Errors /// - /// If `typ` is [Copyable](TypeBound::Copyable), it is returned as an `Err + /// If `typ` is [Copyable](hugr_core::types::TypeBound::Copyable), it is returned as an `Err pub fn register( &mut self, cty: CustomType, From 7e3efcc55482a09764d438a487a2b18ce40028bb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Apr 2025 10:14:52 +0100 Subject: [PATCH 123/123] Rename register(=>_simple,_parametric=>_callback), docs --- hugr-passes/src/replace_types.rs | 2 +- hugr-passes/src/replace_types/linearize.rs | 34 ++++++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 71e5056d7a..480975fabd 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -26,7 +26,7 @@ use hugr_core::{Hugr, HugrView, Node, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; mod linearize; -pub use linearize::{DelegatingLinearizer, LinearizeError, Linearizer}; +pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 539b223276..371798dceb 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -15,20 +15,19 @@ use itertools::Itertools; use super::{NodeTemplate, ParametricType}; /// Trait for things that know how to wire up linear outports to other than one -/// target. Used to restore Hugr validity a [ReplaceTypes](super::ReplaceTypes) -/// results in types of such outports changing from -/// [Copyable](hugr_core::types::TypeBound::Copyable) to linear (i.e. +/// target. Used to restore Hugr validity when a [ReplaceTypes](super::ReplaceTypes) +/// results in types of such outports changing from [Copyable] to linear (i.e. /// [hugr_core::types::TypeBound::Any]). /// /// Note that this is not really effective before [monomorphization]: if a -/// function polymorphic over a -/// [Copyable](hugr_core::types::TypeBound::Copyable) becomes called with a +/// function polymorphic over a [Copyable] becomes called with a /// non-Copyable type argument, [Linearizer] cannot insert copy/discard /// operations for such a case. However, following [monomorphization], there /// would be a specific instantiation of the function for the /// type-that-becomes-linear, into which copy/discard can be inserted. /// /// [monomorphization]: crate::monomorphize() +/// [Copyable]: hugr_core::types::TypeBound::Copyable pub trait Linearizer { /// Insert copy or discard operations (as appropriate) enough to wire `src` /// up to all `targets`. @@ -166,8 +165,11 @@ impl DelegatingLinearizer { /// /// # Errors /// - /// If `typ` is [Copyable](hugr_core::types::TypeBound::Copyable), it is returned as an `Err - pub fn register( + /// * [LinearizeError::CopyableType] If `typ` is + /// [Copyable](hugr_core::types::TypeBound::Copyable) + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the + /// expected inputs or outputs + pub fn register_simple( &mut self, cty: CustomType, copy: NodeTemplate, @@ -188,14 +190,14 @@ impl DelegatingLinearizer { /// to generate a [NodeTemplate] for an appropriate copy/discard operation. /// /// The callback is given - /// * the type arguments (if any - we do not *require* that [TypeDef] take parameters] + /// * the type arguments (as appropriate for the [TypeDef], so perhaps empty) /// * the desired number of outports (this will never be 1) - /// * A handle to the [Linearizer], so that the callback can use it to generate + /// * A [CallbackHandler] that the callback can use it to generate /// `copy`/`discard` ops for other types (e.g. the elements of a collection), /// as part of an [NodeTemplate::CompoundOp]. /// - /// Note that [Self::register] takes precedence when the `src` types overlap. - pub fn register_parametric( + /// Note that [Self::register_simple] takes precedence when the `src` types overlap. + pub fn register_callback( &mut self, src: &TypeDef, copy_discard_fn: impl Fn(&[TypeArg], usize, &CallbackHandler) -> Result @@ -413,7 +415,7 @@ mod test { lowerer.replace_type(usize_custom_t, Type::new_extension(lin_custom_t.clone())); lowerer .linearizer() - .register( + .register_simple( lin_custom_t, NodeTemplate::SingleOp(copy_op.into()), NodeTemplate::SingleOp(discard_op.into()), @@ -526,7 +528,7 @@ mod test { let opdef2 = opdef.clone(); lowerer .linearizer() - .register_parametric(lin_t_def, move |args, num_outs, _| { + .register_callback(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); Ok(NodeTemplate::SingleOp( ExtensionOp::new(opdef2.clone(), [(num_outs as u64).into()]) @@ -587,7 +589,7 @@ mod test { let mut replacer = ReplaceTypes::default(); replacer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); - let bad_copy = replacer.linearizer().register( + let bad_copy = replacer.linearizer().register_simple( lin_ct.clone(), NodeTemplate::SingleOp(copy3.clone()), NodeTemplate::SingleOp(discard.clone().into()), @@ -605,7 +607,7 @@ mod test { }) ); - let bad_discard = replacer.linearizer().register( + let bad_discard = replacer.linearizer().register_simple( lin_ct.clone(), NodeTemplate::SingleOp(copy2.into()), NodeTemplate::SingleOp(copy3.clone()), @@ -623,7 +625,7 @@ mod test { // Try parametrized instead, but this version always returns 3 outports replacer .linearizer() - .register_parametric(ext.get_type(LIN_T).unwrap(), move |_args, _, _| { + .register_callback(ext.get_type(LIN_T).unwrap(), move |_args, _, _| { Ok(NodeTemplate::SingleOp(copy3.clone())) });