diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 014b8cd977..49562d873e 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -481,8 +481,8 @@ pub(crate) mod test { use crate::metadata::Metadata; use crate::ops::{FuncDecl, FuncDefn, OpParent, OpTag, OpTrait, Value, handle::NodeHandle}; use crate::std_extensions::logic::test::and_op; - use crate::types::type_param::TypeParam; - use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV}; + use crate::types::type_param::{TermTypeError, TypeParam}; + use crate::types::{EdgeKind, FuncValueType, Signature, Term, Type, TypeBound, TypeRowRV}; use crate::utils::test_quantum_extension::h_gate; use crate::{Wire, builder::test::n_identity, type_row}; @@ -926,7 +926,7 @@ pub(crate) mod test { #[test] fn no_outer_row_variables() -> Result<(), BuildError> { let e = crate::hugr::validate::test::extension_with_eval_parallel(); - let tv = TypeRV::new_row_var_use(0, TypeBound::Copyable); + let rv = Term::new_row_var_use(0, TypeBound::Copyable); // Can *declare* a function that takes a function-value of unknown #args FunctionBuilder::new( "bad_eval", @@ -935,7 +935,7 @@ pub(crate) mod test { Signature::new( [Type::new_function(FuncValueType::new( [usize_t()], - [tv.clone()], + TypeRowRV::try_from(rv.clone()).unwrap(), ))], [], ), @@ -943,15 +943,23 @@ pub(crate) mod test { )?; // But cannot eval it... - let ev = e.instantiate_extension_op( - "eval", - [vec![usize_t().into()].into(), vec![tv.into()].into()], + let ev = e.instantiate_extension_op("eval", [vec![usize_t()].into(), rv.clone()]); + assert_eq!( + ev, + Err(SignatureError::TypeArgMismatch( + TermTypeError::InvalidValue(Box::new(rv.clone())) + )) ); + + let ev = e.instantiate_extension_op("eval", [vec![usize_t()].into(), [rv.clone()].into()]); assert_eq!( ev, - Err(SignatureError::RowVarWhereTypeExpected { - var: RowVariable(0, TypeBound::Copyable) - }) + Err(SignatureError::TypeArgMismatch( + TermTypeError::TypeMismatch { + term: Box::new(rv), + type_: Box::new(TypeBound::Linear.into()) + } + )) ); Ok(()) } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 1395c6e065..0a34fb1d79 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -2,7 +2,6 @@ use crate::Visibility; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; -use crate::types::type_param::Term; use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port, extension::{ExtensionId, OpDef, SignatureFunc}, @@ -14,10 +13,8 @@ use crate::{ arithmetic::{float_types::ConstF64, int_types::ConstInt}, collections::array::ArrayValue, }, - types::{ - CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase, - }, + types::type_param::{Term, TermVar}, + types::{CustomType, EdgeKind, FuncValueType, Signature, SumType, Type, TypeBound, TypeRow}, }; use hugr_model::v0::bumpalo; @@ -342,10 +339,12 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let symbol_name = this.export_func_name(node, &mut meta); - let symbol = this.export_poly_func_type( + let sig = func.signature(); + let symbol = this.export_symbol_params( symbol_name, Some(func.visibility().clone().into()), - func.signature(), + sig.params(), + |this| this.export_signature(sig.body()), ); regions = this.bump.alloc_slice_copy(&[this.export_dfg( node, @@ -358,11 +357,12 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let symbol_name = this.export_func_name(node, &mut meta); - - let symbol = this.export_poly_func_type( + let sig = func.signature(); + let symbol = this.export_symbol_params( symbol_name, Some(func.visibility().clone().into()), - func.signature(), + sig.params(), + |this| this.export_signature(sig.body()), ); table::Operation::DeclareFunc(symbol) }), @@ -507,7 +507,7 @@ impl<'a> Context<'a> { Some(signature) => { let num_inputs = signature.input_types().len(); let num_outputs = signature.output_types().len(); - let signature = self.export_func_type(signature); + let signature = self.export_signature(signature); (Some(signature), num_inputs, num_outputs) } None => (None, 0, 0), @@ -559,7 +559,9 @@ impl<'a> Context<'a> { let symbol = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); - this.export_poly_func_type(name, None, poly_func_type) + this.export_symbol_params(name, None, poly_func_type.params(), |this| { + this.export_func_type(poly_func_type.body()) + }) }); let meta = { @@ -816,60 +818,52 @@ impl<'a> Context<'a> { } /// Exports a polymorphic function type. - pub fn export_poly_func_type( + pub fn export_symbol_params( &mut self, name: &'a str, visibility: Option, - t: &PolyFuncTypeBase, + params: &[Term], + export_body: impl FnOnce(&mut Self) -> table::TermId, ) -> &'a table::Symbol<'a> { - let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let mut param_vec = BumpVec::with_capacity_in(params.len(), self.bump); let scope = self .local_scope .expect("exporting poly func type outside of local scope"); let visibility = self.bump.alloc(visibility); - for (i, param) in t.params().iter().enumerate() { + for (i, param) in params.iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); let r#type = self.export_term(param, Some((scope, i as _))); let param = table::Param { name, r#type }; - params.push(param); + param_vec.push(param); } let constraints = self.bump.alloc_slice_copy(&self.local_constraints); - let body = self.export_func_type(t.body()); + let body = export_body(self); self.bump.alloc(table::Symbol { visibility, name, - params: params.into_bump_slice(), + params: param_vec.into_bump_slice(), constraints, signature: body, }) } - pub fn export_type(&mut self, t: &TypeBase) -> table::TermId { - self.export_type_enum(t.as_type_enum()) - } - - pub fn export_type_enum(&mut self, t: &TypeEnum) -> table::TermId { - match t { - TypeEnum::Extension(ext) => self.export_custom_type(ext), - TypeEnum::Alias(alias) => { - let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); - self.make_term(table::Term::Apply(symbol, &[])) - } - TypeEnum::Function(func) => self.export_func_type(func), - TypeEnum::Variable(index, _) => { - let node = self.local_scope.expect("local variable out of scope"); - self.make_term(table::Term::Var(table::VarId(node, *index as _))) - } - TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), - TypeEnum::Sum(sum) => self.export_sum_type(sum), - } + pub fn export_type(&mut self, t: &Type) -> table::TermId { + self.export_term(t, None) } - pub fn export_func_type(&mut self, t: &FuncTypeBase) -> table::TermId { + pub fn export_signature(&mut self, t: &Signature) -> table::TermId { let inputs = self.export_type_row(t.input()); let outputs = self.export_type_row(t.output()); + // Ok to use CORE_FN here: the elements of the row will be exported inside a List + self.make_term_apply(model::CORE_FN, &[inputs, outputs]) + } + + pub fn export_func_type(&mut self, t: &FuncValueType) -> table::TermId { + let inputs = self.export_term(t.input(), None); + let outputs = self.export_term(t.output(), None); + // Ok to use CORE_FN here: the input/output should each be a core List or ListConcat self.make_term_apply(model::CORE_FN, &[inputs, outputs]) } @@ -888,28 +882,12 @@ impl<'a> Context<'a> { self.make_term(table::Term::Var(table::VarId(node, var.index() as _))) } - pub fn export_row_var(&mut self, t: &RowVariable) -> table::TermId { - let node = self.local_scope.expect("local variable out of scope"); - self.make_term(table::Term::Var(table::VarId(node, t.0 as _))) - } - pub fn export_sum_variants(&mut self, t: &SumType) -> table::TermId { - match t { - SumType::Unit { size } => { - let parts = self.bump.alloc_slice_fill_iter( - (0..*size) - .map(|_| table::SeqPart::Item(self.make_term(table::Term::List(&[])))), - ); - self.make_term(table::Term::List(parts)) - } - SumType::General { rows } => { - let parts = self.bump.alloc_slice_fill_iter( - rows.iter() - .map(|row| table::SeqPart::Item(self.export_type_row(row))), - ); - self.make_term(table::Term::List(parts)) - } - } + // Sadly we cannot use alloc_slice_fill_iter because SumType::variants is not an ExactSizeIterator. + let parts = self.bump.alloc_slice_fill_with(t.num_variants(), |i| { + table::SeqPart::Item(self.export_term(t.get_variant(i).unwrap(), None)) + }); + self.make_term(table::Term::List(parts)) } pub fn export_sum_type(&mut self, t: &SumType) -> table::TermId { @@ -918,27 +896,20 @@ impl<'a> Context<'a> { } #[inline] - pub fn export_type_row(&mut self, row: &TypeRowBase) -> table::TermId { + pub fn export_type_row(&mut self, row: &TypeRow) -> table::TermId { self.export_type_row_with_tail(row, None) } - pub fn export_type_row_with_tail( + pub fn export_type_row_with_tail( &mut self, - row: &TypeRowBase, + row: &TypeRow, tail: Option, ) -> table::TermId { let mut parts = BumpVec::with_capacity_in(row.len() + usize::from(tail.is_some()), self.bump); for t in row.iter() { - match t.as_type_enum() { - TypeEnum::RowVar(var) => { - parts.push(table::SeqPart::Splice(self.export_row_var(var.as_rv()))); - } - _ => { - parts.push(table::SeqPart::Item(self.export_type(t))); - } - } + parts.push(table::SeqPart::Item(self.export_type(t))); } if let Some(tail) = tail { @@ -982,7 +953,14 @@ impl<'a> Context<'a> { let item_types = self.export_term(item_types, None); self.make_term_apply(model::CORE_TUPLE_TYPE, &[item_types]) } - Term::Runtime(ty) => self.export_type(ty), + Term::RuntimeExtension(ext) => self.export_custom_type(ext), + /*TypeEnum::Alias(alias) => { + let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); + self.make_term(table::Term::Apply(symbol, &[])) + }*/ + Term::RuntimeFunction(func) => self.export_func_type(func), + Term::RuntimeSum(sum) => self.export_sum_type(sum), + Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()), Term::Float(value) => self.make_term(model::Literal::Float(*value).into()), diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 3ba9650b25..287bb80711 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -125,7 +125,6 @@ use thiserror::Error; use crate::hugr::IdentList; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; -use crate::types::RowVariable; use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; @@ -518,9 +517,10 @@ pub enum SignatureError { /// A type variable that was used has not been declared #[error("Type variable {idx} was not declared ({num_decls} in scope)")] FreeTypeVar { idx: usize, num_decls: usize }, - /// A row variable was found outside of a variable-length row - #[error("Expected a single type, but found row variable {var}")] - RowVarWhereTypeExpected { var: RowVariable }, + // ALAN this is now just another TypeArgMismatch + // A row variable was found outside of a variable-length row + //#[error("Expected a single type, but found row variable {var}")] + //RowVarWhereTypeExpected { var: RowVariable }, /// The result of the type application stored in a [Call] /// is not what we get by applying the type-args to the polymorphic function /// diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 0cd7f508fd..22c53e3828 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -16,7 +16,7 @@ use crate::envelope::serde_with::AsBinaryEnvelope; use crate::ops::{OpName, OpNameRef}; use crate::package::Package; use crate::types::type_param::{TypeArg, TypeParam, check_term_types}; -use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; +use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Substitutable}; mod serialize_signature_func; /// Trait necessary for binary computations of `OpDef` signature @@ -614,7 +614,7 @@ pub(super) mod test { use crate::package::Package; use crate::std_extensions::collections::list; use crate::types::type_param::{TermTypeError, TypeParam}; - use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; + use crate::types::{PolyFuncTypeRV, Signature, Term, Type, TypeArg, TypeBound}; use crate::{Extension, const_extension_ids}; const_extension_ids! { @@ -862,7 +862,7 @@ pub(super) mod test { def.validate_args(&args, &decls).unwrap(); assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo([tv]))); // But not with an external row variable - let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); + let arg: TypeArg = Term::new_row_var_use(0, TypeBound::Copyable); assert_eq!( def.compute_signature(std::slice::from_ref(&arg)), Err(SignatureError::TypeArgMismatch( diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index cfd09b30d1..7921e2845d 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -18,7 +18,7 @@ use crate::ops::{NamedOp, Value}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, - TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, + TypeBound, TypeName, TypeRow, TypeRowRV, }; use crate::utils::sorted_consts; use crate::{Extension, type_row}; @@ -107,23 +107,22 @@ pub static PRELUDE: LazyLock> = LazyLock::new(|| { extension_ref, ) .unwrap(); + let panic_exit_sig = PolyFuncTypeRV::new( + [ + TypeParam::new_list_type(TypeBound::Linear), + TypeParam::new_list_type(TypeBound::Linear), + ], + FuncValueType::new( + TypeRowRV::from([Type::new_extension(error_type.clone())]) + .concat(TypeRowRV::just_row_var(0, TypeBound::Linear)), + TypeRowRV::just_row_var(1, TypeBound::Linear), + ), + ); prelude .add_op( PANIC_OP_ID, "Panic with input error".to_string(), - PolyFuncTypeRV::new( - [ - TypeParam::new_list_type(TypeBound::Linear), - TypeParam::new_list_type(TypeBound::Linear), - ], - FuncValueType::new( - vec![ - TypeRV::new_extension(error_type.clone()), - TypeRV::new_row_var_use(0, TypeBound::Linear), - ], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], - ), - ), + panic_exit_sig.clone(), extension_ref, ) .unwrap(); @@ -131,19 +130,7 @@ pub static PRELUDE: LazyLock> = LazyLock::new(|| { .add_op( EXIT_OP_ID, "Exit with input error".to_string(), - PolyFuncTypeRV::new( - [ - TypeParam::new_list_type(TypeBound::Linear), - TypeParam::new_list_type(TypeBound::Linear), - ], - FuncValueType::new( - vec![ - TypeRV::new_extension(error_type), - TypeRV::new_row_var_use(0, TypeBound::Linear), - ], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], - ), - ), + panic_exit_sig, extension_ref, ) .unwrap(); @@ -334,7 +321,7 @@ pub fn either_type(ty_left: impl Into, ty_right: impl Into /// A constant optional value with a given value. /// -/// See [`option_type`]. +/// See [`SumType::new_option`]. #[must_use] pub fn const_some(value: Value) -> Value { const_some_tuple([value]) @@ -344,7 +331,7 @@ pub fn const_some(value: Value) -> Value { /// /// For single values, use [`const_some`]. /// -/// See [`option_type`]. +/// See [`SumType::new_option`]. pub fn const_some_tuple(values: impl IntoIterator) -> Value { const_right_tuple(TypeRow::new(), values) } @@ -375,11 +362,7 @@ pub fn const_left_tuple( ty_right: impl Into, ) -> Value { let values = values.into_iter().collect_vec(); - let types: TypeRowRV = values - .iter() - .map(|v| TypeRV::from(v.get_type())) - .collect_vec() - .into(); + let types: TypeRowRV = values.iter().map(|v| v.get_type()).collect_vec().into(); let typ = either_type(types, ty_right); Value::sum(0, values, typ).unwrap() } @@ -403,11 +386,7 @@ pub fn const_right_tuple( values: impl IntoIterator, ) -> Value { let values = values.into_iter().collect_vec(); - let types: TypeRowRV = values - .iter() - .map(|v| TypeRV::from(v.get_type())) - .collect_vec() - .into(); + let types: TypeRowRV = values.iter().map(|v| v.get_type()).collect_vec().into(); let typ = either_type(ty_left, types); Value::sum(1, values, typ).unwrap() } @@ -642,16 +621,16 @@ impl MakeOpDef for TupleOpDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - let rv = TypeRV::new_row_var_use(0, TypeBound::Linear); - let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); + let rv = TypeRowRV::just_row_var(0, TypeBound::Linear); + let tuple_type = Type::new_tuple(rv.clone()); let param = TypeParam::new_list_type(TypeBound::Linear); match self { TupleOpDef::MakeTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new([rv], [tuple_type])) + PolyFuncTypeRV::new([param], FuncValueType::new(rv, [tuple_type])) } TupleOpDef::UnpackTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new([tuple_type], [rv])) + PolyFuncTypeRV::new([param], FuncValueType::new([tuple_type], rv)) } } .into() @@ -711,14 +690,8 @@ impl MakeExtensionOp for MakeTuple { let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self(tys?.into())) + let tys = elems.clone().try_into().map_err(SignatureError::from)?; + Ok(Self(tys)) } fn type_args(&self) -> Vec { @@ -766,14 +739,8 @@ impl MakeExtensionOp for UnpackTuple { let [Term::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - Term::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self(tys?.into())) + let tys = elems.clone().try_into().map_err(SignatureError::from)?; + Ok(Self(tys)) } fn type_args(&self) -> Vec { @@ -881,10 +848,10 @@ impl MakeExtensionOp for Noop { Self: Sized, { let _def = NoopDef::from_def(ext_op.def())?; - let [TypeArg::Runtime(ty)] = ext_op.args() else { + let [t] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - Ok(Self(ty.clone())) + Ok(Self(t.clone().try_into().map_err(SignatureError::from)?)) } fn type_args(&self) -> Vec { @@ -929,7 +896,7 @@ impl MakeOpDef for BarrierDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( vec![TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo([TypeRV::new_row_var_use(0, TypeBound::Linear)]), + FuncValueType::new_endo(TypeRowRV::just_row_var(0, TypeBound::Linear)), ) .into() } @@ -990,16 +957,8 @@ impl MakeExtensionOp for Barrier { let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self { - type_row: tys?.into(), - }) + let type_row = elems.clone().try_into().map_err(SignatureError::from)?; + Ok(Self { type_row }) } fn type_args(&self) -> Vec { diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index f73b5ce600..9885be5e31 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -3,7 +3,10 @@ use std::iter; use crate::{ Wire, builder::{BuildError, BuildHandle, Dataflow, DataflowSubContainer, SubContainer}, - extension::prelude::{ConstError, PANIC_OP_ID}, + extension::{ + SignatureError, + prelude::{ConstError, PANIC_OP_ID}, + }, ops::handle::DataflowOpID, types::{SumType, Type, TypeArg, TypeRow}, }; @@ -74,7 +77,8 @@ pub trait UnwrapBuilder: Dataflow { let tr_rv = sum_type.get_variant(i).unwrap().to_owned(); TypeRow::try_from(tr_rv) }) - .collect::>()?; + .collect::>() + .map_err(SignatureError::from)?; // TODO don't panic if tag >= num_variants let output_row = variants.get(tag).unwrap(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 9fa5ad0772..3f852fd789 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -25,7 +25,7 @@ mod weak_registry; pub use weak_registry::WeakExtensionRegistry; pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; -pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_type_exts}; +pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_term_exts}; pub(crate) use types_mut::resolve_op_types_extensions; use types_mut::{ resolve_custom_type_exts, resolve_term_exts, resolve_type_exts, resolve_value_exts, @@ -39,11 +39,11 @@ use crate::core::HugrNode; use crate::ops::constant::ValueName; use crate::ops::custom::OpaqueOpError; use crate::ops::{NamedOp, OpName, OpType, Value}; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, TypeArg, TypeBase, TypeName}; +use crate::types::{CustomType, Signature, Type, TypeArg, TypeName}; /// Update all weak Extension pointers inside a type. -pub fn resolve_type_extensions( - typ: &mut TypeBase, +pub fn resolve_type_extensions( + typ: &mut Type, extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = WeakExtensionRegistry::default(); @@ -257,8 +257,8 @@ impl ExtensionCollectionError { } /// Create a new error when signature extensions have been dropped. - pub fn dropped_signature( - signature: &FuncTypeBase, + pub fn dropped_signature( + signature: &Signature, missing_extension: impl IntoIterator, ) -> Self { Self::DroppedSignatureExtensions { @@ -268,8 +268,8 @@ impl ExtensionCollectionError { } /// Create a new error when signature extensions have been dropped. - pub fn dropped_type( - typ: &TypeBase, + pub fn dropped_type( + typ: &Type, missing_extension: impl IntoIterator, ) -> Self { Self::DroppedTypeExtensions { diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 01fdf109ac..edae047f48 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -7,12 +7,12 @@ use std::mem; use std::sync::Arc; +use crate::extension::resolution::types::collect_func_type_exts; use crate::extension::{ Extension, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, TypeDef, }; -use super::types::collect_signature_exts; -use super::types_mut::resolve_signature_exts; +use super::types_mut::resolve_func_type_exts; use super::{ExtensionCollectionError, ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -76,7 +76,7 @@ fn collect_extension_deps( for (_, op_def) in extension.operations() { if let Some(signature) = op_def.signature_func().poly_func_type() { let mut local_missing = ExtensionSet::new(); - collect_signature_exts(signature.body(), &mut used, &mut local_missing); + collect_func_type_exts(signature.body(), &mut used, &mut local_missing); for ext in local_missing { missing.insert(ext); } @@ -207,5 +207,5 @@ pub(super) fn resolve_signature_func_exts( return Ok(()); } }; - resolve_signature_exts(None, signature_body, extensions, used_extensions) + resolve_func_type_exts(None, signature_body, extensions, used_extensions) } diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index ceb8590f6d..92431ccbbe 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -10,8 +10,7 @@ use super::{ExtensionCollectionError, WeakExtensionRegistry}; use crate::Node; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; -use crate::types::type_row::TypeRowBase; -use crate::types::{FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{FuncValueType, Signature, SumType, Term, TypeRow}; /// Collects every extension used to define the types in an operation. /// @@ -59,7 +58,7 @@ pub(crate) fn collect_op_types_extensions( } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), - OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing), + OpType::LoadConstant(lc) => collect_term_exts(&lc.datatype, &mut used, &mut missing), OpType::LoadFunction(lf) => { collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); @@ -121,7 +120,7 @@ pub(crate) fn collect_op_types_extensions( } } -/// Collect the Extension pointers in the [`CustomType`]s inside a signature. +/// Collect the Extension pointers in the [`CustomType`]s inside a [Signature]. /// /// # Attributes /// @@ -129,8 +128,8 @@ pub(crate) fn collect_op_types_extensions( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(crate) fn collect_signature_exts( - signature: &FuncTypeBase, +pub(crate) fn collect_signature_exts( + signature: &Signature, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { @@ -138,6 +137,23 @@ pub(crate) fn collect_signature_exts( collect_type_row_exts(&signature.output, used_extensions, missing_extensions); } +/// Collect the Extension pointers in the [`CustomType`]s inside a [FuncValueType]. +/// +/// # Attributes +/// +/// - `func_ty`: The function type to collect the extensions from. +/// - `used_extensions`: A The registry where to store the used extensions. +/// - `missing_extensions`: A set of `ExtensionId`s of which the +/// `Weak` pointer has been invalidated. +pub(crate) fn collect_func_type_exts( + func_ty: &FuncValueType, + used_extensions: &mut WeakExtensionRegistry, + missing_extensions: &mut ExtensionSet, +) { + collect_term_exts(&func_ty.input, used_extensions, missing_extensions); + collect_term_exts(&func_ty.output, used_extensions, missing_extensions); +} + /// Collect the Extension pointers in the [`CustomType`]s inside a type row. /// /// # Attributes @@ -146,31 +162,31 @@ pub(crate) fn collect_signature_exts( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -fn collect_type_row_exts( - row: &TypeRowBase, +fn collect_type_row_exts( + row: &TypeRow, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { for ty in row.iter() { - collect_type_exts(ty, used_extensions, missing_extensions); + collect_term_exts(ty, used_extensions, missing_extensions); } } -/// Collect the Extension pointers in the [`CustomType`]s inside a type. +/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// # Attributes /// -/// - `typ`: The type to collect the extensions from. +/// - `term`: The term argument to collect the extensions from. /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(crate) fn collect_type_exts( - typ: &TypeBase, +pub(crate) fn collect_term_exts( + term: &Term, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - match typ.as_type_enum() { - TypeEnum::Extension(custom) => { + match term { + Term::RuntimeExtension(custom) => { for arg in custom.args() { collect_term_exts(arg, used_extensions, missing_extensions); } @@ -185,39 +201,16 @@ pub(crate) fn collect_type_exts( } } } - TypeEnum::Function(f) => { - collect_type_row_exts(&f.input, used_extensions, missing_extensions); - collect_type_row_exts(&f.output, used_extensions, missing_extensions); + Term::RuntimeFunction(f) => { + collect_term_exts(&f.input, used_extensions, missing_extensions); + collect_term_exts(&f.output, used_extensions, missing_extensions); } - TypeEnum::Sum(SumType::General { rows }) => { - for row in rows { - collect_type_row_exts(row, used_extensions, missing_extensions); + Term::RuntimeSum(g @ SumType::General { .. }) => { + for row in g.variants() { + collect_term_exts(row, used_extensions, missing_extensions); } } - // Other types do not store extensions. - TypeEnum::Alias(_) - | TypeEnum::RowVar(_) - | TypeEnum::Variable(_, _) - | TypeEnum::Sum(SumType::Unit { .. }) => {} - } -} - -/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. -/// -/// # Attributes -/// -/// - `term`: The term argument to collect the extensions from. -/// - `used_extensions`: A The registry where to store the used extensions. -/// - `missing_extensions`: A set of `ExtensionId`s of which the -/// `Weak` pointer has been invalidated. -pub(super) fn collect_term_exts( - term: &Term, - used_extensions: &mut WeakExtensionRegistry, - missing_extensions: &mut ExtensionSet, -) { - match term { - Term::Runtime(ty) => collect_type_exts(ty, used_extensions, missing_extensions), - Term::ConstType(ty) => collect_type_exts(ty, used_extensions, missing_extensions), + Term::ConstType(ty) => collect_term_exts(ty, used_extensions, missing_extensions), Term::List(elems) => { for elem in elems.iter() { collect_term_exts(elem, used_extensions, missing_extensions); @@ -254,7 +247,8 @@ pub(super) fn collect_term_exts( | Term::BoundedNat(_) | Term::String(_) | Term::Bytes(_) - | Term::Float(_) => {} + | Term::Float(_) + | Term::RuntimeSum(SumType::Unit { .. }) => {} } } @@ -274,12 +268,12 @@ fn collect_value_exts( match value { Value::Extension { e } => { let typ = e.get_type(); - collect_type_exts(&typ, used_extensions, missing_extensions); + collect_term_exts(&typ, used_extensions, missing_extensions); } Value::Sum(s) => { - if let SumType::General { rows } = &s.sum_type { - for row in rows { - collect_type_row_exts(row, used_extensions, missing_extensions); + if matches!(s.sum_type, SumType::General { .. }) { + for row in s.sum_type.variants() { + collect_term_exts(row, used_extensions, missing_extensions); } } s.values diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 56faeec129..607cc02de7 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -5,12 +5,11 @@ use std::sync::Weak; -use super::types::collect_type_exts; +use super::types::collect_term_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; -use crate::types::type_row::TypeRowBase; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{CustomType, FuncValueType, Signature, SumType, Term, Type, TypeRow, TypeRowRV}; use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an @@ -125,12 +124,12 @@ pub fn resolve_op_types_extensions( Ok(used.into_iter()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [Signature]. /// /// Adds the extensions used in the signature to the `used_extensions` registry. -pub(super) fn resolve_signature_exts( +pub(super) fn resolve_signature_exts( node: Option, - signature: &mut FuncTypeBase, + signature: &mut Signature, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { @@ -139,12 +138,26 @@ pub(super) fn resolve_signature_exts( Ok(()) } +/// Update all weak Extension pointers in the [`CustomType`]s inside a [FuncValueType]. +/// +/// Adds the extensions used in the signature to the `used_extensions` registry. +pub(super) fn resolve_func_type_exts( + node: Option, + signature: &mut FuncValueType, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + resolve_typerow_rv_exts(node, &mut signature.input, extensions, used_extensions)?; + resolve_typerow_rv_exts(node, &mut signature.output, extensions, used_extensions)?; + Ok(()) +} + /// Update all weak Extension pointers in the [`CustomType`]s inside a type row. /// /// Adds the extensions used in the row to the `used_extensions` registry. -pub(super) fn resolve_type_row_exts( +pub(super) fn resolve_type_row_exts( node: Option, - row: &mut TypeRowBase, + row: &mut TypeRow, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { @@ -154,34 +167,18 @@ pub(super) fn resolve_type_row_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. /// -/// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_type_exts( +/// Adds the extensions used in the row to the `used_extensions` registry. +fn resolve_typerow_rv_exts( node: Option, - typ: &mut TypeBase, + row: &mut TypeRowRV, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - match typ.as_type_enum_mut() { - TypeEnum::Extension(custom) => { - resolve_custom_type_exts(node, custom, extensions, used_extensions)?; - } - TypeEnum::Function(f) => { - resolve_type_row_exts(node, &mut f.input, extensions, used_extensions)?; - resolve_type_row_exts(node, &mut f.output, extensions, used_extensions)?; - } - TypeEnum::Sum(SumType::General { rows }) => { - for row in rows.iter_mut() { - resolve_type_row_exts(node, row, extensions, used_extensions)?; - } - } - // Other types do not store extensions. - TypeEnum::Alias(_) - | TypeEnum::RowVar(_) - | TypeEnum::Variable(_, _) - | TypeEnum::Sum(SumType::Unit { .. }) => {} - } + let mut t = Term::from(std::mem::take(row)); + resolve_term_exts(node, &mut t, extensions, used_extensions)?; + *row = TypeRowRV::new_unchecked(t); Ok(()) } @@ -211,17 +208,43 @@ pub(super) fn resolve_custom_type_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Type`]. /// /// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_term_exts( +pub(crate) fn resolve_type_exts( + node: Option, + typ: &mut Type, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + const EMPTY: Type = Type::new_unit_sum(0); // as no Type::default() + let mut tm = std::mem::replace(typ, EMPTY).into(); + let r = resolve_term_exts(node, &mut tm, extensions, used_extensions); + *typ = tm.try_into().unwrap(); + r +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. +/// +/// Adds the extensions used in the term to the `used_extensions` registry. +pub(crate) fn resolve_term_exts( node: Option, term: &mut Term, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { match term { - Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, + Term::RuntimeExtension(custom) => { + resolve_custom_type_exts(node, custom, extensions, used_extensions)?; + } + Term::RuntimeFunction(f) => { + resolve_func_type_exts(node, &mut *f, extensions, used_extensions)?; + } + Term::RuntimeSum(SumType::General { rows }) => { + for row in rows.iter_mut() { + resolve_typerow_rv_exts(node, row, extensions, used_extensions)?; + } + } Term::ConstType(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, Term::List(children) | Term::ListConcat(children) @@ -247,7 +270,8 @@ pub(super) fn resolve_term_exts( | Term::BoundedNat(_) | Term::String(_) | Term::Bytes(_) - | Term::Float(_) => {} + | Term::Float(_) + | Term::RuntimeSum(SumType::Unit { .. }) => {} } Ok(()) } @@ -269,7 +293,7 @@ pub(super) fn resolve_value_exts( // return types with valid extensions after we call `update_extensions`. let typ = e.get_type(); let mut missing = ExtensionSet::new(); - collect_type_exts(&typ, used_extensions, &mut missing); + collect_term_exts(&typ, used_extensions, &mut missing); if !missing.is_empty() { return Err(ExtensionResolutionError::InvalidConstTypes { value: e.name(), @@ -280,7 +304,7 @@ pub(super) fn resolve_value_exts( Value::Sum(s) => { if let SumType::General { rows } = &mut s.sum_type { for row in rows.iter_mut() { - resolve_type_row_exts(node, row, extensions, used_extensions)?; + resolve_typerow_rv_exts(node, row, extensions, used_extensions)?; } } s.values diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index b848c7528f..bbb8001aa8 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -146,10 +146,8 @@ impl TypeDef { } least_upper_bound(indices.iter().map(|i| { let ta = args.get(*i); - match ta { - Some(TypeArg::Runtime(s)) => s.least_upper_bound(), - _ => panic!("TypeArg index does not refer to a type."), - } + ta.and_then(|t| t.least_upper_bound()) + .expect("TypeArg index does not refer to a type.") })) } } diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 1aa4384742..f47a717006 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -60,7 +60,7 @@ impl SimpleReplacement { node: replacement.entrypoint(), op: Box::new(replacement.get_optype(replacement.entrypoint()).to_owned()), })?; - if subgraph_sig != repl_sig { + if &subgraph_sig != repl_sig.as_ref() { return Err(InvalidReplacement::InvalidSignature { expected: Box::new(subgraph_sig), actual: Some(Box::new(repl_sig.into_owned())), diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 227fe035b6..7153c485b7 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -27,7 +27,7 @@ use crate::test_file; use crate::types::type_param::TypeParam; use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, - TypeRV, + TypeRowRV, }; use crate::{OutgoingPort, Visibility, type_row}; use std::fs::File; @@ -52,7 +52,7 @@ pub(super) struct HugrDeser(#[serde(deserialize_with = "Hugr::serde_deserialize" /// Version 1 of the Testing HUGR serialization format, see `testing_hugr.py`. #[derive(Serialize, Deserialize, PartialEq, Debug, Default)] struct SerTestingLatest { - typ: Option, + typ: Option, sum_type: Option, poly_func_type: Option, value: Option, @@ -142,7 +142,7 @@ macro_rules! impl_sertesting_from { }; } -impl_sertesting_from!(crate::types::TypeRV, typ); +impl_sertesting_from!(crate::types::Type, typ); impl_sertesting_from!(crate::types::SumType, sum_type); impl_sertesting_from!(crate::types::PolyFuncTypeRV, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); @@ -156,13 +156,6 @@ impl From for SerTestingLatest { } } -impl From for SerTestingLatest { - fn from(v: Type) -> Self { - let t: TypeRV = v.into(); - t.into() - } -} - #[test] fn empty_hugr_serialize() { check_hugr_json_roundtrip(&Hugr::default(), true); @@ -539,7 +532,7 @@ fn serialize_types_roundtrip() { check_testing_roundtrip(t); // A Classic sum - let t = TypeRV::new_sum([vec![usize_t()], vec![float64_type()]]); + let t = Type::new_sum([vec![usize_t()], vec![float64_type()]]); check_testing_roundtrip(t); let t = Type::new_unit_sum(4); @@ -550,7 +543,7 @@ fn serialize_types_roundtrip() { #[case(bool_t())] #[case(usize_t())] #[case(INT_TYPES[2].clone())] -#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Linear)))] +//#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Linear)))] #[case(Type::new_var_use(2, TypeBound::Copyable))] #[case(Type::new_tuple(vec![bool_t(),qb_t()]))] #[case(Type::new_sum([vec![bool_t(),qb_t()], vec![Type::new_unit_sum(4)]]))] @@ -583,14 +576,15 @@ fn polyfunctype1() -> PolyFuncType { } fn polyfunctype2() -> PolyFuncTypeRV { - let tv0 = TypeRV::new_row_var_use(0, TypeBound::Linear); - let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); + let tv0 = TypeRowRV::just_row_var(0, TypeBound::Linear); + let tv1 = TypeRowRV::just_row_var(1, TypeBound::Copyable); let params = [TypeBound::Linear, TypeBound::Copyable].map(TypeParam::new_list_type); - let inputs = vec![ - TypeRV::new_function(FuncValueType::new([tv0.clone()], [tv1.clone()])), - tv0, - ]; - let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, [tv1])); + let inputs = TypeRowRV::from([Type::new_function(FuncValueType::new( + tv0.clone(), + tv1.clone(), + ))]) + .concat(tv0); + let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, tv1)); // Just check we've got the arguments the right way round // (not that it really matters for the serialization schema we have) res.validate().unwrap(); @@ -606,7 +600,7 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[case(PolyFuncType::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( [TypeParam::new_list_type(TypeBound::Linear)], - Signature::new_endo([Type::new_tuple([TypeRV::new_row_var_use(0, TypeBound::Linear)])])))] + Signature::new_endo([Type::new_tuple(TypeRowRV::just_row_var(0, TypeBound::Linear))])))] fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type); } @@ -619,7 +613,7 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[case(PolyFuncTypeRV::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( [TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo([TypeRV::new_row_var_use(0, TypeBound::Linear)])))] + FuncValueType::new_endo(TypeRowRV::just_row_var(0, TypeBound::Linear))))] #[case(polyfunctype2())] fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { check_testing_roundtrip(poly_func_type); diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 107e3eb5c5..9915e7c86f 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -19,6 +19,7 @@ use crate::ops::validate::{ }; use crate::ops::{NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::EdgeKind; +use crate::types::Substitutable; use crate::types::type_param::TypeParam; use crate::{Direction, Port, Visibility}; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 7cec7d6e6d..2c2b48f948 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -25,7 +25,7 @@ use crate::std_extensions::logic::test::{and_op, or_op}; use crate::types::type_param::{TermTypeError, TypeArg}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Term, Type, TypeBound, - TypeRV, TypeRow, + TypeRow, TypeRowRV, }; use crate::{Direction, Hugr, IncomingPort, Node, const_extension_ids, test_file, type_row}; @@ -495,28 +495,27 @@ fn no_polymorphic_consts() -> Result<(), Box> { pub(crate) fn extension_with_eval_parallel() -> Arc { let rowp = TypeParam::new_list_type(TypeBound::Linear); Extension::new_test_arc(EXT_ID, |ext, extension_ref| { - let inputs = TypeRV::new_row_var_use(0, TypeBound::Linear); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Linear); - let evaled_fn = - TypeRV::new_function(FuncValueType::new([inputs.clone()], [outputs.clone()])); + let inputs = TypeRowRV::just_row_var(0, TypeBound::Linear); + let outputs = TypeRowRV::just_row_var(1, TypeBound::Linear); + let evaled_fn = Type::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone()], - FuncValueType::new([evaled_fn, inputs], [outputs]), + FuncValueType::new(TypeRowRV::from([evaled_fn]).concat(inputs), outputs), ); ext.add_op("eval".into(), String::new(), pf, extension_ref) .unwrap(); - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Linear); + let rv = |idx| TypeRowRV::just_row_var(idx, TypeBound::Linear); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], Signature::new( - vec![ - Type::new_function(FuncValueType::new([rv(0)], [rv(2)])), - Type::new_function(FuncValueType::new([rv(1)], [rv(3)])), + [ + Type::new_function(FuncValueType::new(rv(0), rv(2))), + Type::new_function(FuncValueType::new(rv(1), rv(3))), ], [Type::new_function(FuncValueType::new( - [rv(0), rv(1)], - [rv(2), rv(3)], + rv(0).concat(rv(1)), + rv(2).concat(rv(3)), ))], ), ); @@ -528,7 +527,7 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { #[test] fn instantiate_row_variables() -> Result<(), Box> { fn uint_seq(i: usize) -> Term { - vec![usize_t().into(); i].into() + vec![usize_t(); i].into() } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( @@ -552,16 +551,13 @@ fn instantiate_row_variables() -> Result<(), Box> { Ok(()) } -fn list1ty(t: TypeRV) -> Term { - Term::new_list([t.into()]) -} - #[test] fn row_variables() -> Result<(), Box> { let e = extension_with_eval_parallel(); - let tv = TypeRV::new_row_var_use(0, TypeBound::Linear); - let inner_ft = Type::new_function(FuncValueType::new_endo([tv.clone()])); - let ft_usz = Type::new_function(FuncValueType::new_endo([tv.clone(), usize_t().into()])); + let tv = Term::new_row_var_use(0, TypeBound::Linear); + let tv_row = TypeRowRV::try_from(tv.clone()).unwrap(); + let inner_ft = Type::new_function(FuncValueType::new_endo(tv_row.clone())); + let ft_usz = Type::new_function(FuncValueType::new_endo(tv_row.concat([usize_t()]))); let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( @@ -580,7 +576,12 @@ fn row_variables() -> Result<(), Box> { }; let par = e.instantiate_extension_op( "parallel", - [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(list1ty), + [ + tv.clone(), + [usize_t()].into(), + tv.clone(), + [usize_t()].into(), + ], )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs(par_func.outputs())?; diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index abdbacf749..eb68a8ac9c 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -12,7 +12,7 @@ use crate::{ OpParent, OpTrait, OpType, handle::{DataflowParentID, DfgID}, }, - types::{NoRV, Signature, Type, TypeBase}, + types::{Signature, Type}, }; use super::RootChecked; @@ -262,7 +262,7 @@ fn update_signature(hugr: &mut H, node: H::Node, new_sig: &Signature fn check_valid_inputs( old_ports: &[Vec], - old_sig: &[TypeBase], + old_sig: &[Type], map_sig: &[usize], ) -> Result<(), InvalidSignature> { if let Some(old_pos) = map_sig @@ -291,10 +291,7 @@ fn check_valid_inputs( Ok(()) } -fn check_valid_outputs( - old_sig: &[TypeBase], - map_sig: &[usize], -) -> Result<(), InvalidSignature> { +fn check_valid_outputs(old_sig: &[Type], map_sig: &[usize]) -> Result<(), InvalidSignature> { if let Some(old_pos) = map_sig .iter() .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos)) @@ -684,8 +681,8 @@ mod test { let new_inputs = vec![bool_t(), float64_type()]; dfg_view.extend_inputs(&new_inputs).unwrap(); assert_eq!( - dfg_view.hugr().inner_function_type().unwrap(), - Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()]) + dfg_view.hugr().inner_function_type().unwrap().as_ref(), + &Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()]) ); let new_inputs_fail = vec![qb_t()]; diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 1e62bc699e..95a57a640e 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use crate::envelope::description::GeneratorDesc; use crate::metadata::{self, Metadata}; +use crate::types::{FuncValueType, TypeRowRV}; use crate::{ Direction, Hugr, HugrView, Node, Port, envelope::description::{ExtensionDesc, ModuleDesc}, @@ -27,10 +28,8 @@ use crate::{ collections::array::ArrayValue, }, types::{ - CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, - Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, + CustomType, PolyFuncType, Signature, Term, Type, TypeArg, TypeBound, TypeName, TypeRow, type_param::{SeqPart, TypeParam}, - type_row::TypeRowBase, }, }; use hugr_model::v0::table; @@ -314,7 +313,7 @@ impl<'a> Context<'a> { let signature = node_data .signature .ok_or_else(|| error_uninferred!("node signature"))?; - self.import_func_type(signature) + self.import_signature(signature) } /// Get the node with the given `NodeId`, or return an error if it does not exist. @@ -687,7 +686,7 @@ impl<'a> Context<'a> { } let signature = self - .import_func_type( + .import_signature( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -839,7 +838,10 @@ impl<'a> Context<'a> { let sum_rows: Vec<_> = { let [variants] = self.expect_symbol(*first, model::CORE_ADT)?; - self.import_type_rows(variants)? + self.import_closed_list(variants)? + .into_iter() + .map(|term_id| self.import_type_row(term_id)) + .collect::>()? }; let rest = rest @@ -933,7 +935,7 @@ impl<'a> Context<'a> { for region in node_data.regions { let region_data = self.get_region(*region)?; - let signature = self.import_func_type( + let signature = self.import_signature( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -1378,11 +1380,11 @@ impl<'a> Context<'a> { Ok(node) } - fn import_poly_func_type( + fn import_poly_func_type( &mut self, node: table::NodeId, symbol: table::Symbol<'a>, - in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, + in_scope: impl FnOnce(&mut Self, PolyFuncType) -> Result, ) -> Result { (|| { let mut imported_params = Vec::with_capacity(symbol.params.len()); @@ -1425,8 +1427,8 @@ impl<'a> Context<'a> { ); } - let body = self.import_func_type::(symbol.signature)?; - in_scope(self, PolyFuncTypeBase::new(imported_params, body)) + let body = self.import_signature(symbol.signature)?; + in_scope(self, PolyFuncType::new(imported_params, body)) })() .map_err(|err| error_context!(err, "symbol `{}` defined by node {}", symbol.name, node)) } @@ -1436,6 +1438,10 @@ impl<'a> Context<'a> { self.import_term_with_bound(term_id, TypeBound::Linear) } + fn import_type(&mut self, term_id: table::TermId) -> Result { + Ok(Type::try_from(self.import_term(term_id)?).map_err(SignatureError::from)?) + } + fn import_term_with_bound( &mut self, term_id: table::TermId, @@ -1495,6 +1501,28 @@ impl<'a> Context<'a> { return Ok(TypeParam::new_tuple_type(item_types)); } + if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { + let func_type = self.import_func_type(term_id)?; + return Ok(Type::new_function(func_type).into()); + } + + if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { + let variants = (|| { + self.import_closed_list(variants)? + .iter() + .map(|variant| { + self.import_term(*variant).and_then(|tm| { + TypeRowRV::try_from(tm) + .map_err(|e| ImportErrorInner::Signature(e.into())) + }) + }) + .collect::, _>>() + })() + .map_err(|err| error_context!(err, "adt variants"))?; + + return Ok(Type::new_sum(variants).into()); + } + match self.get_term(term_id)? { table::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -1539,51 +1567,6 @@ impl<'a> Context<'a> { table::Term::Literal(model::Literal::Float(value)) => Ok(Term::Float(*value)), table::Term::Func { .. } => Err(error_unsupported!("function constant")), - table::Term::Apply { .. } => { - let ty: Type = self.import_type(term_id)?; - Ok(ty.into()) - } - } - })() - .map_err(|err| error_context!(err, "term {}", term_id)) - } - - fn import_seq_part( - &mut self, - seq_part: &'a table::SeqPart, - ) -> Result, ImportErrorInner> { - Ok(match seq_part { - table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), - table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), - }) - } - - /// Import a `Type` from a term that represents a runtime type. - fn import_type( - &mut self, - term_id: table::TermId, - ) -> Result, ImportErrorInner> { - (|| { - if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { - let func_type = self.import_func_type::(term_id)?; - return Ok(TypeBase::new_function(func_type)); - } - - if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { - let variants = (|| { - self.import_closed_list(variants)? - .iter() - .map(|variant| self.import_type_row::(*variant)) - .collect::, _>>() - })() - .map_err(|err| error_context!(err, "adt variants"))?; - - return Ok(TypeBase::new_sum(variants)); - } - - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), - table::Term::Apply(symbol, args) => { let name = self.get_symbol_name(*symbol)?; @@ -1615,32 +1598,28 @@ impl<'a> Context<'a> { let bound = ext_type.bound(&args); - Ok(TypeBase::new_extension(CustomType::new( + Ok(Type::new_extension(CustomType::new( id, args, extension, bound, &Arc::downgrade(extension_ref), - ))) - } - - table::Term::Var(var @ table::VarId(_, index)) => { - let local_var = self - .local_vars - .get(var) - .ok_or(error_invalid!("unknown var {}", var))?; - Ok(TypeBase::new_var_use(*index as _, local_var.bound)) + )) + .into()) } - - // The following terms are not runtime types, but the core `Type` only contains runtime types. - // We therefore report a type error here. - table::Term::Literal(_) - | table::Term::List { .. } - | table::Term::Tuple { .. } - | table::Term::Func { .. } => Err(error_invalid!("expected a runtime type")), } })() - .map_err(|err| error_context!(err, "term {} as `Type`", term_id)) + .map_err(|err| error_context!(err, "term {}", term_id)) + } + + fn import_seq_part( + &mut self, + seq_part: &'a table::SeqPart, + ) -> Result, ImportErrorInner> { + Ok(match seq_part { + table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), + table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), + }) } fn get_func_type( @@ -1667,23 +1646,31 @@ impl<'a> Context<'a> { /// /// Function types are not special-cased in `hugr-model` but are represented /// via the `core.fn` term constructor. - fn import_func_type( + fn import_func_type( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { + ) -> Result { (|| { let [inputs, outputs] = self.get_func_type(term_id)?; let inputs = self - .import_type_row(inputs) + .import_term(inputs) .map_err(|err| error_context!(err, "function inputs"))?; + let inputs = TypeRowRV::try_from(inputs).map_err(SignatureError::from)?; let outputs = self - .import_type_row(outputs) + .import_term(outputs) .map_err(|err| error_context!(err, "function outputs"))?; - Ok(FuncTypeBase::new(inputs, outputs)) + let outputs = TypeRowRV::try_from(outputs).map_err(SignatureError::from)?; + + Ok(FuncValueType::new(inputs, outputs)) })() .map_err(|err| error_context!(err, "function type")) } + fn import_signature(&mut self, term_id: table::TermId) -> Result { + let fvt = self.import_func_type(term_id)?; + Ok(fvt.try_into()?) + } + /// Import a closed list as a vector of term ids. /// /// This method supports list terms that contain spliced sublists as long as @@ -1778,64 +1765,19 @@ impl<'a> Context<'a> { Ok(types) } - /// Imports a list of lists as a vector of type rows. - /// - /// See [`Self::import_type_row`]. - fn import_type_rows( - &mut self, - term_id: table::TermId, - ) -> Result>, ImportErrorInner> { - self.import_closed_list(term_id)? - .into_iter() - .map(|term_id| self.import_type_row::(term_id)) - .collect() - } - - /// Imports a list as a type row. + /// Imports a closed list as a type row. /// /// This method works to produce a [`TypeRow`] or a [`TypeRowRV`], depending /// on the `RV` type argument. For [`TypeRow`] a closed list is expected. /// For [`TypeRowRV`] we import spliced variables as row variables. - fn import_type_row( - &mut self, - term_id: table::TermId, - ) -> Result, ImportErrorInner> { - fn import_into( - ctx: &mut Context, - term_id: table::TermId, - types: &mut Vec>, - ) -> Result<(), ImportErrorInner> { - match ctx.get_term(term_id)? { - table::Term::List(parts) => { - types.reserve(parts.len()); - - for item in *parts { - match item { - table::SeqPart::Item(term_id) => { - types.push(ctx.import_type::(*term_id)?); - } - table::SeqPart::Splice(term_id) => { - import_into(ctx, *term_id, types)?; - } - } - } - } - table::Term::Var(table::VarId(_, index)) => { - let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Linear)) - .map_err(|_| { - error_invalid!("Expected a closed list.\n{}", CLOSED_LIST_HINT) - })?; - types.push(TypeBase::new(TypeEnum::RowVar(var))); - } - _ => return Err(error_invalid!("expected a list")), - } - - Ok(()) - } - - let mut types = Vec::new(); - import_into(self, term_id, &mut types)?; - Ok(types.into()) + fn import_type_row(&mut self, term_id: table::TermId) -> Result { + let elems = self.import_closed_list(term_id)?; + Ok(elems + .into_iter() + .map(|id| self.import_term(id)) + .collect::, _>>()? + .try_into() + .map_err(SignatureError::from)?) } fn import_custom_name( @@ -1987,13 +1929,8 @@ impl<'a> Context<'a> { .map(|(value, ty)| self.import_value(*value, *ty)) .collect::, _>>()?; - let ty = { - // TODO: Import as a `SumType` directly and avoid the copy. - let ty: Type = self.import_type(type_id)?; - match ty.as_type_enum() { - TypeEnum::Sum(sum) => sum.clone(), - _ => unreachable!(), - } + let Term::RuntimeSum(ty) = self.import_term(type_id)? else { + unreachable!() }; return Ok(Value::sum(*tag as _, items, ty).unwrap()); @@ -2138,7 +2075,7 @@ impl<'a> Context<'a> { struct LocalVar { /// The type of the variable. r#type: table::TermId, - /// The type bound of the variable. + /// The type bound of the variable. Overwritten if a constraint is seen. bound: TypeBound, } diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 45a06b16f5..a06d7844d0 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use crate::Direction; -use crate::types::{EdgeKind, Signature, Type, TypeRow}; +use crate::types::{EdgeKind, Signature, Substitutable, Type, TypeRow}; use super::OpTag; use super::dataflow::{DataflowOpTrait, DataflowParent}; @@ -351,7 +351,7 @@ mod test { use crate::{ extension::prelude::{qb_t, usize_t}, ops::{Conditional, DataflowOpTrait, DataflowParent}, - types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, + types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRowRV}, }; use super::{DataflowBlock, TailLoop}; @@ -368,8 +368,8 @@ mod test { let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( - dfb2.inner_signature(), - Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) + dfb2.inner_signature().as_ref(), + &Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) ); } @@ -378,10 +378,10 @@ mod test { let tv1 = Type::new_var_use(1, TypeBound::Linear); let cond = Conditional { sum_rows: vec![[usize_t()].into(), [tv1.clone()].into()], - other_inputs: vec![Type::new_tuple([TypeRV::new_row_var_use( + other_inputs: vec![Type::new_tuple(TypeRowRV::just_row_var( 0, TypeBound::Linear, - )])] + ))] .into(), outputs: vec![usize_t(), tv1].into(), }; @@ -391,8 +391,8 @@ mod test { ])); let st = Type::new_sum([[usize_t()], [qb_t()]]); assert_eq!( - cond2.signature(), - Signature::new( + cond2.signature().as_ref(), + &Signature::new( [st, Type::new_tuple(vec![usize_t(); 3])], [usize_t(), qb_t()] ) @@ -409,8 +409,8 @@ mod test { }; let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()])); assert_eq!( - tail2.signature(), - Signature::new( + tail2.signature().as_ref(), + &Signature::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 878fbe04ca..16dc754444 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -14,7 +14,7 @@ use { use crate::core::HugrNode; use crate::extension::simple_op::MakeExtensionOp; use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError}; -use crate::types::{Signature, type_param::TypeArg}; +use crate::types::{Signature, Substitutable, type_param::TypeArg}; use crate::{IncomingPort, ops}; use super::dataflow::DataflowOpTrait; diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 9e46764728..a08e841b66 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -6,7 +6,9 @@ use super::{OpTag, OpTrait, impl_op_name}; use crate::extension::SignatureError; use crate::ops::StaticTag; -use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; +use crate::types::{ + EdgeKind, PolyFuncType, Signature, Substitutable, Substitution, Type, TypeArg, TypeRow, +}; use crate::{IncomingPort, type_row}; #[cfg(test)] diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index 71955bdc1b..eb90b9f777 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -6,7 +6,7 @@ use crate::types::{Type, TypeBound}; use derive_more::From as DerFrom; use smol_str::SmolStr; -use super::{AliasDecl, OpTag}; +use super::OpTag; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. @@ -91,7 +91,8 @@ impl AliasID { /// Construct new `AliasID` pub fn get_alias_type(&self) -> Type { - Type::new_alias(AliasDecl::new(self.name.clone(), self.bound)) + unimplemented!("Type aliases") + //Type::new_alias(AliasDecl::new(self.name.clone(), self.bound)) } /// Retrieve the underlying core type pub fn get_name(&self) -> &SmolStr { diff --git a/hugr-core/src/ops/sum.rs b/hugr-core/src/ops/sum.rs index 1c535683fc..34f1a6db0d 100644 --- a/hugr-core/src/ops/sum.rs +++ b/hugr-core/src/ops/sum.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use super::dataflow::DataflowOpTrait; use super::{OpTag, impl_op_name}; -use crate::types::{EdgeKind, Signature, Type, TypeRow}; +use crate::types::{EdgeKind, Signature, Substitutable, Type, TypeRow}; /// An operation that creates a tagged sum value from one of its variants. #[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] diff --git a/hugr-core/src/proptest.rs b/hugr-core/src/proptest.rs index 051e300bf8..c569f923de 100644 --- a/hugr-core/src/proptest.rs +++ b/hugr-core/src/proptest.rs @@ -8,8 +8,8 @@ use std::sync::LazyLock; use crate::Hugr; #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -/// The types [Type], [`TypeEnum`], [`SumType`], [`FunctionType`], [`TypeArg`], -/// [`TypeParam`], as well as several others, form a mutually recursive hierarchy. +/// The types [Type], [`Term`], [`SumType`], [`FunctionType`], [`CustomType`], +/// as well as several others, form a mutually recursive hierarchy. /// /// The proptest [`proptest::strategy::Strategy::prop_recursive`] is inadequate to /// generate values for these types. Instead, the Arbitrary instances take a diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index 586b1f65ca..264afa1b1e 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -13,7 +13,7 @@ use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc}; use crate::ops::{ExtensionOp, OpName}; use crate::std_extensions::arithmetic::int_ops::int_polytype; use crate::std_extensions::arithmetic::int_types::int_type; -use crate::types::{TypeArg, TypeRV}; +use crate::types::TypeArg; use super::float_types::float64_type; use super::int_types::{get_log_width, int_tv}; @@ -62,11 +62,9 @@ impl MakeOpDef for ConvertOpDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { use ConvertOpDef::*; match self { - trunc_s | trunc_u => int_polytype( - 1, - [float64_type()], - [TypeRV::from(sum_with_error([int_tv(0)]))], - ), + trunc_s | trunc_u => { + int_polytype(1, [float64_type()], [sum_with_error([int_tv(0)]).into()]) + } convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]), itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]), ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]), diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 11c16b14a8..05512383a9 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -10,7 +10,7 @@ use crate::extension::simple_op::{ use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs}; use crate::ops::OpName; use crate::ops::custom::ExtensionOp; -use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; +use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRow, TypeRowRV}; use crate::utils::collect_array; use crate::{ @@ -136,7 +136,7 @@ impl MakeOpDef for IntOpDef { } ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(), idivmod_checked_u | idivmod_checked_s => { - let intpair: TypeRowRV = vec![tv0; 2].into(); + let intpair: TypeRow = vec![tv0; 2].into(); int_polytype( 1, intpair.clone(), diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs index 566ee12c70..742c40f549 100644 --- a/hugr-core/src/std_extensions/collections/array/array_clone.rs +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -180,8 +180,9 @@ impl HasConcrete for GenericArrayCloneDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { - Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) + [TypeArg::BoundedNat(n), ty] if ty.copyable() => { + let ty = Type::try_from(ty.clone()).unwrap(); // succeeds as copyable + Ok(GenericArrayClone::new(ty, *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs index 61b013a062..8415264801 100644 --- a/hugr-core/src/std_extensions/collections/array/array_conversion.rs +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -231,9 +231,10 @@ impl HasConcrete fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { - Ok(GenericArrayConvert::new(ty.clone(), *n)) - } + [TypeArg::BoundedNat(n), ty] => Ok(GenericArrayConvert::new( + ty.clone().try_into().map_err(SignatureError::from)?, + *n, + )), _ => Err(SignatureError::InvalidTypeArgs.into()), } } diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs index 17e2be1577..be3021bafd 100644 --- a/hugr-core/src/std_extensions/collections/array/array_discard.rs +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -164,8 +164,9 @@ impl HasConcrete for GenericArrayDiscardDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { - Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) + [TypeArg::BoundedNat(n), ty] if ty.copyable() => { + let ty = ty.clone().try_into().unwrap(); // succeeds as copyable + Ok(GenericArrayDiscard::new(ty, *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 26ebb5b5f4..e937b5c895 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -326,12 +326,12 @@ impl HasConcrete for GenericArrayOpDef { fn instantiate(&self, type_args: &[Term]) -> Result { let (ty, size) = match (self, type_args) { - (GenericArrayOpDef::discard_empty, [Term::Runtime(ty)]) => (ty.clone(), 0), - (_, [Term::BoundedNat(n), Term::Runtime(ty)]) => (ty.clone(), *n), + (GenericArrayOpDef::discard_empty, [ty]) => (ty.clone(), 0), + (_, [Term::BoundedNat(n), ty]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(self.to_concrete(ty.clone(), size)) + Ok(self.to_concrete(ty.try_into().map_err(SignatureError::from)?, size)) } } diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 3fb121980f..ee3d2ec9ca 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -170,8 +170,9 @@ impl HasConcrete for GenericArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { - Ok(GenericArrayRepeat::new(ty.clone(), *n)) + [TypeArg::BoundedNat(n), ty] => { + let ty = Type::try_from(ty.clone()).map_err(SignatureError::from)?; + Ok(GenericArrayRepeat::new(ty, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 5bd62466c2..f8eaca5033 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -13,7 +13,7 @@ use crate::extension::simple_op::{ use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound, TypeRowRV}; use super::array_kind::ArrayKind; @@ -62,29 +62,26 @@ impl GenericArrayScanDef { TypeParam::new_list_type(TypeBound::Linear), ]; let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let t1 = Type::new_var_use(1, TypeBound::Linear); - let t2 = Type::new_var_use(2, TypeBound::Linear); - let s = TypeRV::new_row_var_use(3, TypeBound::Linear); + let src_elem = Type::new_var_use(1, TypeBound::Linear); + let tgt_elem = Type::new_var_use(2, TypeBound::Linear); + let with_rest = |tys: Vec| { + TypeRowRV::from(tys).concat(TypeRowRV::just_row_var(3, TypeBound::Linear)) + }; PolyFuncTypeRV::new( params, - FuncTypeBase::::new( - vec![ - AK::instantiate_ty(array_def, n.clone(), t1.clone()) - .expect("Array type instantiation failed") - .into(), - Type::new_function(FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - )) - .into(), - s.clone(), - ], - vec![ - AK::instantiate_ty(array_def, n, t2) - .expect("Array type instantiation failed") - .into(), - s, - ], + FuncValueType::new( + with_rest(vec![ + AK::instantiate_ty(array_def, n.clone(), src_elem.clone()) + .expect("Array type instantiation failed"), + Type::new_function(FuncValueType::new( + with_rest(vec![src_elem]), + with_rest(vec![tgt_elem.clone()]), + )), + ]), + with_rest(vec![ + AK::instantiate_ty(array_def, n, tgt_elem) + .expect("Array type instantiation failed"), + ]), ), ) .into() @@ -214,23 +211,20 @@ impl HasConcrete for GenericArrayScanDef { match type_args { [ TypeArg::BoundedNat(n), - TypeArg::Runtime(src_ty), - TypeArg::Runtime(tgt_ty), + src_elem_ty, + tgt_elem_ty, TypeArg::List(acc_tys), ] => { - let acc_tys: Result<_, OpLoadError> = acc_tys + let src_elem_ty = + Type::try_from(src_elem_ty.clone()).map_err(SignatureError::from)?; + let tgt_elem_ty = + Type::try_from(tgt_elem_ty.clone()).map_err(SignatureError::from)?; + let acc_tys = acc_tys .iter() - .map(|acc_ty| match acc_ty { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs.into()), - }) - .collect(); - Ok(GenericArrayScan::new( - src_ty.clone(), - tgt_ty.clone(), - acc_tys?, - *n, - )) + .map(|tm| Type::try_from(tm.clone())) + .collect::, _>>() + .map_err(SignatureError::from)?; + Ok(GenericArrayScan::new(src_elem_ty, tgt_elem_ty, acc_tys, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs index 33828d9e0d..0227a68aa5 100644 --- a/hugr-core/src/std_extensions/collections/array/array_value.rs +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -94,8 +94,9 @@ impl GenericArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if *n as usize == self.values.len() => { - ty + // ALAN checking copyable here might be a bugfix but sounds like we should? + [TypeArg::BoundedNat(n), ty] if *n as usize == self.values.len() && ty.copyable() => { + Type::try_from(ty.clone()).unwrap() // succeeds as copyable } _ => { return Err(CustomCheckFailure::Message(format!( @@ -107,7 +108,7 @@ impl GenericArrayValue { // check all values are instances of the element type for v in &self.values { - if v.get_type() != *ty { + if v.get_type() != ty { return Err(CustomCheckFailure::Message(format!( "Array element {v:?} is not of expected type {ty}" ))); diff --git a/hugr-core/src/std_extensions/collections/borrow_array.rs b/hugr-core/src/std_extensions/collections/borrow_array.rs index 534a28655c..54d2877142 100644 --- a/hugr-core/src/std_extensions/collections/borrow_array.rs +++ b/hugr-core/src/std_extensions/collections/borrow_array.rs @@ -280,7 +280,10 @@ impl HasConcrete for BArrayUnsafeOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [Term::BoundedNat(n), Term::Runtime(ty)] => Ok(self.to_concrete(ty.clone(), *n)), + [Term::BoundedNat(n), ty] => { + let ty = Type::try_from(ty.clone()).map_err(SignatureError::from)?; + Ok(self.to_concrete(ty, *n)) + } _ => Err(SignatureError::InvalidTypeArgs.into()), } } diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 495ab0e003..0a3ed8b258 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -111,13 +111,14 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Runtime(ty)] = typ.args() else { - return Err(error()); + let ty = match typ.args() { + [ty] if ty.least_upper_bound().is_some() => Type::try_from(ty.clone()).unwrap(), // succeeds as has l-u-b + _ => return Err(error()), }; // check all values are instances of the element type for v in &self.0 { - if v.get_type() != *ty { + if v.get_type() != ty { return Err(error()); } } @@ -349,18 +350,16 @@ impl MakeExtensionOp for ListOpInst { fn from_extension_op( ext_op: &ExtensionOp, ) -> Result { - let [Term::Runtime(ty)] = ext_op.args() else { + let [ty] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs.into()); }; + let elem_type = ty.clone().try_into().map_err(SignatureError::from)?; let name = ext_op.unqualified_id(); let Ok(op) = ListOp::from_str(name) else { return Err(OpLoadError::NotMember(name.to_string())); }; - Ok(Self { - elem_type: ty.clone(), - op, - }) + Ok(Self { elem_type, op }) } fn type_args(&self) -> Vec { diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 007be5ecc4..a097f4b568 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -310,15 +310,13 @@ impl HasConcrete for StaticArrayOpDef { use TypeBound::Copyable; match type_args { [arg] => { - let elem_ty = arg - .as_runtime() - .filter(|t| Copyable.contains(t.least_upper_bound())) - .ok_or(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - type_: Box::new(Copyable.into()), - term: Box::new(arg.clone()), - }, - ))?; + if !arg.copyable() { + Err(SignatureError::from(TermTypeError::TypeMismatch { + type_: Box::new(Copyable.into()), + term: Box::new(arg.clone()), + }))? + } + let elem_ty = Type::try_from(arg.clone()).unwrap(); // succeeds as copyable Ok(StaticArrayOp { def: *self, diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 7816e9b03c..7deb1fb53c 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -204,7 +204,7 @@ impl HasConcrete for PtrOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { let ty = match type_args { - [TypeArg::Runtime(ty)] => ty.clone(), + [ty] => Type::try_from(ty.clone()).map_err(SignatureError::from)?, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 1f67edf594..546b4b2e68 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -3,41 +3,35 @@ mod check; pub mod custom; mod poly_func; -mod row_var; pub(crate) mod serialize; mod signature; pub mod type_param; pub mod type_row; -pub(crate) use row_var::MaybeRV; -pub use row_var::{NoRV, RowVariable}; use crate::extension::resolution::{ - ExtensionCollectionError, WeakExtensionRegistry, collect_type_exts, + ExtensionCollectionError, WeakExtensionRegistry, collect_term_exts, }; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; -use crate::types::type_param::check_term_type; +use crate::types::type_param::{TermTypeError, check_term_type}; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; -pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; +pub use poly_func::{PolyFuncType, PolyFuncTypeBase, PolyFuncTypeRV}; pub use signature::{FuncTypeBase, FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::{Term, TypeArg}; pub use type_row::{TypeRow, TypeRowRV}; -pub(crate) use poly_func::PolyFuncTypeBase; - use itertools::FoldWhile::{Continue, Done}; use itertools::{Either, Itertools as _}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; +use std::ops::Deref; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; -use crate::ops::AliasDecl; use self::type_param::TypeParam; -use self::type_row::TypeRowBase; /// A unique identifier for a type. pub type TypeName = SmolStr; @@ -211,8 +205,8 @@ impl std::fmt::Display for SumType { display_list_with_separator(itertools::repeat_n("[]", *size as usize), f, "+") } SumType::General { rows } => match rows.len() { - 1 if rows[0].is_empty() => write!(f, "Unit"), - 2 if rows[0].is_empty() && rows[1].is_empty() => write!(f, "Bool"), + 1 if rows[0].is_empty_list() => write!(f, "Unit"), + 2 if rows[0].is_empty_list() && rows[1].is_empty_list() => write!(f, "Bool"), _ => display_list_with_separator(rows.iter(), f, "+"), }, } @@ -225,13 +219,12 @@ impl SumType { where V: Into, { - let rows = variants.into_iter().map(Into::into).collect_vec(); - - let len: usize = rows.len(); - if u8::try_from(len).is_ok() && rows.iter().all(TypeRowRV::is_empty) { + let variants = variants.into_iter().map(V::into).collect_vec(); + let len = variants.len(); + if u8::try_from(len).is_ok() && variants.iter().all(|tr| tr.is_empty_list()) { Self::new_unary(len as u8) } else { - Self::General { rows } + Self::General { rows: variants } } } @@ -242,6 +235,10 @@ impl SumType { } /// New tuple (single row of variants). + /// + /// # Panics + /// + /// If the argument is not of type [Term::ListType]`(`[Term::RuntimeType]`)` pub fn new_tuple(types: impl Into) -> Self { Self::new([types.into()]) } @@ -255,7 +252,7 @@ impl SumType { #[must_use] pub fn get_variant(&self, tag: usize) -> Option<&TypeRowRV> { match self { - SumType::Unit { size } if tag < (*size as usize) => Some(TypeRV::EMPTY_TYPEROW_REF), + SumType::Unit { size } if tag < (*size as usize) => Some(TypeRowRV::EMPTY_REF), SumType::General { rows } => rows.get(tag), _ => None, } @@ -270,43 +267,59 @@ impl SumType { } } - /// Returns variant row if there is only one variant. + /// Returns variant row if there is only one variant + /// (will be an instance of [Term::ListType]([Term::RuntimeType]). #[must_use] - pub fn as_tuple(&self) -> Option<&TypeRowRV> { + pub fn as_tuple(&self) -> Option<&Term> { match self { - SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF), + SumType::Unit { size } if *size == 1 => Some(Term::EMPTY_LIST_REF), SumType::General { rows } if rows.len() == 1 => Some(&rows[0]), _ => None, } } - /// If the sum matches the convention of `Option[row]`, return the row. + /// If the sum matches the convention of `Option[row]`, return the row + /// (an instance of [Term::ListType]([Term::RuntimeType]). #[must_use] - pub fn as_option(&self) -> Option<&TypeRowRV> { + pub fn as_option(&self) -> Option<&Term> { match self { - SumType::Unit { size } if *size == 2 => Some(TypeRV::EMPTY_TYPEROW_REF), - SumType::General { rows } if rows.len() == 2 && rows[0].is_empty() => Some(&rows[1]), + SumType::Unit { size } if *size == 2 => Some(Term::EMPTY_LIST_REF), + SumType::General { rows } if rows.len() == 2 && rows[0].is_empty_list() => { + Some(&rows[1]) + } _ => None, } } - /// If a sum is an option of a single type, return the type. - #[must_use] - pub fn as_unary_option(&self) -> Option<&TypeRV> { - self.as_option() - .and_then(|row| row.iter().exactly_one().ok()) - } + // ALAN removing as_unary_option. + // "If a sum is an option of a single type, return the type. pub fn as_unary_option(&self) -> Option<&Term>" + // But of course a Term was not necessarily a single type... - /// Returns an iterator over the variants. + /// Returns an iterator over the variants, each an instance of [Term::ListType]`(`[Term::RuntimeType]`)` pub fn variants(&self) -> impl Iterator { match self { - SumType::Unit { size } => Either::Left(itertools::repeat_n( - TypeRV::EMPTY_TYPEROW_REF, - *size as usize, - )), + SumType::Unit { size } => { + Either::Left(itertools::repeat_n(TypeRowRV::EMPTY_REF, *size as usize)) + } SumType::General { rows } => Either::Right(rows.iter()), } } + + fn bound(&self) -> TypeBound { + match self { + SumType::Unit { .. } => TypeBound::Copyable, + SumType::General { rows } => { + if rows + .iter() + .all(|t| check_term_type(t, &Term::new_list_type(TypeBound::Copyable)).is_ok()) + { + TypeBound::Copyable + } else { + TypeBound::Linear + } + } + } + } } impl Transformable for SumType { @@ -318,72 +331,29 @@ impl Transformable for SumType { } } -impl From for TypeBase { +impl From for Type { fn from(sum: SumType) -> Self { match sum { - SumType::Unit { size } => TypeBase::new_unit_sum(size), - SumType::General { rows } => TypeBase::new_sum(rows), + SumType::Unit { size } => Type::new_unit_sum(size), + SumType::General { rows } => Type::new_sum(rows), } } } -#[derive(Clone, Debug, Eq, Hash, derive_more::Display)] -/// Core types -pub enum TypeEnum { - /// An extension type. - // - // TODO optimise with `Box`? - // or some static version of this? - Extension(CustomType), - /// An alias of a type. - #[display("Alias({})", _0.name())] - Alias(AliasDecl), - /// A function type. - #[display("{_0}")] - Function(Box), - /// A type variable, defined by an index into a list of type parameters. - // - // We cache the TypeBound here (checked in validation) - #[display("#{_0}")] - Variable(usize, TypeBound), - /// `RowVariable`. Of course, this requires that `RV` has instances, [`NoRV`] doesn't. - #[display("RowVar({_0})")] - RowVar(RV), - /// Sum of types. - #[display("{_0}")] - Sum(SumType), -} - -impl TypeEnum { - /// The smallest type bound that covers the whole type. - fn least_upper_bound(&self) -> TypeBound { - match self { - TypeEnum::Extension(c) => c.bound(), - TypeEnum::Alias(a) => a.bound, - TypeEnum::Function(_) => TypeBound::Copyable, - TypeEnum::Variable(_, b) => *b, - TypeEnum::RowVar(b) => b.bound(), - TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Copyable, - TypeEnum::Sum(SumType::General { rows }) => least_upper_bound( - rows.iter() - .flat_map(TypeRowRV::iter) - .map(TypeRV::least_upper_bound), - ), - } - } -} - -#[derive(Clone, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize)] +#[derive( + Clone, Debug, Eq, Hash, PartialEq, derive_more::Display, serde::Serialize, serde::Deserialize, +)] #[display("{_0}")] #[serde( into = "serialize::SerSimpleType", try_from = "serialize::SerSimpleType" )] -/// A HUGR type - the valid types of [`EdgeKind::Value`] and [`EdgeKind::Const`] edges. +/// A HUGR type - a single value, that can be sent down a wire. /// +/// The valid types of [`EdgeKind::Value`] and [`EdgeKind::Const`] edges. /// Such an edge is valid if the ports on either end agree on the [Type]. -/// Types have an optional [`TypeBound`] which places limits on the valid -/// operations on a type. +/// Types have a [`TypeBound`] which specifies the number of inports +/// to which a particular outport (of that type) may be connected. /// /// Examples: /// ``` @@ -400,68 +370,48 @@ impl TypeEnum { /// let func_type: Type = Type::new_function(Signature::new_endo([])); /// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable); /// ``` -pub struct TypeBase(TypeEnum, TypeBound); - -/// The type of a single value, that can be sent down a wire -pub type Type = TypeBase; - -/// One or more types - either a single type, or a row variable -/// standing for multiple types. -pub type TypeRV = TypeBase; - -impl PartialEq> for TypeEnum { - fn eq(&self, other: &TypeEnum) -> bool { - match (self, other) { - (TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2, - (TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2, - (TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1 == f2, - (TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1 == i2 && b1 == b2, - (TypeEnum::RowVar(v1), TypeEnum::RowVar(v2)) => v1.as_rv() == v2.as_rv(), - (TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2, - _ => false, - } - } -} - -impl PartialEq> for TypeBase { - fn eq(&self, other: &TypeBase) -> bool { - self.0 == other.0 && self.1 == other.1 - } -} +pub struct Type(Term, TypeBound); -impl TypeBase { +impl Type { /// An empty `TypeRow` or `TypeRowRV`. Provided here for convenience - pub const EMPTY_TYPEROW: TypeRowBase = TypeRowBase::::new(); + pub const EMPTY_TYPEROW: TypeRow = TypeRow::new(); /// Unit type (empty tuple). pub const UNIT: Self = Self( - TypeEnum::Sum(SumType::Unit { size: 1 }), + Term::RuntimeSum(SumType::Unit { size: 1 }), TypeBound::Copyable, ); - const EMPTY_TYPEROW_REF: &'static TypeRowBase = &Self::EMPTY_TYPEROW; - /// Initialize a new function type. pub fn new_function(fun_ty: impl Into) -> Self { - Self::new(TypeEnum::Function(Box::new(fun_ty.into()))) + Self( + Term::RuntimeFunction(Box::new(fun_ty.into())), + TypeBound::Copyable, + ) } /// Initialize a new tuple type by providing the elements. #[inline(always)] pub fn new_tuple(types: impl Into) -> Self { let row = types.into(); - match row.len() { - 0 => Self::UNIT, - _ => Self::new_sum([row]), + match row.is_empty_list() { + true => Self::UNIT, + false => Self::new_sum([row]), } } /// Initialize a new sum type by providing the possible variant types. + /// + /// # Panics + /// + /// If any element is not a type or row variable #[inline(always)] pub fn new_sum(variants: impl IntoIterator) -> Self where R: Into, { - Self::new(TypeEnum::Sum(SumType::new(variants))) + let st = SumType::new(variants); + let b = st.bound(); + Self(Term::RuntimeSum(st), b) } /// Initialize a new custom type. @@ -469,25 +419,17 @@ impl TypeBase { #[must_use] pub const fn new_extension(opaque: CustomType) -> Self { let bound = opaque.bound(); - TypeBase(TypeEnum::Extension(opaque), bound) - } - - /// Initialize a new alias. - #[must_use] - pub fn new_alias(alias: AliasDecl) -> Self { - Self::new(TypeEnum::Alias(alias)) - } - - pub(crate) fn new(type_e: TypeEnum) -> Self { - let bound = type_e.least_upper_bound(); - Self(type_e, bound) + Self(Term::RuntimeExtension(opaque), bound) } /// New `UnitSum` with empty Tuple variants #[must_use] pub const fn new_unit_sum(size: u8) -> Self { // should be the only way to avoid going through SumType::new - Self(TypeEnum::Sum(SumType::new_unary(size)), TypeBound::Copyable) + Self( + Term::RuntimeSum(SumType::new_unary(size)), + TypeBound::Copyable, + ) } /// New use (occurrence) of the type variable with specified index. @@ -495,8 +437,8 @@ impl TypeBase { /// (i.e. as a [`Term::RuntimeType`]`(bound)`), which may be narrower /// than required for the use. #[must_use] - pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::Variable(idx, bound), bound) + pub fn new_var_use(idx: usize, bound: TypeBound) -> Self { + Self(Term::new_var_use(idx, bound), bound) } /// Report the least upper [`TypeBound`] @@ -505,91 +447,12 @@ impl TypeBase { self.1 } - /// Report the component `TypeEnum`. - #[inline(always)] - pub const fn as_type_enum(&self) -> &TypeEnum { - &self.0 - } - - /// Report a mutable reference to the component `TypeEnum`. - #[inline(always)] - pub fn as_type_enum_mut(&mut self) -> &mut TypeEnum { - &mut self.0 - } - - /// Returns the inner [`SumType`] if the type is a sum. - pub fn as_sum(&self) -> Option<&SumType> { - match &self.0 { - TypeEnum::Sum(s) => Some(s), - _ => None, - } - } - - /// Returns the inner [`CustomType`] if the type is from an extension. - pub fn as_extension(&self) -> Option<&CustomType> { - match &self.0 { - TypeEnum::Extension(ct) => Some(ct), - _ => None, - } - } - /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { TypeBound::Copyable.contains(self.least_upper_bound()) } - /// Checks all variables used in the type are in the provided list - /// of bound variables, rejecting any [`RowVariable`]s if `allow_row_vars` is False; - /// and that for each [`CustomType`] the corresponding - /// [`TypeDef`] is in the [`ExtensionRegistry`] and the type arguments - /// [validate] and fit into the def's declared parameters. - /// - /// [RowVariable]: TypeEnum::RowVariable - /// [validate]: crate::types::type_param::TypeArg::validate - /// [TypeDef]: crate::extension::TypeDef - pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - // There is no need to check the components against the bound, - // that is guaranteed by construction (even for deserialization) - match &self.0 { - TypeEnum::Sum(SumType::General { rows }) => { - rows.iter().try_for_each(|row| row.validate(var_decls)) - } - TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there - TypeEnum::Alias(_) => Ok(()), - TypeEnum::Extension(custy) => custy.validate(var_decls), - // Function values may be passed around without knowing their arity - // (i.e. with row vars) as long as they are not called: - TypeEnum::Function(ft) => ft.validate(var_decls), - TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()), - TypeEnum::RowVar(rv) => rv.validate(var_decls), - } - } - - /// Applies a substitution to a type. - /// This may result in a row of types, if this [Type] is not really a single type but actually a row variable - /// Invariants may be confirmed by validation: - /// * If [`Type::validate`]`(false)` returns successfully, this method will return a Vec containing exactly one type - /// * If [`Type::validate`]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self) - /// return a Vec containing any number of [Type]s. These may (or not) pass [`Type::validate`] - fn substitute(&self, t: &Substitution) -> Vec { - match &self.0 { - TypeEnum::RowVar(rv) => rv.substitute(t), - TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], - TypeEnum::Variable(idx, bound) => { - let TypeArg::Runtime(ty) = t.apply_var(*idx, &((*bound).into())) else { - panic!("Variable was not a type - try validate() first") - }; - vec![ty.into_()] - } - TypeEnum::Extension(cty) => vec![TypeBase::new_extension(cty.substitute(t))], - TypeEnum::Function(bf) => vec![TypeBase::new_function(bf.substitute(t))], - TypeEnum::Sum(SumType::General { rows }) => { - vec![TypeBase::new_sum(rows.iter().map(|r| r.substitute(t)))] - } - } - } - /// Returns a registry with the concrete extensions used by this type. /// /// This includes the extensions of custom types that may be nested @@ -598,7 +461,7 @@ impl TypeBase { let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); - collect_type_exts(self, &mut used, &mut missing); + collect_term_exts(self, &mut used, &mut missing); if missing.is_empty() { Ok(used.try_into().expect("all extensions are present")) @@ -608,114 +471,41 @@ impl TypeBase { } } -impl Transformable for TypeBase { - fn transform(&mut self, tr: &T) -> Result { - match &mut self.0 { - TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => Ok(false), - TypeEnum::Extension(custom_type) => { - if let Some(nt) = tr.apply_custom(custom_type)? { - *self = nt.into_(); - Ok(true) - } else { - let args_changed = custom_type.args_mut().transform(tr)?; - if args_changed { - *self = Self::new_extension( - custom_type - .get_type_def(&custom_type.get_extension()?)? - .instantiate(custom_type.args())?, - ); - } - Ok(args_changed) - } - } - TypeEnum::Function(fty) => fty.transform(tr), - TypeEnum::Sum(sum_type) => { - let ch = sum_type.transform(tr)?; - self.1 = self.0.least_upper_bound(); - Ok(ch) - } - } - } -} +impl Deref for Type { + type Target = Term; -impl Type { - fn substitute1(&self, s: &Substitution) -> Self { - let v = self.substitute(s); - let [r] = v.try_into().unwrap(); // No row vars, so every Type produces exactly one - r - } -} - -impl TypeRV { - /// Tells if this Type is a row variable, i.e. could stand for any number >=0 of Types - #[must_use] - pub fn is_row_var(&self) -> bool { - matches!(self.0, TypeEnum::RowVar(_)) - } - - /// New use (occurrence) of the row variable with specified index. - /// `bound` must match that with which the variable was declared - /// (i.e. as a list of runtime types of that bound). - /// For use in [OpDef], not [FuncDefn], type schemes only. - /// - /// [OpDef]: crate::extension::OpDef - /// [FuncDefn]: crate::ops::FuncDefn - #[must_use] - pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::RowVar(RowVariable(idx, bound)), bound) + fn deref(&self) -> &Self::Target { + &self.0 } } -// ====== Conversions ====== -impl TypeBase { - /// (Fallibly) converts a `TypeBase` (parameterized, so may or may not be able - /// to contain [`RowVariable`]s) into a [Type] that definitely does not. - pub fn try_into_type(self) -> Result { - Ok(TypeBase( - match self.0 { - TypeEnum::Extension(e) => TypeEnum::Extension(e), - TypeEnum::Alias(a) => TypeEnum::Alias(a), - TypeEnum::Function(f) => TypeEnum::Function(f), - TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), - TypeEnum::RowVar(rv) => Err(rv.as_rv().clone())?, - TypeEnum::Sum(s) => TypeEnum::Sum(s), - }, - self.1, - )) - } -} +impl TryFrom for Type { + type Error = TermTypeError; -impl TryFrom for Type { - type Error = RowVariable; - fn try_from(value: TypeRV) -> Result { - value.try_into_type() + fn try_from(t: Term) -> Result { + match t.least_upper_bound() { + Some(b) => Ok(Self(t, b)), + None => Err(TermTypeError::TypeMismatch { + term: Box::new(t), + type_: Box::new(TypeBound::Linear.into()), + }), + } } } -impl TypeBase { - /// A swiss-army-knife for any safe conversion of the type argument `RV1` - /// to/from [`NoRV`]/RowVariable/rust-type-variable. - fn into_(self) -> TypeBase - where - RV1: Into, - { - TypeBase( - match self.0 { - TypeEnum::Extension(e) => TypeEnum::Extension(e), - TypeEnum::Alias(a) => TypeEnum::Alias(a), - TypeEnum::Function(f) => TypeEnum::Function(f), - TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), - TypeEnum::RowVar(rv) => TypeEnum::RowVar(rv.into()), - TypeEnum::Sum(s) => TypeEnum::Sum(s), - }, - self.1, - ) +impl From for Term { + fn from(t: Type) -> Self { + t.0 } } -impl From for TypeRV { - fn from(value: Type) -> Self { - value.into_() +impl Transformable for Type { + fn transform(&mut self, tr: &T) -> Result { + let res = self.0.transform(tr)?; + if res { + self.1 = self.0.least_upper_bound().unwrap() + } + Ok(res) } } @@ -745,36 +535,6 @@ impl<'a> Substitution<'a> { debug_assert_eq!(check_term_type(arg, decl), Ok(())); arg.clone() } - - fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec { - let arg = self - .0 - .get(idx) - .expect("Undeclared type variable - call validate() ?"); - debug_assert!(check_term_type(arg, &TypeParam::new_list_type(bound)).is_ok()); - match arg { - TypeArg::List(elems) => elems - .iter() - .map(|ta| { - match ta { - Term::Runtime(ty) => return ty.clone().into(), - Term::Variable(v) => { - if let Some(b) = v.bound_if_row_var() { - return TypeRV::new_row_var_use(v.index(), b); - } - } - _ => (), - } - panic!("Not a list of types - call validate() ?") - }) - .collect(), - Term::Runtime(ty) if matches!(ty.0, TypeEnum::RowVar(_)) => { - // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type - vec![ty.clone().into()] - } - _ => panic!("Not a type or list of types - call validate() ?"), - } - } } /// A transformation that can be applied to a [Type] or [`TypeArg`]. @@ -825,28 +585,59 @@ impl Transformable for [E] { } } -pub(crate) fn check_typevar_decl( - decls: &[TypeParam], - idx: usize, - cached_decl: &TypeParam, -) -> Result<(), SignatureError> { - match decls.get(idx) { - None => Err(SignatureError::FreeTypeVar { - idx, - num_decls: decls.len(), - }), - Some(actual) => { - // The cache here just mirrors the declaration. The typevar can be used - // anywhere expecting a kind *containing* the decl - see `check_type_arg`. - if actual == cached_decl { - Ok(()) - } else { - Err(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: Box::new(cached_decl.clone()), - actual: Box::new(actual.clone()), - }) - } - } +/// Compared to just `pub(crate) trait Substitutable: Transformable`, this avoids a +/// private_bounds warning when the trait is used as a type bound on a public struct. +mod internal { + use super::{SignatureError, Substitution, Transformable, TypeParam}; + + /// Sub-trait of [`Transformable`] for types that support substitution of + /// type variables and validation of type-variable scopes. + pub trait Substitutable: Transformable { + /// Checks all variables used in `self` are in the provided list of bound + /// variables, and that for each [`CustomType`] the corresponding [`TypeDef`] + /// is in the [`ExtensionRegistry`] and the type arguments validate (recursively) + /// and fit into the declared parameters of the [`TypeDef`]. + /// + /// [`TypeDef`]: crate::extension::TypeDef + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError>; + + /// Applies a [`Substitution`] to this instance, returning a new value. + /// + /// Infallible (assuming the `subst` covers all variables) and will + /// not invalidate the instance (assuming all values substituted in are + /// valid instances of the variables they replace). + /// + /// # Panics + /// + /// If the substitution does not cover all type variables in `self`. + fn substitute(&self, s: &Substitution) -> Self; + } +} + +pub(crate) use internal::Substitutable; + +impl Substitutable for Type { + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { + self.0.validate(var_decls)?; + // ALAN even this should be only a debug-assert really: + // we have no unchecked access from outside crate::types + // so it must be a bug in our caching logic if this is wrong: + check_term_type(&self.0, &self.1.into())?; + debug_assert!( + self.1 == TypeBound::Copyable + || check_term_type(&self.0, &TypeBound::Copyable.into()).is_err() + ); + Ok(()) + } + + /// Always produces exactly one type, but may narrow the bound (from + /// [TypeBound::Linear] to [TypeBound::Copyable]). + fn substitute(&self, s: &Substitution) -> Self { + let t = self.0.substitute(s); + // Must succeed and produce a type assuming substitution valid (RHSes + // fit within LHS). However, may *narrow* the bound, so recompute. + let b = t.least_upper_bound().unwrap(); + Self(t, b) } } @@ -876,18 +667,15 @@ pub(crate) mod test { // Dummy extension reference. &Weak::default(), )), - Type::new_alias(AliasDecl::new("my_alias", TypeBound::Copyable)), + //Type::new_alias(AliasDecl::new("my_alias", TypeBound::Copyable)), ]); - assert_eq!( - &t.to_string(), - "[usize, [] -> [], my_custom, Alias(my_alias)]" - ); + assert_eq!(&t.to_string(), "[usize, [] -> [], my_custom]"); } #[rstest::rstest] fn sum_construct() { let pred1 = Type::new_sum([type_row![], type_row![]]); - let pred2 = TypeRV::new_unit_sum(2); + let pred2 = Type::new_unit_sum(2); assert_eq!(pred1, pred2); @@ -905,9 +693,22 @@ pub(crate) mod test { fn as_option() { let opt = option_type([usize_t()]); - assert_eq!(opt.as_unary_option().unwrap().clone(), usize_t()); assert_eq!( - Type::new_unit_sum(2).as_sum().unwrap().as_unary_option(), + opt.as_option().unwrap().clone(), + Term::new_list([usize_t().into()]) + ); + // Two empty variants is like an option of empty. + // ALAN note there used to be as_unary_option... + assert_eq!( + Type::new_unit_sum(2).as_sum().unwrap().as_option(), + Some(&Term::new_list([])) + ); + + assert_eq!( + Type::new_sum(vec![[usize_t()]; 2]) + .as_sum() + .unwrap() + .as_option(), None ); @@ -932,18 +733,16 @@ pub(crate) mod test { #[test] fn sum_variants() { let variants: Vec = vec![ - [TypeRV::UNIT].into(), - vec![TypeRV::new_row_var_use(0, TypeBound::Linear)].into(), + [Type::UNIT].into(), + TypeRowRV::just_row_var(0, TypeBound::Linear), ]; let t = SumType::new(variants.clone()); assert_eq!(variants, t.variants().cloned().collect_vec()); - let empty_rows = vec![TypeRV::EMPTY_TYPEROW; 3]; + let empty_rows = vec![TypeRowRV::new(); 3]; let sum_unary = SumType::new_unary(3); - let sum_general = SumType::General { - rows: empty_rows.clone(), - }; - assert_eq!(&empty_rows, &sum_unary.variants().cloned().collect_vec()); + assert_eq!(empty_rows, sum_unary.variants().cloned().collect_vec()); + let sum_general = SumType::General { rows: empty_rows }; assert_eq!(sum_general, sum_unary); let mut hasher_general = std::hash::DefaultHasher::new(); @@ -1070,11 +869,10 @@ pub(crate) mod test { use crate::proptest::RecursionDepth; - use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum}; - use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV}; - use proptest::prelude::*; + use crate::types::{CustomType, FuncValueType, SumType, Type, TypeBound, TypeRow}; + use proptest::{prelude::*, strategy::Union}; - impl Arbitrary for super::SumType { + impl Arbitrary for SumType { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { @@ -1082,29 +880,34 @@ pub(crate) mod test { if depth.leaf() { any::().prop_map(Self::new_unary).boxed() } else { - vec(any_with::(depth), 0..3) + vec(any_with::(depth), 0..3) .prop_map(SumType::new) .boxed() } } } - impl Arbitrary for TypeBase { + impl Arbitrary for Type { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - // We descend here, because a TypeEnum may contain a Type + let strat = Union::new([ + (any::(), any::()) + .prop_map(|(i, b)| Type::new_var_use(i, b)) + .boxed(), + any_with::(depth.into()) + .prop_map(Type::new_extension) + .boxed(), + ]); + if depth.leaf() { + return strat.boxed(); + } let depth = depth.descend(); - prop_oneof![ - 1 => any::().prop_map(TypeBase::new_alias), - 1 => any_with::(depth.into()).prop_map(TypeBase::new_extension), - 1 => any_with::(depth).prop_map(TypeBase::new_function), - 1 => any_with::(depth).prop_map(TypeBase::from), - 1 => (any::(), any::()).prop_map(|(i,b)| TypeBase::new_var_use(i,b)), - // proptest_derive::Arbitrary's weight attribute requires a constant, - // rather than this expression, hence the manual impl: - RV::weight() => RV::arb().prop_map(|rv| TypeBase::new(TypeEnum::RowVar(rv))) - ] + strat + .or(any_with::(depth) + .prop_map(Type::new_function) + .boxed()) + .or(any_with::(depth).prop_map(Type::from).boxed()) .boxed() } } @@ -1116,11 +919,10 @@ pub(super) mod proptest_utils { use proptest::collection::vec; use proptest::prelude::{Strategy, any_with}; - use super::serialize::{TermSer, TypeArgSer, TypeParamSer}; - use super::type_param::Term; - use crate::proptest::RecursionDepth; - use crate::types::serialize::ArrayOrTermSer; + + use super::serialize::{ArrayOrTermSer, TermSer, TypeArgSer, TypeParamSer}; + use super::type_param::Term; fn term_is_serde_type_arg(t: &Term) -> bool { let TermSer::TypeArg(arg) = TermSer::from(t.clone()) else { @@ -1132,13 +934,11 @@ pub(super) mod proptest_utils { | TypeArgSer::Tuple { elems: terms } | TypeArgSer::TupleConcat { tuples: terms } => terms.iter().all(term_is_serde_type_arg), TypeArgSer::Variable { v } => term_is_serde_type_param(&v.cached_decl), - TypeArgSer::Type { ty } => { - if let Some(cty) = ty.as_extension() { - cty.args().iter().all(term_is_serde_type_arg) - } else { - true - } - } // Do we need to inspect inside function types? sum types? + TypeArgSer::Type { ty } => match Term::from(ty) { + Term::RuntimeExtension(cty) => cty.args().iter().all(term_is_serde_type_arg), + // Do we need to inspect inside function types? sum types? + _ => true, + }, TypeArgSer::BoundedNat { .. } | TypeArgSer::String { .. } | TypeArgSer::Bytes { .. } diff --git a/hugr-core/src/types/check.rs b/hugr-core/src/types/check.rs index 072da5884e..9ef8dc783b 100644 --- a/hugr-core/src/types/check.rs +++ b/hugr-core/src/types/check.rs @@ -3,7 +3,10 @@ use thiserror::Error; use super::{Type, TypeRow}; -use crate::{extension::SignatureError, ops::Value}; +use crate::{ + ops::Value, + types::{Term, type_param::TermTypeError}, +}; /// Errors that arise from typechecking constants #[derive(Clone, Debug, PartialEq, Error)] @@ -69,10 +72,16 @@ impl super::SumType { num_variants: self.num_variants(), })?; let variant: TypeRow = variant.clone().try_into().map_err(|e| { - let SignatureError::RowVarWhereTypeExpected { var } = e else { - panic!("Unexpected error") + let TermTypeError::TypeMismatch { term, .. } = e else { + panic!("Unexpected error {e}") }; - SumTypeError::VariantNotConcrete { tag, varidx: var.0 } + let Term::Variable(tv) = &*term else { + panic!("Unexpected term {term}"); + }; + SumTypeError::VariantNotConcrete { + tag, + varidx: tv.index(), + } })?; if variant.len() != val.len() { diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 248e0f6253..ae1901e5d9 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -8,7 +8,7 @@ use crate::Extension; use crate::extension::{ExtensionId, SignatureError, TypeDef}; use super::{ - Substitution, TypeBound, + Substitutable, Substitution, TypeBound, type_param::{TypeArg, TypeParam}, }; use super::{Type, TypeName}; diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index ea16ab958b..6f00cc8d26 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -4,18 +4,14 @@ use std::borrow::Cow; use itertools::Itertools; -use crate::extension::SignatureError; -#[cfg(test)] -use { - super::proptest_utils::any_serde_type_param, - crate::proptest::RecursionDepth, - ::proptest::{collection::vec, prelude::*}, - proptest_derive::Arbitrary, +use crate::{ + extension::SignatureError, + types::{Substitutable, TypeRow, TypeRowRV}, }; use super::Substitution; +use super::signature::FuncTypeBase; use super::type_param::{TypeArg, TypeParam, check_term_types}; -use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// A polymorphic type scheme, i.e. of a [`FuncDecl`], [`FuncDefn`] or [`OpDef`]. /// (Nodes/operations in the Hugr are not polymorphic.) @@ -24,19 +20,24 @@ use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// [`FuncDefn`]: crate::ops::module::FuncDefn /// [`OpDef`]: crate::extension::OpDef #[derive( - Clone, PartialEq, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, + PartialEq, + Debug, + Default, + Eq, + Hash, + derive_more::Display, + serde::Serialize, + serde::Deserialize, )] -#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] #[display("{}{body}", self.display_params())] -pub struct PolyFuncTypeBase { +pub struct PolyFuncTypeBase { /// The declared type parameters, i.e., these must be instantiated with /// the same number of [`TypeArg`]s before the function can be called. This /// defines the indices used by variables inside the body. - #[cfg_attr(test, proptest(strategy = "vec(any_serde_type_param(params), 0..3)"))] params: Vec, /// Template for the function. May contain variables up to length of [`Self::params`] - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - body: FuncTypeBase, + body: FuncTypeBase, } /// The polymorphic type of a [`Call`]-able function ([`FuncDecl`] or [`FuncDefn`]). @@ -45,26 +46,16 @@ pub struct PolyFuncTypeBase { /// [`Call`]: crate::ops::Call /// [`FuncDefn`]: crate::ops::FuncDefn /// [`FuncDecl`]: crate::ops::FuncDecl -pub type PolyFuncType = PolyFuncTypeBase; +pub type PolyFuncType = PolyFuncTypeBase; /// The polymorphic type of an [`OpDef`], whose number of input and outputs -/// may vary according to how [`RowVariable`]s therein are instantiated. +/// may vary according to how row variables therein are instantiated. /// /// [`OpDef`]: crate::extension::OpDef -pub type PolyFuncTypeRV = PolyFuncTypeBase; +pub type PolyFuncTypeRV = PolyFuncTypeBase; -// deriving Default leads to an impl that only applies for RV: Default -impl Default for PolyFuncTypeBase { - fn default() -> Self { - Self { - params: Default::default(), - body: Default::default(), - } - } -} - -impl From> for PolyFuncTypeBase { - fn from(body: FuncTypeBase) -> Self { +impl From> for PolyFuncTypeBase { + fn from(body: FuncTypeBase) -> Self { Self { params: vec![], body, @@ -81,11 +72,11 @@ impl From for PolyFuncTypeRV { } } -impl TryFrom> for FuncTypeBase { +impl TryFrom> for FuncTypeBase { /// If the `PolyFuncTypeBase` is not monomorphic, fail with its binders type Error = Vec; - fn try_from(value: PolyFuncTypeBase) -> Result { + fn try_from(value: PolyFuncTypeBase) -> Result { if value.params.is_empty() { Ok(value.body) } else { @@ -94,45 +85,26 @@ impl TryFrom> for FuncTypeBase { } } -impl PolyFuncTypeBase { +impl PolyFuncTypeBase { /// The type parameters, aka binders, over which this type is polymorphic pub fn params(&self) -> &[TypeParam] { &self.params } /// The body of the type, a function type. - pub fn body(&self) -> &FuncTypeBase { + pub fn body(&self) -> &FuncTypeBase { &self.body } /// Create a new `PolyFuncTypeBase` given the kinds of the variables it declares /// and the underlying [`FuncTypeBase`]. - pub fn new(params: impl Into>, body: impl Into>) -> Self { + pub fn new(params: impl Into>, body: impl Into>) -> Self { Self { params: params.into(), body: body.into(), } } - /// Instantiates an outer [`PolyFuncTypeBase`], i.e. with no free variables - /// (as ensured by [`Self::validate`]), into a monomorphic type. - /// - /// # Errors - /// If there is not exactly one [`TypeArg`] for each binder ([`Self::params`]), - /// or an arg does not fit into its corresponding [`TypeParam`] - pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { - // Check that args are applicable, and that we have a value for each binder, - // i.e. each possible free variable within the body. - check_term_types(args, &self.params)?; - Ok(self.body.substitute(&Substitution(args))) - } - - /// Validates this instance, checking that the types in the body are - /// wellformed with respect to the registry, and the type variables declared. - pub fn validate(&self) -> Result<(), SignatureError> { - self.body.validate(&self.params) - } - /// Helper function for the Display implementation fn display_params(&self) -> Cow<'static, str> { if self.params.is_empty() { @@ -148,11 +120,34 @@ impl PolyFuncTypeBase { } /// Returns a mutable reference to the body of the function type. - pub fn body_mut(&mut self) -> &mut FuncTypeBase { + pub fn body_mut(&mut self) -> &mut FuncTypeBase { &mut self.body } } +// Do not implement Substitutable: we never need to substitute into a PolyFuncType +// (i.e. under a binder). +impl PolyFuncTypeBase { + /// Instantiates an outer [`PolyFuncTypeBase`], i.e. with no free variables + /// (as ensured by [`Self::validate`]), into a monomorphic type. + /// + /// # Errors + /// If there is not exactly one [`TypeArg`] for each binder ([`Self::params`]), + /// or an arg does not fit into its corresponding [`TypeParam`] + pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { + // Check that args are applicable, and that we have a value for each binder, + // i.e. each possible free variable within the body. + check_term_types(args, &self.params)?; + Ok(self.body.substitute(&Substitution(args))) + } + + /// Validates this instance, checking that the types in the body are + /// wellformed with respect to the registry, and the type variables declared. + pub fn validate(&self) -> Result<(), SignatureError> { + self.body.validate(&self.params) + } +} + #[cfg(test)] pub(crate) mod test { use std::num::NonZeroU64; @@ -163,20 +158,47 @@ pub(crate) mod test { use crate::Extension; use crate::extension::prelude::{bool_t, usize_t}; use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound}; + use crate::std_extensions::collections::array::{self, array_type_parametric}; use crate::std_extensions::collections::list; + use crate::types::signature::FuncTypeBase; use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, MaybeRV, Signature, Term, Type, TypeBound, TypeName, TypeRV, + CustomType, FuncValueType, Signature, Substitutable, Term, Type, TypeBound, TypeName, + TypeRowRV, }; use super::PolyFuncTypeBase; - - impl PolyFuncTypeBase { + mod proptest { + use proptest::collection::vec; + use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any_with}; + + use super::PolyFuncTypeBase; + use crate::proptest::RecursionDepth; + use crate::types::proptest_utils::any_serde_type_param; + use crate::types::signature::FuncTypeBase; + + impl + 'static> Arbitrary for PolyFuncTypeBase { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + // We want to generate a random number of type parameters, and then generate a body that can refer to those parameters. + // To do this, we first generate the type parameters, and then pass them as parameters to the body strategy. + ( + vec(any_serde_type_param(params), 0..3), + any_with::>(params), + ) + .prop_map(|(params, body)| Self::new(params, body)) + .boxed() + } + } + } + impl PolyFuncTypeBase { fn new_validated( params: impl Into>, - body: FuncTypeBase, + body: FuncTypeBase, ) -> Result { let res = Self::new(params, body); res.validate()?; @@ -187,7 +209,7 @@ pub(crate) mod test { #[test] fn test_opaque() -> Result<(), SignatureError> { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); - let tyvar = TypeArg::new_var_use(0, TypeBound::Linear.into()); + let tyvar = TypeArg::new_var_use(0, TypeBound::Linear); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); let list_len = PolyFuncTypeBase::new_validated( [TypeBound::Linear.into()], @@ -211,7 +233,7 @@ pub(crate) mod test { #[test] fn test_mismatched_args() -> Result<(), SignatureError> { let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let ty_var = TypeArg::new_var_use(1, TypeBound::Linear.into()); + let ty_var = TypeArg::new_var_use(1, TypeBound::Linear); let type_params = [TypeParam::max_nat_type(), TypeBound::Linear.into()]; // Valid schema... @@ -262,7 +284,7 @@ pub(crate) mod test { #[test] fn test_misused_variables() -> Result<(), SignatureError> { // Variables in args have different bounds from variable declaration - let tv = TypeArg::new_var_use(0, TypeBound::Copyable.into()); + let tv = TypeArg::new_var_use(0, TypeBound::Copyable); let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo([Type::new_extension(list_def.instantiate([tv])?)]); for decl in [ @@ -377,10 +399,7 @@ pub(crate) mod test { let decl = Term::new_list_type(TP_ANY); let e = PolyFuncTypeBase::new_validated( [decl.clone()], - FuncValueType::new( - vec![usize_t()], - vec![TypeRV::new_row_var_use(0, TypeBound::Copyable)], - ), + FuncValueType::new([usize_t()], TypeRowRV::just_row_var(0, TypeBound::Copyable)), ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { @@ -401,10 +420,13 @@ pub(crate) mod test { #[test] fn row_variables() { - let rty = TypeRV::new_row_var_use(0, TypeBound::Linear); + let rty = TypeRowRV::just_row_var(0, TypeBound::Linear); let pf = PolyFuncTypeBase::new_validated( [TypeParam::new_list_type(TP_ANY)], - FuncValueType::new([usize_t().into(), rty.clone()], [TypeRV::new_tuple([rty])]), + FuncValueType::new( + TypeRowRV::from([usize_t()]).concat(rty.clone()), + [Type::new_tuple(rty)], + ), ) .unwrap(); @@ -417,20 +439,20 @@ pub(crate) mod test { let t2 = pf.instantiate(&[Term::new_list(seq2())]).unwrap(); assert_eq!( - t2, Signature::new( vec![usize_t(), usize_t(), bool_t()], vec![Type::new_tuple(vec![usize_t(), bool_t()])] - ) + ), + t2 ); } #[test] fn row_variables_inner() { - let inner_fty = Type::new_function(FuncValueType::new_endo([TypeRV::new_row_var_use( + let inner_fty = Type::new_function(FuncValueType::new_endo(TypeRowRV::just_row_var( 0, TypeBound::Copyable, - )])); + ))); let pf = PolyFuncTypeBase::new_validated( [Term::new_list_type(TypeBound::Copyable)], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), @@ -439,11 +461,7 @@ pub(crate) mod test { let inner3 = Type::new_function(Signature::new_endo([usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate(&[Term::new_list([ - usize_t().into(), - bool_t().into(), - usize_t().into(), - ])]) + .instantiate(&[[usize_t(), bool_t(), usize_t()].into()]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/row_var.rs b/hugr-core/src/types/row_var.rs deleted file mode 100644 index 086ab7b076..0000000000 --- a/hugr-core/src/types/row_var.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! Classes for row variables (i.e. Type variables that can stand for multiple types) - -use super::type_param::TypeParam; -use super::{Substitution, TypeBase, TypeBound, check_typevar_decl}; -use crate::extension::SignatureError; - -#[cfg(test)] -use proptest::prelude::{BoxedStrategy, Strategy, any}; -/// Describes a row variable - a type variable bound with a list of runtime types -/// of the specified bound (checked in validation) -// The serde derives here are not used except as markers -// so that other types containing this can also #derive-serde the same way. -#[derive( - Clone, Debug, Eq, Hash, PartialEq, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -#[display("{_0}")] -pub struct RowVariable(pub usize, pub TypeBound); - -// Note that whilst 'pub' this is not re-exported outside private module `row_var` -// so is effectively sealed. -pub trait MaybeRV: - Clone - + std::fmt::Debug - + std::fmt::Display - + From - + Into - + Eq - + PartialEq - + 'static -{ - fn as_rv(&self) -> &RowVariable; - fn try_from_rv(rv: RowVariable) -> Result; - fn bound(&self) -> TypeBound; - fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError>; - #[allow(private_interfaces)] - fn substitute(&self, s: &Substitution) -> Vec>; - #[cfg(test)] - fn weight() -> u32 { - 1 - } - #[cfg(test)] - fn arb() -> BoxedStrategy; -} - -/// Has no instances - used as parameter to [`Type`] to rule out the possibility -/// of there being any [`TypeEnum::RowVar`]s -/// -/// [`TypeEnum::RowVar`]: super::TypeEnum::RowVar -/// [`Type`]: super::Type -// The serde derives here are not used except as markers -// so that other types containing this can also #derive-serde the same way. -#[derive( - Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -pub enum NoRV {} - -impl From for RowVariable { - fn from(value: NoRV) -> Self { - match value {} - } -} - -impl MaybeRV for RowVariable { - fn as_rv(&self) -> &RowVariable { - self - } - - fn bound(&self) -> TypeBound { - self.1 - } - - fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - check_typevar_decl(var_decls, self.0, &TypeParam::new_list_type(self.1)) - } - - #[allow(private_interfaces)] - fn substitute(&self, s: &Substitution) -> Vec> { - s.apply_rowvar(self.0, self.1) - } - - fn try_from_rv(rv: RowVariable) -> Result { - Ok(rv) - } - - #[cfg(test)] - fn arb() -> BoxedStrategy { - (any::(), any::()) - .prop_map(|(i, b)| Self(i, b)) - .boxed() - } -} - -impl MaybeRV for NoRV { - fn as_rv(&self) -> &RowVariable { - match *self {} - } - - fn bound(&self) -> TypeBound { - match *self {} - } - - fn validate(&self, _var_decls: &[TypeParam]) -> Result<(), SignatureError> { - match *self {} - } - - #[allow(private_interfaces)] - fn substitute(&self, _s: &Substitution) -> Vec> { - match *self {} - } - - fn try_from_rv(rv: RowVariable) -> Result { - Err(rv) - } - - #[cfg(test)] - fn weight() -> u32 { - 0 - } - - #[cfg(test)] - fn arb() -> BoxedStrategy { - any::() - .prop_map(|_| panic!("Should be ruled out by weight==0")) - .boxed() - } -} diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index eeff6f2e14..edb5f44da6 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -1,18 +1,18 @@ use std::sync::Arc; use ordered_float::OrderedFloat; +use serde::Serialize; -use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeBase, TypeBound, TypeEnum}; +use super::{FuncValueType, SumType, Term, Type, TypeBound}; use super::custom::CustomType; -use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::AliasDecl; -use crate::types::type_param::{TermVar, UpperBound}; -use crate::types::{Term, Type}; +use crate::types::TypeRowRV; +use crate::types::type_param::{SeqPart, TermTypeError, TermVar, UpperBound}; -#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] +#[derive(Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] pub(crate) enum SerSimpleType { Q, @@ -25,45 +25,69 @@ pub(crate) enum SerSimpleType { R { i: usize, b: TypeBound }, } -impl From> for SerSimpleType { - fn from(value: TypeBase) -> Self { +/// For the things that used to be supported as Types +impl From for SerSimpleType { + fn from(value: Type) -> Self { if value == qb_t() { return SerSimpleType::Q; } if value == usize_t() { return SerSimpleType::I; } - match value.0 { - TypeEnum::Extension(o) => SerSimpleType::Opaque(o), - TypeEnum::Alias(a) => SerSimpleType::Alias(a), - TypeEnum::Function(sig) => SerSimpleType::G(sig), - TypeEnum::Variable(i, b) => SerSimpleType::V { i, b }, - TypeEnum::RowVar(rv) => { - let RowVariable(idx, bound) = rv.as_rv(); - SerSimpleType::R { i: *idx, b: *bound } + match value.into() { + Term::RuntimeExtension(o) => SerSimpleType::Opaque(o), + //TypeEnum::Alias(a) => SerSimpleType::Alias(a), + Term::RuntimeFunction(sig) => SerSimpleType::G(sig), + Term::Variable(tv) => { + let i = tv.index(); + let Term::RuntimeType(b) = &*tv.cached_decl else { + panic!("Variable with bound {} is not a valid Type", tv.cached_decl); + }; + SerSimpleType::V { i, b: *b } } - TypeEnum::Sum(st) => SerSimpleType::Sum(st), + Term::RuntimeSum(st) => SerSimpleType::Sum(st), + v => panic!("{} was not a valid Type", v), } } } -impl TryFrom for TypeBase { - type Error = SignatureError; +impl TryFrom for SerSimpleType { + type Error = TermTypeError; + + fn try_from(value: Term) -> Result { + if let Term::Variable(tv) = &value + && let Term::ListType(t) = &*tv.cached_decl + && let Term::RuntimeType(b) = &**t + { + return Ok(SerSimpleType::R { + i: tv.index(), + b: *b, + }); + } + Type::try_from(value).map(SerSimpleType::from) + } +} + +impl From for Term { + fn from(value: SerSimpleType) -> Self { + match value { + SerSimpleType::Q => qb_t().into(), + SerSimpleType::I => usize_t().into(), + SerSimpleType::G(sig) => Type::new_function(*sig).into(), + SerSimpleType::Sum(st) => Type::from(st).into(), + SerSimpleType::Opaque(o) => Type::new_extension(o).into(), + SerSimpleType::Alias(_) => todo!("alias?"), + SerSimpleType::V { i, b } => Type::new_var_use(i, b).into(), + SerSimpleType::R { i, b } => Term::new_row_var_use(i, b), + } + } +} + +impl TryFrom for Type { + type Error = TermTypeError; + fn try_from(value: SerSimpleType) -> Result { - Ok(match value { - SerSimpleType::Q => qb_t().into_(), - SerSimpleType::I => usize_t().into_(), - SerSimpleType::G(sig) => TypeBase::new_function(*sig), - SerSimpleType::Sum(st) => st.into(), - SerSimpleType::Opaque(o) => TypeBase::new_extension(o), - SerSimpleType::Alias(a) => TypeBase::new_alias(a), - SerSimpleType::V { i, b } => TypeBase::new_var_use(i, b), - // We can't use new_row_var because that returns TypeRV not TypeBase. - SerSimpleType::R { i, b } => TypeBase::new(TypeEnum::RowVar( - RV::try_from_rv(RowVariable(i, b)) - .map_err(|var| SignatureError::RowVarWhereTypeExpected { var })?, - )), - }) + Term::from(value).try_into() } } @@ -87,7 +111,7 @@ pub(super) enum TypeParamSer { #[serde(tag = "tya")] pub(super) enum TypeArgSer { Type { - ty: Type, + ty: SerSimpleType, }, BoundedNat { n: u64, @@ -138,7 +162,11 @@ impl From for TermSer { Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), Term::ConstType(ty) => TermSer::TypeParam(TypeParamSer::ConstType { ty: *ty }), - Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), + Term::RuntimeFunction(_) | Term::RuntimeExtension(_) | Term::RuntimeSum(_) => { + TermSer::TypeArg(TypeArgSer::Type { + ty: value.try_into().unwrap(), + }) + } Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { params: (*params).into(), }), @@ -148,7 +176,15 @@ impl From for TermSer { Term::Float(value) => TermSer::TypeArg(TypeArgSer::Float { value }), Term::List(elems) => TermSer::TypeArg(TypeArgSer::List { elems }), Term::Tuple(elems) => TermSer::TypeArg(TypeArgSer::Tuple { elems }), - Term::Variable(v) => TermSer::TypeArg(TypeArgSer::Variable { v }), + Term::Variable(v) => { + TermSer::TypeArg(if matches!(&*v.cached_decl, Term::RuntimeType(_)) { + TypeArgSer::Type { + ty: SerSimpleType::try_from(Term::Variable(v)).unwrap(), + } + } else { + TypeArgSer::Variable { v } + }) + } Term::ListConcat(lists) => TermSer::TypeArg(TypeArgSer::ListConcat { lists }), Term::TupleConcat(tuples) => TermSer::TypeArg(TypeArgSer::TupleConcat { tuples }), } @@ -170,7 +206,7 @@ impl From for Term { TypeParamSer::ConstType { ty } => Term::ConstType(Box::new(ty)), }, TermSer::TypeArg(arg) => match arg { - TypeArgSer::Type { ty } => Term::Runtime(ty), + TypeArgSer::Type { ty } => Term::from(ty), TypeArgSer::BoundedNat { n } => Term::BoundedNat(n), TypeArgSer::String { arg } => Term::String(arg), TypeArgSer::Bytes { value } => Term::Bytes(value), @@ -233,3 +269,38 @@ mod base64 { .map_err(serde::de::Error::custom) } } + +impl serde::Serialize for TypeRowRV { + fn serialize(&self, serializer: S) -> Result { + let items: Vec = self + .0 + .clone() + .into_list_parts() + .map(|part| match part { + SeqPart::Item(t) => { + let t = Type::try_from(t).unwrap(); + let s = SerSimpleType::from(t); + assert!(!matches!(s, SerSimpleType::R { .. })); + s + } + SeqPart::Splice(t) => { + let s = SerSimpleType::try_from(t).unwrap(); + assert!(matches!(s, SerSimpleType::R { .. })); + s + } + }) + .collect(); + items.serialize(serializer) + } +} + +impl<'de> serde::Deserialize<'de> for TypeRowRV { + fn deserialize>(deser: D) -> Result { + let items: Vec = serde::Deserialize::deserialize(deser)?; + let list_parts = items.into_iter().map(|s| match s { + SerSimpleType::R { i, b } => SeqPart::Splice(Term::new_row_var_use(i, b)), + s => SeqPart::Item(Term::from(s)), + }); + Ok(TypeRowRV::try_from(Term::new_list_from_parts(list_parts)).unwrap()) + } +} diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 0a7200ed92..8d2e3ff9dc 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -2,115 +2,89 @@ use itertools::Either; -use std::borrow::Cow; use std::fmt::{self, Display}; use super::type_param::TypeParam; -use super::type_row::TypeRowBase; -use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeRow, TypeTransformer, -}; +use super::{Substitution, Transformable, Type, TypeRow, TypeTransformer}; use crate::core::PortIndex; use crate::extension::resolution::{ ExtensionCollectionError, WeakExtensionRegistry, collect_signature_exts, }; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::types::{Substitutable, TypeRowRV}; use crate::{Direction, IncomingPort, OutgoingPort, Port}; -#[cfg(test)] -use {crate::proptest::RecursionDepth, proptest::prelude::*, proptest_derive::Arbitrary}; - -#[derive(Clone, Debug, Eq, Hash, serde::Serialize, serde::Deserialize)] -#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] /// Base type for listing inputs and output types. /// -/// The exact semantics depend on the use case: -/// - If `ROWVARS=`[`NoRV`], describes the edges required to/from a node or inside a [`FuncDefn`]. -/// - If `ROWVARS=`[`RowVariable`], describes the type of the inputs/outputs from an `OpDef`. -/// -/// `ROWVARS` specifies whether the type lists may contain [`RowVariable`]s or not. -/// -/// [`FuncDefn`]: crate::ops::FuncDefn -pub struct FuncTypeBase { +/// Parametrized by the type used to list the inputs and outputs. Exactly two +/// instantiations are used: [Signature] and [FuncValueType]. +pub struct FuncTypeBase { /// Value inputs of the function. - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - pub input: TypeRowBase, + pub input: T, /// Value outputs of the function. - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - pub output: TypeRowBase, + pub output: T, } /// The concept of "signature" in the spec - the edges required to/from a node /// or within a [`FuncDefn`], also the target (value) of a call (static). /// +/// Thus, contains a statically-known number of types. +/// /// [`FuncDefn`]: crate::ops::FuncDefn -pub type Signature = FuncTypeBase; +pub type Signature = FuncTypeBase; -/// A function that may contain [`RowVariable`]s and thus has potentially-unknown arity; +/// A function that may contain row variables and thus has potentially-unknown arity; /// used for [`OpDef`]'s and passable as a value round a Hugr (see [`Type::new_function`]) /// but not a valid node type. /// /// [`OpDef`]: crate::extension::OpDef -pub type FuncValueType = FuncTypeBase; - -impl FuncTypeBase { - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { - Self { - input: self.input.substitute(tr), - output: self.output.substitute(tr), - } - } +pub type FuncValueType = FuncTypeBase; +impl FuncTypeBase { /// Create a new signature with specified inputs and outputs. - pub fn new(input: impl Into>, output: impl Into>) -> Self { + pub fn new(input: impl Into, output: impl Into) -> Self { Self { input: input.into(), output: output.into(), } } - /// Create a new signature with the same input and output types (signature of an endomorphic - /// function). - pub fn new_endo(row: impl Into>) -> Self { - let row = row.into(); - Self::new(row.clone(), row) - } - - /// True if both inputs and outputs are necessarily empty. - /// (For [`FuncValueType`], even after any possible substitution of row variables) - #[inline(always)] - #[must_use] - pub fn is_empty(&self) -> bool { - self.input.is_empty() && self.output.is_empty() - } - #[inline] /// Returns a row of the value inputs of the function. #[must_use] - pub fn input(&self) -> &TypeRowBase { + pub fn input(&self) -> &T { &self.input } #[inline] /// Returns a row of the value outputs of the function. #[must_use] - pub fn output(&self) -> &TypeRowBase { + pub fn output(&self) -> &T { &self.output } #[inline] /// Returns a tuple with the input and output rows of the function. #[must_use] - pub fn io(&self) -> (&TypeRowBase, &TypeRowBase) { + pub fn io(&self) -> (&T, &T) { (&self.input, &self.output) } +} - pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - self.input.validate(var_decls)?; - self.output.validate(var_decls) +impl FuncTypeBase { + /// Create a new signature with the same input and output types. + pub fn new_endo(io: impl Into) -> Self { + let io = io.into(); + Self { + input: io.clone(), + output: io, + } } +} +impl Signature { /// Returns a registry with the concrete extensions used by this signature. pub fn used_extensions(&self) -> Result { let mut used = WeakExtensionRegistry::default(); @@ -124,16 +98,47 @@ impl FuncTypeBase { Err(ExtensionCollectionError::dropped_signature(self, missing)) } } + + /// True if both inputs and outputs are necessarily empty. + /// (For [`FuncValueType`], even after any possible substitution of row variables) + #[inline(always)] + #[must_use] + pub fn is_empty(&self) -> bool { + self.input.is_empty() && self.output.is_empty() + } } -impl Transformable for FuncTypeBase { - fn transform(&mut self, tr: &T) -> Result { +impl FuncValueType { + /// True if both inputs and outputs are necessarily empty + /// (even after any possible substitution of row variables) + #[inline(always)] + #[must_use] + pub fn is_empty(&self) -> bool { + self.input.is_empty_list() && self.output.is_empty_list() + } +} + +impl Transformable for FuncTypeBase { + fn transform(&mut self, tr: &U) -> Result { // TODO handle extension sets? Ok(self.input.transform(tr)? | self.output.transform(tr)?) } } -impl FuncValueType { +impl Substitutable for FuncTypeBase { + fn substitute(&self, subst: &Substitution) -> Self { + Self { + input: self.input.substitute(subst), + output: self.output.substitute(subst), + } + } + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { + self.input.validate(var_decls)?; + self.output.validate(var_decls) + } +} + +/*impl FuncValueType { /// If this `FuncValueType` contains any row variables, return one. #[must_use] pub fn find_rowvar(&self) -> Option { @@ -142,17 +147,7 @@ impl FuncValueType { .chain(self.output.iter()) .find_map(|t| Type::try_from(t.clone()).err()) } -} - -// deriving Default leads to an impl that only applies for RV: Default -impl Default for FuncTypeBase { - fn default() -> Self { - Self { - input: Default::default(), - output: Default::default(), - } - } -} +}*/ impl Signature { /// Returns the type of a value [`Port`]. Returns `None` if the port is out @@ -274,7 +269,7 @@ impl Signature { } } -impl Display for FuncTypeBase { +impl Display for FuncTypeBase { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.input.fmt(f)?; f.write_str(" -> ")?; @@ -301,31 +296,40 @@ impl From for FuncValueType { } } -impl PartialEq> for FuncTypeBase { - fn eq(&self, other: &FuncTypeBase) -> bool { +impl PartialEq for Signature { + fn eq(&self, other: &FuncValueType) -> bool { self.input == other.input && self.output == other.output } } -impl PartialEq>> for FuncTypeBase { - fn eq(&self, other: &Cow<'_, FuncTypeBase>) -> bool { - self.eq(other.as_ref()) - } -} - -impl PartialEq> for Cow<'_, FuncTypeBase> { - fn eq(&self, other: &FuncTypeBase) -> bool { - self.as_ref().eq(other) - } -} - #[cfg(test)] mod test { use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::type_row; - use crate::types::{CustomType, TypeEnum, test::FnTransformer}; + use crate::types::{CustomType, Term, test::FnTransformer}; use super::*; + + mod proptest { + use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any_with}; + + use super::FuncTypeBase; + use crate::proptest::RecursionDepth; + + impl + 'static> Arbitrary for FuncTypeBase { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + // We want to generate a random number of type parameters, and then generate a body that can refer to those parameters. + // To do this, we first generate the type parameters, and then pass them as parameters to the body strategy. + (any_with::(params), any_with::(params)) + .prop_map(|(input, output)| Self::new(input, output)) + .boxed() + } + } + } + #[test] fn test_function_type() { let mut f_type = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]); @@ -354,7 +358,7 @@ mod test { #[test] fn test_transform() { - let TypeEnum::Extension(usz_t) = usize_t().as_type_enum().clone() else { + let Term::RuntimeExtension(usz_t) = usize_t().into() else { panic!() }; let tr = FnTransformer(|ct: &CustomType| (ct == &usz_t).then_some(bool_t())); diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index c9ddcc81b9..7f5a148307 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -4,6 +4,7 @@ //! //! [`TypeDef`]: crate::extension::TypeDef +use itertools::Itertools as _; use ordered_float::OrderedFloat; #[cfg(test)] use proptest_derive::Arbitrary; @@ -14,12 +15,9 @@ use std::sync::Arc; use thiserror::Error; use tracing::warn; -use super::row_var::MaybeRV; -use super::{ - NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer, - check_typevar_decl, -}; +use super::{Substitutable, Substitution, Transformable, Type, TypeBound, TypeTransformer}; use crate::extension::SignatureError; +use crate::types::{CustomType, FuncValueType, SumType, TypeRow}; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) @@ -95,9 +93,22 @@ pub enum Term { /// The type of static tuples. #[display("TupleType[{_0}]")] TupleType(Box), - /// A runtime type as a term. Instance of [`Term::RuntimeType`]. + /// The type of runtime values defined by an extension type. + /// Instance of [Self::RuntimeType] for some bound. + // + // TODO optimise with `Box`? + // or some static version of this? + RuntimeExtension(CustomType), + /// The type of runtime values that are function pointers. + /// Instance of [Self::RuntimeType]`(`[TypeBound::Copyable]`)`. + /// Function values may be passed around without knowing their arity + /// (i.e. with row vars) as long as they are not called. #[display("{_0}")] - Runtime(Type), + RuntimeFunction(Box), + /// The type of runtime values that are sums of products (ADTs) + /// Instance of [Self::RuntimeType]`(bound)` for `bound` calculated from each variant's elements. + #[display("{_0}")] + RuntimeSum(SumType), /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`]. #[display("{_0}")] BoundedNat(u64), @@ -111,9 +122,14 @@ pub enum Term { #[display("{}", _0.into_inner())] Float(OrderedFloat), /// A list of static terms. Instance of [`Term::ListType`]. + /// Note, not a [TypeRow] because `impl Arbitrary for TypeRow` generates only types. + /// + /// [TypeRow]: super::TypeRow #[display("[{}]", { use itertools::Itertools as _; - _0.iter().map(|t|t.to_string()).join(",") + //_0.iter().map(|t|t.to_string()).join(",") + // extra space matching old Display for Type(Row) + _0.iter().map(|t|t.to_string()).join(", ") })] List(Vec), /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type. @@ -152,6 +168,8 @@ pub enum Term { } impl Term { + pub(crate) const EMPTY_LIST: Term = Term::List(Vec::new()); + pub(crate) const EMPTY_LIST_REF: &'static Term = &Self::EMPTY_LIST; /// Creates a [`Term::BoundedNatType`] with the maximum bound (`u64::MAX` + 1). #[must_use] pub const fn max_nat_type() -> Self { @@ -166,7 +184,7 @@ impl Term { /// Creates a new [`Term::List`] given a sequence of its items. pub fn new_list(items: impl IntoIterator) -> Self { - Self::List(items.into_iter().collect()) + Self::List(items.into_iter().map_into().collect()) } /// Creates a new [`Term::ListType`] given the type of its elements. @@ -197,24 +215,57 @@ impl Term { (Term::StringType, Term::StringType) => true, (Term::StaticType, Term::StaticType) => true, (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), + // The term inside a TupleType is a list of types, so this is ok as long as + // supertype holds element-wise (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2), (Term::BytesType, Term::BytesType) => true, (Term::FloatType, Term::FloatType) => true, - (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2, + // Needed for TupleType, does not make a great deal of sense otherwise: + (Term::List(es1), Term::List(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) + } + // The following are not types (they have no instances), so these are just to + // maintain reflexivity of the relation: + (Term::RuntimeSum(t1), Term::RuntimeSum(t2)) => t1 == t2, + (Term::RuntimeFunction(f1), Term::RuntimeFunction(f2)) => f1 == f2, + (Term::RuntimeExtension(c1), Term::RuntimeExtension(c2)) => c1 == c2, (Term::BoundedNat(n1), Term::BoundedNat(n2)) => n1 == n2, (Term::String(s1), Term::String(s2)) => s1 == s2, (Term::Bytes(v1), Term::Bytes(v2)) => v1 == v2, (Term::Float(f1), Term::Float(f2)) => f1 == f2, (Term::Variable(v1), Term::Variable(v2)) => v1 == v2, - (Term::List(es1), Term::List(es2)) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) - } (Term::Tuple(es1), Term::Tuple(es2)) => { es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) } _ => false, } } + + /// Returns true if this term is an empty list (contains no elements) + pub fn is_empty_list(&self) -> bool { + match self { + Term::List(v) => v.is_empty(), + // We probably don't need to be this thorough in dealing with unnormalized forms but it's easy enough + Term::ListConcat(v) => v.iter().all(Term::is_empty_list), + _ => false, + } + } + + /// Returns the inner [`CustomType`] if this `Term` is a [Self::RuntimeExtension] + pub fn as_extension(&self) -> Option<&CustomType> { + match self { + Term::RuntimeExtension(ct) => Some(ct), + _ => None, + } + } + + /// Returns the inner [`SumType`] if this `Term` is a [Self::RuntimeSum]. + pub fn as_sum(&self) -> Option<&SumType> { + match self { + Term::RuntimeSum(s) => Some(s), + _ => None, + } + } } impl From for Term { @@ -229,15 +280,6 @@ impl From for Term { } } -impl From> for Term { - fn from(value: TypeBase) -> Self { - match value.try_into_type() { - Ok(ty) => Term::Runtime(ty), - Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)), - } - } -} - impl From for Term { fn from(n: u64) -> Self { Self::BoundedNat(n) @@ -262,14 +304,37 @@ impl From> for Term { } } +impl From> for Term { + fn from(value: Vec) -> Self { + TypeRow::from(value).into() + } +} + impl From<[Term; N]> for Term { fn from(value: [Term; N]) -> Self { Self::new_list(value) } } -/// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`] -/// - it might be a [`Type::new_row_var_use`]). +impl From<[Type; N]> for Term { + fn from(value: [Type; N]) -> Self { + TypeRow::from(value).into() + } +} + +impl From for Term { + fn from(value: SumType) -> Self { + Self::RuntimeSum(value) + } +} + +impl From for Term { + fn from(value: CustomType) -> Self { + Self::RuntimeExtension(value) + } +} + +/// Variable in a [`Term`] #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, )] @@ -280,24 +345,23 @@ pub struct TermVar { } impl Term { - /// [`Type::UNIT`] as a [`Term::Runtime`] - pub const UNIT: Self = Self::Runtime(Type::UNIT); - /// Makes a `TypeArg` representing a use (occurrence) of the type variable /// with the specified index. + /// /// `decl` must be exactly that with which the variable was declared. #[must_use] - pub fn new_var_use(idx: usize, decl: Term) -> Self { - match decl { - // Note a TypeParam::List of TypeParam::Type *cannot* be represented - // as a TypeArg::Type because the latter stores a Type i.e. only a single type, - // not a RowVariable. - Term::RuntimeType(b) => Type::new_var_use(idx, b).into(), - _ => Term::Variable(TermVar { - idx, - cached_decl: Box::new(decl), - }), - } + pub fn new_var_use(idx: usize, decl: impl Into) -> Self { + Term::Variable(TermVar { + idx, + cached_decl: Box::new(decl.into()), + }) + } + + /// Makes a `Term` representing a use (occurrence) of a variable whose + /// kind is a [Term::ListType] of [Term::RuntimeType]. + #[must_use] + pub fn new_row_var_use(idx: usize, b: TypeBound) -> Self { + Self::new_var_use(idx, Term::new_list_type(b)) } /// Creates a new string literal. @@ -308,8 +372,11 @@ impl Term { /// Creates a new concatenated list. #[inline] - pub fn new_list_concat(lists: impl IntoIterator) -> Self { - Self::ListConcat(lists.into_iter().collect()) + pub fn concat_lists(lists: impl IntoIterator) -> Self { + match lists.into_iter().exactly_one() { + Ok(list) => list, + Err(e) => Self::ListConcat(e.collect()), + } } /// Creates a new tuple from its items. @@ -333,122 +400,42 @@ impl Term { } } - /// Returns a [`Type`] if the [`Term`] is a runtime type. - #[must_use] - pub fn as_runtime(&self) -> Option> { + pub(crate) fn least_upper_bound(&self) -> Option { match self { - TypeArg::Runtime(ty) => Some(ty.clone()), + Self::RuntimeExtension(ct) => Some(ct.bound()), + Self::RuntimeSum(st) => Some(st.bound()), + Self::RuntimeFunction(_) => Some(TypeBound::Copyable), + Self::Variable(v) => match &*v.cached_decl { + TypeParam::RuntimeType(b) => Some(*b), + _ => None, + }, _ => None, } } - /// Returns a string if the [`Term`] is a string literal. - #[must_use] - pub fn as_string(&self) -> Option { - match self { - TypeArg::String(arg) => Some(arg.clone()), - _ => None, + /// Report if this is a copyable runtime type, i.e. an instance + /// of [Self::RuntimeType]`(`[TypeBound::Copyable]`)` + // - i.e.the least upper bound of the type is contained by the copyable bound. + pub(crate) fn copyable(&self) -> bool { + match self.least_upper_bound() { + Some(b) => TypeBound::Copyable.contains(b), + None => false, } } - /// Much as [`Type::validate`], also checks that the type of any [`TypeArg::Opaque`] - /// is valid and closed. - pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - match self { - Term::Runtime(ty) => ty.validate(var_decls), - Term::List(elems) => { - // TODO: Full validation would check that the type of the elements agrees - elems.iter().try_for_each(|a| a.validate(var_decls)) - } - Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), - Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), - TypeArg::ListConcat(lists) => { - // TODO: Full validation would check that each of the lists is indeed a - // list or list variable of the correct types. - lists.iter().try_for_each(|a| a.validate(var_decls)) - } - TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), - Term::Variable(TermVar { idx, cached_decl }) => { - assert!( - !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), - "Malformed TypeArg::Variable {cached_decl} - should be inconstructible" - ); - - check_typevar_decl(var_decls, *idx, cached_decl) - } - Term::RuntimeType { .. } => Ok(()), - Term::BoundedNatType { .. } => Ok(()), - Term::StringType => Ok(()), - Term::BytesType => Ok(()), - Term::FloatType => Ok(()), - Term::ListType(item_type) => item_type.validate(var_decls), - Term::TupleType(item_types) => item_types.validate(var_decls), - Term::StaticType => Ok(()), - Term::ConstType(ty) => ty.validate(var_decls), - } + /// Repot if this is a runtime type, i.e. an instance of [Self::RuntimeType] for some bound. + /// + /// If so, [Type::try_from(Type)] will succeed and can be followed by [Type::least_upper_bound] to get the bound. + pub fn is_runtime_type(&self) -> bool { + self.least_upper_bound().is_some() } - pub(crate) fn substitute(&self, t: &Substitution) -> Self { + /// Returns a string if the [`Term`] is a string literal. + #[must_use] + pub fn as_string(&self) -> Option { match self { - Term::Runtime(ty) => { - // RowVariables are represented as Term::Variable - ty.substitute1(t).into() - } - TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { - self.clone() - } // We do not allow variables as bounds on BoundedNat's - TypeArg::List(elems) => { - // NOTE: This implements a hack allowing substitutions to - // replace `TypeArg::Variable`s representing "row variables" - // with a list that is to be spliced into the containing list. - // We won't need this code anymore once we stop conflating types - // with lists of types. - - fn is_type(type_arg: &TypeArg) -> bool { - match type_arg { - TypeArg::Runtime(_) => true, - TypeArg::Variable(v) => v.bound_if_row_var().is_some(), - _ => false, - } - } - - let are_types = elems.first().map(is_type).unwrap_or(false); - - Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) { - list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list), - list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list), - elem => SeqPart::Item(elem), - })) - } - TypeArg::ListConcat(lists) => { - // When a substitution instantiates spliced list variables, we - // may be able to merge the concatenated lists. - Self::new_list_from_parts( - lists.iter().map(|list| SeqPart::Splice(list.substitute(t))), - ) - } - Term::Tuple(elems) => { - Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) - } - TypeArg::TupleConcat(tuples) => { - // When a substitution instantiates spliced tuple variables, - // we may be able to merge the concatenated tuples. - Self::new_tuple_from_parts( - tuples - .iter() - .map(|tuple| SeqPart::Splice(tuple.substitute(t))), - ) - } - TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), - Term::RuntimeType(_) => self.clone(), - Term::BoundedNatType(_) => self.clone(), - Term::StringType => self.clone(), - Term::BytesType => self.clone(), - Term::FloatType => self.clone(), - Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), - Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), - Term::StaticType => self.clone(), - Term::ConstType(ty) => Term::new_const(ty.substitute1(t)), + TypeArg::String(arg) => Some(arg.clone()), + _ => None, } } @@ -488,7 +475,7 @@ impl Term { Self::new_seq_from_parts( parts.into_iter().flat_map(ListPartIter::new), TypeArg::List, - TypeArg::ListConcat, + TypeArg::concat_lists, ) } @@ -518,7 +505,7 @@ impl Term { /// # let b = Term::new_string("b"); /// # let c = Term::new_string("c"); /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); - /// let term = Term::new_list_concat([ + /// let term = Term::concat_lists([ /// Term::new_list([a.clone(), b.clone()]), /// var.clone(), /// Term::new_list([c.clone()]) @@ -537,8 +524,8 @@ impl Term { /// # let a = Term::new_string("a"); /// # let b = Term::new_string("b"); /// # let c = Term::new_string("c"); - /// let term = Term::new_list_concat([ - /// Term::new_list_concat([ + /// let term = Term::concat_lists([ + /// Term::concat_lists([ /// Term::new_list([a.clone()]), /// Term::new_list([b.clone()]) /// ]), @@ -565,7 +552,7 @@ impl Term { /// ); /// ``` #[inline] - pub fn into_list_parts(self) -> ListPartIter { + pub fn into_list_parts(self) -> impl Iterator> { ListPartIter::new(SeqPart::Splice(self)) } @@ -584,15 +571,56 @@ impl Term { /// /// Analogous to [`TypeArg::into_list_parts`]. #[inline] - pub fn into_tuple_parts(self) -> TuplePartIter { + pub fn into_tuple_parts(self) -> impl Iterator> { TuplePartIter::new(SeqPart::Splice(self)) } } +fn check_typevar_decl( + decls: &[TypeParam], + idx: usize, + cached_decl: &TypeParam, +) -> Result<(), SignatureError> { + match decls.get(idx) { + None => Err(SignatureError::FreeTypeVar { + idx, + num_decls: decls.len(), + }), + Some(actual) => { + // The cache here just mirrors the declaration. The typevar can be used + // anywhere expecting a kind *containing* the decl - see `check_type_arg`. + if actual == cached_decl { + Ok(()) + } else { + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + cached: Box::new(cached_decl.clone()), + actual: Box::new(actual.clone()), + }) + } + } + } +} + impl Transformable for Term { fn transform(&mut self, tr: &T) -> Result { match self { - Term::Runtime(ty) => ty.transform(tr), + Term::RuntimeExtension(custom_type) => { + if let Some(nt) = tr.apply_custom(custom_type)? { + *self = nt.0; + Ok(true) + } else { + let args_changed = custom_type.args_mut().transform(tr)?; + if args_changed { + *self = custom_type + .get_type_def(&custom_type.get_extension()?)? + .instantiate(custom_type.args())? + .into(); + } + Ok(args_changed) + } + } + Term::RuntimeFunction(fty) => fty.transform(tr), + Term::RuntimeSum(sum_type) => sum_type.transform(tr), Term::List(elems) => elems.transform(tr), Term::Tuple(elems) => elems.transform(tr), Term::BoundedNat(_) @@ -615,6 +643,94 @@ impl Transformable for Term { } } +impl Substitutable for Term { + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { + match self { + Term::RuntimeSum(SumType::General { rows }) => { + rows.iter().try_for_each(|row| row.validate(var_decls))?; + Ok(()) + } + Term::RuntimeSum(SumType::Unit { .. }) => Ok(()), // No leaves there + Term::RuntimeExtension(custy) => custy.validate(var_decls), + Term::RuntimeFunction(ft) => ft.validate(var_decls), + Term::List(elems) => { + // Full validation might check that the type of the elements agrees. + // However we will leave this to a separate check_term_type which knows + // the required element type. + elems.iter().try_for_each(|a| a.validate(var_decls)) + } + Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), + Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), + TypeArg::ListConcat(lists) => { + // Full validation might check that each of the lists is indeed a list or + // list variable of the correct types. However we will leave this to a + // separate check_term_type which knows the required element type. + lists.iter().try_for_each(|a| a.validate(var_decls)) + } + TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), + Term::Variable(TermVar { idx, cached_decl }) => { + check_typevar_decl(var_decls, *idx, cached_decl) + } + Term::RuntimeType { .. } => Ok(()), + Term::BoundedNatType { .. } => Ok(()), + Term::StringType => Ok(()), + Term::BytesType => Ok(()), + Term::FloatType => Ok(()), + Term::ListType(item_type) => item_type.validate(var_decls), + Term::TupleType(item_types) => item_types.validate(var_decls), + Term::StaticType => Ok(()), + Term::ConstType(ty) => ty.validate(var_decls), + } + } + + fn substitute(&self, t: &Substitution) -> Self { + match self { + TypeArg::RuntimeSum(SumType::Unit { .. }) => self.clone(), + TypeArg::RuntimeSum(SumType::General { rows }) => { + // A substitution of a row variable for an empty list, + // could make the general case into a unary SumType. + Term::RuntimeSum(SumType::new(rows.iter().map(|r| r.substitute(t)))) + } + TypeArg::RuntimeExtension(cty) => Term::RuntimeExtension(cty.substitute(t)), + TypeArg::RuntimeFunction(bf) => Term::RuntimeFunction(Box::new(bf.substitute(t))), + + TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { + self.clone() + } // We do not allow variables as bounds on BoundedNat's + TypeArg::List(elems) => Self::List(elems.iter().map(|e| e.substitute(t)).collect()), + TypeArg::ListConcat(lists) => { + // When a substitution instantiates spliced list variables, we + // may be able to merge the concatenated lists. + Self::new_list_from_parts( + lists.iter().map(|list| SeqPart::Splice(list.substitute(t))), + ) + } + Term::Tuple(elems) => { + Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) + } + TypeArg::TupleConcat(tuples) => { + // When a substitution instantiates spliced tuple variables, + // we may be able to merge the concatenated tuples. + Self::new_tuple_from_parts( + tuples + .iter() + .map(|tuple| SeqPart::Splice(tuple.substitute(t))), + ) + } + TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), + Term::RuntimeType(_) => self.clone(), + Term::BoundedNatType(_) => self.clone(), + Term::StringType => self.clone(), + Term::BytesType => self.clone(), + Term::FloatType => self.clone(), + Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), + Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), + Term::StaticType => self.clone(), + Term::ConstType(ty) => Term::new_const(ty.substitute(t)), + } + } +} + impl TermVar { /// Return the index. #[must_use] @@ -641,24 +757,17 @@ pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => { Ok(()) } - (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => { + (Term::RuntimeSum(st), Term::RuntimeType(bound)) if bound.contains(st.bound()) => Ok(()), + (Term::RuntimeFunction(_), Term::RuntimeType(_)) => Ok(()), // Function pointers are always Copyable so fit any bound + (Term::RuntimeExtension(cty), Term::RuntimeType(bound)) if bound.contains(cty.bound()) => { Ok(()) } - (Term::List(elems), Term::ListType(item_type)) => { - elems.iter().try_for_each(|term| { - // Also allow elements that are RowVars if fitting into a List of Types - if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) - && v.bound_if_row_var() - .is_some_and(|arg_bound| param_bound.contains(arg_bound)) - { - return Ok(()); - } - check_term_type(term, item_type) - }) - } - (Term::ListConcat(lists), Term::ListType(item_type)) => lists + (Term::List(elems), Term::ListType(item_type)) => elems .iter() - .try_for_each(|list| check_term_type(list, item_type)), + .try_for_each(|elem| check_term_type(elem, item_type)), + (Term::ListConcat(lists), Term::ListType(_)) => lists + .iter() + .try_for_each(|list| check_term_type(list, type_)), // ALAN this used the element type, which seems very wrong (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => { let term_parts: Vec<_> = term.clone().into_tuple_parts().collect(); let type_parts: Vec<_> = item_types.clone().into_list_parts().collect(); @@ -762,7 +871,7 @@ pub enum SeqPart { /// Iterator created by [`TypeArg::into_list_parts`]. #[derive(Debug, Clone)] -pub struct ListPartIter { +pub(crate) struct ListPartIter { parts: SmallVec<[SeqPart; 1]>, } @@ -797,7 +906,7 @@ impl FusedIterator for ListPartIter {} /// Iterator created by [`TypeArg::into_tuple_parts`]. #[derive(Debug, Clone)] -pub struct TuplePartIter { +pub(crate) struct TuplePartIter { parts: SmallVec<[SeqPart; 1]>, } @@ -836,9 +945,9 @@ mod test { use super::{Substitution, TypeArg, TypeParam, check_term_type}; use crate::extension::prelude::{bool_t, usize_t}; - use crate::types::Term; + use crate::types::Substitutable; use crate::types::type_param::SeqPart; - use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; + use crate::types::{Term, Type, TypeBound, TypeRow, type_param::TermTypeError}; #[test] fn new_list_from_parts_items() { @@ -868,13 +977,13 @@ mod test { let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); let parts = [ SeqPart::Splice(Term::new_list([a.clone(), b.clone()])), - SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])), + SeqPart::Splice(Term::concat_lists([Term::new_list([c.clone()])])), SeqPart::Item(d.clone()), SeqPart::Splice(var.clone()), ]; assert_eq!( Term::new_list_from_parts(parts), - Term::new_list_concat([Term::new_list([a, b, c, d]), var]) + Term::concat_lists([Term::new_list([a, b, c, d]), var]) ); } @@ -899,7 +1008,7 @@ mod test { #[test] fn type_arg_fits_param() { - let rowvar = TypeRV::new_row_var_use; + let rowvar = Term::new_row_var_use; fn check(arg: impl Into, param: &TypeParam) -> Result<(), TermTypeError> { check_term_type(&arg.into(), param) } @@ -910,42 +1019,45 @@ mod test { let arg = args.iter().cloned().map_into().collect_vec().into(); check_term_type(&arg, param) } - // Simple cases: a Term::Type is a Term::RuntimeType but singleton sequences are lists + // Simple cases: Term::RuntimeXXXs are Term::RuntimeType's check(usize_t(), &TypeBound::Copyable.into()).unwrap(); let seq_param = TypeParam::new_list_type(TypeBound::Copyable); check(usize_t(), &seq_param).unwrap_err(); + // ...but singleton sequences thereof are lists check_seq(&[usize_t()], &TypeBound::Linear.into()).unwrap_err(); // Into a list of type, we can fit a single row var check(rowvar(0, TypeBound::Copyable), &seq_param).unwrap(); - // or a list of (types or row vars) - check(vec![], &seq_param).unwrap(); - check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap(); - check_seq( - &[ + // or a list of types, or a "concat" of row vars + check([usize_t()], &seq_param).unwrap(); + check( + Term::ListConcat(vec![rowvar(0, TypeBound::Copyable); 2]), + &seq_param, + ) + .unwrap(); + // but a *list* of the rowvar is a list of list of types, which is wrong + check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap_err(); + check( + Term::concat_lists([ rowvar(1, TypeBound::Linear), - usize_t().into(), + vec![usize_t()].into(), rowvar(0, TypeBound::Copyable), - ], + ]), &TypeParam::new_list_type(TypeBound::Linear), ) .unwrap(); - // Next one fails because a list of Eq is required - check_seq( - &[ + // Next one fails because a list of Copyable is required + check( + Term::concat_lists([ rowvar(1, TypeBound::Linear), - usize_t().into(), + vec![usize_t()].into(), rowvar(0, TypeBound::Copyable), - ], + ]), &seq_param, ) .unwrap_err(); // seq of seq of types is not allowed - check( - vec![usize_t().into(), vec![usize_t().into()].into()], - &seq_param, - ) - .unwrap_err(); + check(vec![Term::from(usize_t()), [usize_t()].into()], &seq_param).unwrap_err(); // Similar for nats (but no equivalent of fancy row vars) check(5, &TypeParam::max_nat_type()).unwrap(); @@ -963,7 +1075,7 @@ mod test { // `Term::TupleType` requires a `Term::Tuple` of the same number of elems let usize_and_ty = - TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]); + TypeParam::new_tuple_type([TypeParam::max_nat_type(), Term::from(TypeBound::Copyable)]); check( TypeArg::Tuple(vec![5.into(), usize_t().into()]), &usize_and_ty, @@ -974,10 +1086,9 @@ mod test { &usize_and_ty, ) .unwrap_err(); // Wrong way around - let two_types = TypeParam::new_tuple_type(Term::new_list([ - TypeBound::Linear.into(), - TypeBound::Linear.into(), - ])); + + let two_types = Term::new_list([TypeBound::Linear.into(), TypeBound::Linear.into()]); + let two_types = TypeParam::new_tuple_type(two_types); check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); // not a Row Var which could have any number of elems check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); @@ -986,23 +1097,20 @@ mod test { #[test] fn type_arg_subst_row() { let row_param = Term::new_list_type(TypeBound::Copyable); - let row_arg: Term = vec![bool_t().into(), Term::UNIT].into(); + let row_arg: Term = vec![bool_t(), Type::UNIT].into(); check_term_type(&row_arg, &row_param).unwrap(); // Now say a row variable referring to *that* row was used // to instantiate an outer "row parameter" (list of type). let outer_param = Term::new_list_type(TypeBound::Linear); - let outer_arg = Term::new_list([ - TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), - usize_t().into(), + let outer_arg = Term::concat_lists([ + Term::new_row_var_use(0, TypeBound::Copyable), + Term::new_list([usize_t().into()]), ]); check_term_type(&outer_arg, &outer_param).unwrap(); let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); - assert_eq!( - outer_arg2, - vec![bool_t().into(), Term::UNIT, usize_t().into()].into() - ); + assert_eq!(outer_arg2, vec![bool_t(), Type::UNIT, usize_t()].into()); // Of course this is still valid (as substitution is guaranteed to preserve validity) check_term_type(&outer_arg2, &outer_param).unwrap(); @@ -1015,9 +1123,9 @@ mod test { let row_var_use = Term::new_var_use(0, row_var_decl.clone()); let good_arg = Term::new_list([ // The row variables here refer to `row_var_decl` above - vec![usize_t().into()].into(), + vec![usize_t()].into(), row_var_use.clone(), - vec![row_var_use, usize_t().into()].into(), + Term::concat_lists([row_var_use, Term::new_list([usize_t().into()])]), ]); check_term_type(&good_arg, &outer_param).unwrap(); @@ -1025,7 +1133,8 @@ mod test { let Term::List(mut elems) = good_arg.clone() else { panic!() }; - elems.push(usize_t().into()); + let t: Term = usize_t().into(); + elems.push(t); assert_eq!( check_term_type(&Term::new_list(elems), &outer_param), Err(TermTypeError::TypeMismatch { @@ -1036,20 +1145,33 @@ mod test { ); // Now substitute a list of two types for that row-variable - let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); + let row_var_arg = vec![usize_t(), bool_t()].into(); check_term_type(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = good_arg.substitute(&Substitution(std::slice::from_ref(&row_var_arg))); check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, Term::new_list([ - Term::new_list([usize_t().into()]), + [usize_t()].into(), row_var_arg, - Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()]) + [usize_t(), bool_t(), usize_t()].into() ]) ); } + #[test] + fn test_try_into_list_elements() { + // Test successful conversion with List + let types = vec![Type::new_unit_sum(1), bool_t()]; + let term = TypeArg::new_list(types.iter().cloned().map_into()); + let result = TypeRow::try_from(term); + assert_eq!(result, Ok(TypeRow::from(types))); + + // Test failure with non-list + let result = TypeRow::try_from(Term::from(Type::UNIT)); + assert!(result.is_err()); + } + #[test] fn bytes_json_roundtrip() { let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into()); @@ -1058,13 +1180,41 @@ mod test { assert_eq!(deserialized, bytes_arg); } - mod proptest { + #[test] + fn list_from_single_part_item() { + // arbitrary, not but worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!( + Term::List(vec![term.clone()]), + Term::new_list_from_parts(std::iter::once(SeqPart::Item(term))) + ); + } + + #[test] + fn list_from_single_part_splice() { + // arbitrary, not but worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!( + term.clone(), + Term::new_list_from_parts(std::iter::once(SeqPart::Splice(term))) + ); + } + + #[test] + fn list_concat_single_item() { + // arbitrary, not but worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!(term.clone(), Term::concat_lists([term])); + } + mod proptest { + use prop::{collection::vec, strategy::Union}; use proptest::prelude::*; use super::super::{TermVar, UpperBound}; use crate::proptest::RecursionDepth; - use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param}; + use crate::types::proptest_utils::any_serde_type_param; + use crate::types::{Term, Type, TypeBound}; impl Arbitrary for TermVar { type Parameters = RecursionDepth; @@ -1083,9 +1233,7 @@ mod test { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - use prop::collection::vec; - use prop::strategy::Union; - let mut strat = Union::new([ + let strat = Union::new([ Just(Self::StringType).boxed(), Just(Self::BytesType).boxed(), Just(Self::FloatType).boxed(), @@ -1100,32 +1248,29 @@ mod test { any::() .prop_map(|value| Self::Float(value.into())) .boxed(), - any_with::(depth).prop_map(Self::from).boxed(), + any_with::(depth).prop_map_into().boxed(), ]); - if !depth.leaf() { - // we descend here because we these constructors contain Terms - strat = strat - .or( - // TODO this is a bit dodgy, TypeArgVariables are supposed - // to be constructed from TypeArg::new_var_use. We are only - // using this instance for serialization now, but if we want - // to generate valid TypeArgs this will need to change. - any_with::(depth.descend()) - .prop_map(Self::Variable) - .boxed(), - ) - .or(any_with::(depth.descend()) - .prop_map(Self::new_list_type) - .boxed()) - .or(any_with::(depth.descend()) - .prop_map(Self::new_tuple_type) - .boxed()) - .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(Self::new_list) - .boxed()); + if depth.leaf() { + return strat.boxed(); } - - strat.boxed() + // we descend here because we these constructors contain Terms + let depth = depth.descend(); + strat + .or( + // TODO this means we have two ways to create variables of type + // `RuntimeType`, so we probably get more of them than we should` + any_with::(depth).prop_map(Self::Variable).boxed(), + ) + .or(any_with::(depth) + .prop_map(Self::new_list_type) + .boxed()) + .or(any_with::(depth) + .prop_map(Self::new_tuple_type) + .boxed()) + .or(vec(any_with::(depth), 0..3) + .prop_map(Self::new_list) + .boxed()) + .boxed() } } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index db9314ff66..8c80a0c7a7 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -8,42 +8,34 @@ use std::{ }; use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Term, Transformable, Type, TypeArg, TypeBase, TypeRV, - TypeTransformer, type_param::TypeParam, + Substitutable, Substitution, Term, Transformable, Type, TypeTransformer, type_param::TypeParam, +}; +use crate::{ + extension::SignatureError, + types::{ + TypeBound, + type_param::{TermTypeError, check_term_type}, + }, + utils::display_list, }; -use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; +use derive_more::Display; use itertools::Itertools; /// List of types, used for function signatures. -/// The `ROWVARS` parameter controls whether this may contain [`RowVariable`]s -#[derive(Clone, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] +/// +/// Also allows sharing via `Cow` and static allocation via [type_row!]. +/// +/// [type_row!]: crate::type_row +#[derive(Clone, PartialEq, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[serde(transparent)] -pub struct TypeRowBase { +pub struct TypeRow { /// The datatypes in the row. - types: Cow<'static, [TypeBase]>, -} - -/// Row of single types i.e. of known length, for node inputs/outputs -pub type TypeRow = TypeRowBase; - -/// Row of types and/or row variables, the number of actual types is thus -/// unknown -pub type TypeRowRV = TypeRowBase; - -impl PartialEq> for TypeRowBase { - fn eq(&self, other: &TypeRowBase) -> bool { - self.types.len() == other.types.len() - && self - .types - .iter() - .zip(other.types.iter()) - .all(|(s, o)| s == o) - } + types: Cow<'static, [Type]>, } -impl Display for TypeRowBase { +impl Display for TypeRow { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_char('[')?; display_list(self.types.as_ref(), f)?; @@ -51,7 +43,7 @@ impl Display for TypeRowBase { } } -impl TypeRowBase { +impl TypeRow { /// Create a new empty row. #[must_use] pub const fn new() -> Self { @@ -61,48 +53,47 @@ impl TypeRowBase { } /// Returns a new `TypeRow` with `xs` concatenated onto `self`. - pub fn extend<'a>(&'a self, rest: impl IntoIterator>) -> Self { + pub fn extend<'a>(&'a self, rest: impl IntoIterator) -> Self { self.iter().chain(rest).cloned().collect_vec().into() } /// Returns a reference to the types in the row. #[must_use] - pub fn as_slice(&self) -> &[TypeBase] { + pub fn as_slice(&self) -> &[Type] { &self.types } - /// Applies a substitution to the row. - /// For `TypeRowRV`, note this may change the length of the row. - /// For `TypeRow`, guaranteed not to change the length of the row. - pub(crate) fn substitute(&self, s: &Substitution) -> Self { - self.iter() - .flat_map(|ty| ty.substitute(s)) - .collect::>() - .into() - } - delegate! { to self.types { /// Iterator over the types in the row. - pub fn iter(&self) -> impl Iterator>; + pub fn iter(&self) -> impl Iterator; /// Mutable vector of the types in the row. - pub fn to_mut(&mut self) -> &mut Vec>; + pub fn to_mut(&mut self) -> &mut Vec; /// Allow access (consumption) of the contained elements - #[must_use] pub fn into_owned(self) -> Vec>; + #[must_use] pub fn into_owned(self) -> Vec; /// Returns `true` if the row contains no types. - #[must_use] pub fn is_empty(&self) -> bool ; + #[must_use] pub fn is_empty(&self) -> bool; } } +} - pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { +impl Substitutable for TypeRow { + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { self.iter().try_for_each(|t| t.validate(var_decls)) } + + fn substitute(&self, s: &Substitution) -> Self { + self.iter() + .map(|ty| ty.substitute(s)) + .collect::>() + .into() + } } -impl Transformable for TypeRowBase { +impl Transformable for TypeRow { fn transform(&mut self, tr: &T) -> Result { self.to_mut().transform(tr) } @@ -127,259 +118,289 @@ impl TypeRow { } } -impl TryFrom for TypeRow { - type Error = SignatureError; - - fn try_from(value: TypeRowRV) -> Result { - Ok(Self::from( - value - .into_owned() - .into_iter() - .map(std::convert::TryInto::try_into) - .collect::, _>>() - .map_err(|var| SignatureError::RowVarWhereTypeExpected { var })?, - )) - } -} - -impl Default for TypeRowBase { +impl Default for TypeRow { fn default() -> Self { Self::new() } } -impl From>> for TypeRowBase { - fn from(types: Vec>) -> Self { +impl From> for TypeRow { + fn from(types: Vec) -> Self { Self { types: types.into(), } } } -impl From> for TypeRowRV { - fn from(types: Vec) -> Self { - Self { - types: types.into_iter().map(Type::into_).collect(), - } - } -} +impl TryFrom> for TypeRow { + type Error = TermTypeError; -impl From for TypeRowRV { - fn from(value: TypeRow) -> Self { - Self { - types: value.into_owned().into_iter().map(Type::into_).collect(), - } + fn try_from(value: Vec) -> Result { + value + .into_iter() + .map(Type::try_from) + .collect::, _>>() + .map(Self::from) } } -impl From<[TypeBase; N]> for TypeRowBase { - fn from(types: [TypeBase; N]) -> Self { - Self::from(Vec::from(types)) +impl TryFrom for TypeRow { + type Error = TermTypeError; + + fn try_from(value: TypeRowRV) -> Result { + value.0.try_into() } } -impl From<[Type; N]> for TypeRowRV { +impl From<[Type; N]> for TypeRow { fn from(types: [Type; N]) -> Self { Self::from(Vec::from(types)) } } -impl From<&'static [TypeBase]> for TypeRowBase { - fn from(types: &'static [TypeBase]) -> Self { +impl From<&'static [Type]> for TypeRow { + fn from(types: &'static [Type]) -> Self { Self { types: types.into(), } } } -// Fallibly convert a [Term] to a [TypeRV]. -// -// This will fail if `arg` is of non-type kind (e.g. String). -impl TryFrom for TypeRV { - type Error = SignatureError; - - fn try_from(value: Term) -> Result { - match value { - TypeArg::Runtime(ty) => Ok(ty.into()), - TypeArg::Variable(v) => Ok(TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )), - _ => Err(SignatureError::InvalidTypeArgs), +impl PartialEq for TypeRow { + fn eq(&self, other: &Term) -> bool { + let Term::List(items) = other else { + return false; + }; + if self.types.len() != items.len() { + return false; } + self.types.iter().zip_eq(items).all(|(ty, tm)| &**ty == tm) } } -// Fallibly convert a [Term] to a [TypeRow]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type) -// or if the sequence contains row variables. -impl TryFrom for TypeRow { - type Error = SignatureError; - - fn try_from(value: TypeArg) -> Result { - match value { - TypeArg::List(elems) => elems - .into_iter() - .map(|ta| ta.as_runtime().ok_or(SignatureError::InvalidTypeArgs)) - .collect::, _>>() - .map(TypeRow::from), - _ => Err(SignatureError::InvalidTypeArgs), - } +impl PartialEq for TypeRow { + fn eq(&self, other: &TypeRowRV) -> bool { + self == &other.0 } } -// Fallibly convert a [TypeArg] to a [TypeRowRV]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type). -impl TryFrom for TypeRowRV { - type Error = SignatureError; +/// Fallibly convert a [Term] to a [TypeRow]. +/// +/// This will fail if `arg` is not a [Term::List] or any of the elements are not [Type]s +impl TryFrom for TypeRow { + type Error = TermTypeError; fn try_from(value: Term) -> Result { match value { - TypeArg::List(elems) => elems + Term::List(elems) => Ok(elems .into_iter() - .map(TypeRV::try_from) - .collect::, _>>() - .map(|vec| vec.into()), - TypeArg::Variable(v) => Ok(vec![TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )] - .into()), - _ => Err(SignatureError::InvalidTypeArgs), + .map(Type::try_from) + .collect::, _>>()? + .into()), + v => Err(TermTypeError::InvalidValue(Box::new(v))), } } } impl From for Term { fn from(value: TypeRow) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) + Term::new_list(value.into_owned().into_iter().map_into()) } } -impl From for Term { - fn from(value: TypeRowRV) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) +impl From for TypeRowRV { + fn from(value: TypeRow) -> Self { + Self(Term::from(value)) } } -impl Deref for TypeRowBase { - type Target = [TypeBase]; +impl Deref for TypeRow { + type Target = [Type]; fn deref(&self) -> &Self::Target { self.as_slice() } } -impl DerefMut for TypeRowBase { +impl DerefMut for TypeRow { fn deref_mut(&mut self) -> &mut Self::Target { self.types.to_mut() } } +/// Row of types and/or row variables, the number of actual types is thus +/// unknown. +/// +/// A [Term] that `check_term_type`s against [Term::ListType] of [Term::RuntimeType] +/// (of a [TypeBound]), i.e. one of +/// * A [Term::Variable] of type [Term::ListType] (of [Term::RuntimeType]...) +/// * A [Term::List], each of whose elements is of type some [Term::RuntimeType] +/// * A [Term::ListConcat], each of whose sublists is one of these three +/// +/// [TypeBound]: crate::types::TypeBound +#[derive(Clone, Debug, Display, PartialEq, Eq, Hash)] +#[display("{_0}")] +pub struct TypeRowRV(pub(super) Term); + +impl TypeRowRV { + const EMPTY: TypeRowRV = Self(Term::List(vec![])); + pub(super) const EMPTY_REF: &TypeRowRV = &Self::EMPTY; + + /// Create a new empty row. + pub const fn new() -> Self { + Self::EMPTY + } + + /// Wraps the given Term, without checking its type. + pub fn new_unchecked(t: impl Into) -> Self { + Self(t.into()) + } + + /// Creates a singleton row with just a row variable + /// (a variable ranging over lists of types of any length) + pub fn just_row_var(idx: usize, b: TypeBound) -> Self { + Self(Term::new_row_var_use(idx, b)) + } + + /// Concatenates another TypeRowRV onto the end of this one + pub fn concat(self, other: impl Into) -> Self { + Self(Term::concat_lists([self.0, other.into().0])) + } +} + +impl Substitutable for TypeRowRV { + /// Checks that this is indeed a list of runtime types; + /// and that all variables are as declared in the supplied list of params. + fn validate(&self, vars: &[TypeParam]) -> Result<(), SignatureError> { + check_term_type(&self.0, &Term::new_list_type(TypeBound::Linear))?; + self.0.validate(vars) + } + + fn substitute(&self, s: &Substitution) -> Self { + // Substitution cannot make this invalid if it was valid previously + Self::new_unchecked(self.0.substitute(s)) + } +} + +impl Default for TypeRowRV { + /// Makes a new empty list + fn default() -> Self { + Self::EMPTY + } +} + +impl Transformable for TypeRowRV { + fn transform(&mut self, t: &T) -> Result { + self.0.transform(t) + } +} + +impl std::ops::Deref for TypeRowRV { + type Target = Term; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl TryFrom for TypeRowRV { + type Error = TermTypeError; + + fn try_from(t: Term) -> Result { + check_term_type(&t, &Term::new_list_type(TypeBound::Linear))?; + Ok(Self(t)) + } +} + +impl From for Term { + fn from(value: TypeRowRV) -> Self { + value.0 + } +} + +// This allows an easy syntax for building TypeRowRV's which are all Types +impl> From for TypeRowRV { + fn from(value: T) -> Self { + Self(Term::new_list(value.into_iter().map_into())) + } +} + +/*impl FromIterator for TypeRowRV { + fn from_iter>(iter: T) -> Self { + Self(Term::new_list(iter.into_iter().map_into())) + } +}*/ + #[cfg(test)] mod test { use super::*; - use crate::{ - extension::prelude::bool_t, - types::{Type, TypeArg, TypeRV}, - }; + use crate::{extension::prelude::bool_t, types::Type}; mod proptest { - use crate::proptest::RecursionDepth; - use crate::types::{MaybeRV, TypeBase, TypeRowBase}; + use super::{TypeRow, TypeRowRV}; + use crate::{proptest::RecursionDepth, types::Type}; use ::proptest::prelude::*; - impl Arbitrary for super::super::TypeRowBase { + impl Arbitrary for TypeRow { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { use proptest::collection::vec; if depth.leaf() { - Just(TypeRowBase::new()).boxed() + Just(TypeRow::new()).boxed() } else { - vec(any_with::>(depth), 0..4) + vec(any_with::(depth.descend()), 0..4) .prop_map(|ts| ts.clone().into()) .boxed() } } } - } - - #[test] - fn test_try_from_term_to_typerv() { - // Test successful conversion with Runtime type - let runtime_type = Type::UNIT; - let term = TypeArg::Runtime(runtime_type.clone()); - let result = TypeRV::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRV::from(runtime_type)); - // Test failure with non-type kind - let term = Term::String("test".to_string()); - let result = TypeRV::try_from(term); - assert!(result.is_err()); + impl Arbitrary for TypeRowRV { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + use proptest::collection::vec; + if depth.leaf() { + Just(TypeRowRV::default()).boxed() + } else { + // TODO ALAN include row variables here too! + vec(any_with::(depth.descend()), 0..4) + .prop_map(|ts| ts.clone().into()) + .boxed() + } + } + } } #[test] fn test_try_from_term_to_typerow() { // Test successful conversion with List let types = vec![Type::new_unit_sum(1), bool_t()]; - let type_args = types.iter().map(|t| TypeArg::Runtime(t.clone())).collect(); - let term = TypeArg::List(type_args); + let term = Term::new_list(types.iter().cloned().map_into()); let result = TypeRow::try_from(term); assert!(result.is_ok()); assert_eq!(result.unwrap(), TypeRow::from(types)); // Test failure with non-list - let term = TypeArg::Runtime(Type::UNIT); + let term = Term::from(Type::UNIT); let result = TypeRow::try_from(term); assert!(result.is_err()); } - #[test] - fn test_try_from_term_to_typerowrv() { - // Test successful conversion with List - let types = [TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_args = types.iter().map(|t| t.clone().into()).collect(); - let term = TypeArg::List(type_args); - let result = TypeRowRV::try_from(term); - assert!(result.is_ok()); - - // Test failure with non-sequence kind - let term = Term::String("test".to_string()); - let result = TypeRowRV::try_from(term); - assert!(result.is_err()); - } - #[test] fn test_from_typerow_to_term() { let types = vec![Type::UNIT, bool_t()]; let type_row = TypeRow::from(types); - let term = Term::from(type_row); + let term = Term::from(type_row.clone()); - match term { + match &term { Term::List(elems) => { assert_eq!(elems.len(), 2); } _ => panic!("Expected Term::List"), } - } - - #[test] - fn test_from_typerowrv_to_term() { - let types = vec![TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_row_rv = TypeRowRV::from(types); - let term = Term::from(type_row_rv); - match term { - TypeArg::List(elems) => { - assert_eq!(elems.len(), 2); - } - _ => panic!("Expected Term::List"), - } + assert_eq!(term.try_into(), Ok(type_row)); } } diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 435ab6ff92..0dff19dc4b 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -7,7 +7,7 @@ use hugr_core::ops::{ }; use hugr_core::{ HugrView, NodeIndex, - types::{SumType, Type, TypeEnum}, + types::{SumType, Type}, }; use inkwell::types::BasicTypeEnum; use inkwell::values::BasicValueEnum; @@ -101,11 +101,11 @@ where } fn get_exactly_one_sum_type(ts: impl IntoIterator) -> Result { - let Some(TypeEnum::Sum(sum_type)) = ts + let Some(sum_type) = ts .into_iter() - .map(|t| t.as_type_enum().clone()) .exactly_one() .ok() + .and_then(|t| t.as_sum().cloned()) else { Err(anyhow!("Not exactly one SumType"))? }; diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 7ca58b6e74..3992845038 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -25,7 +25,7 @@ use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ self, ArrayClone, ArrayDiscard, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::TypeArg; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::Builder; @@ -498,7 +498,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -559,7 +559,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -621,7 +621,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/borrow_array.rs b/hugr-llvm/src/extension/collections/borrow_array.rs index 0ce14e6ef4..c726281598 100644 --- a/hugr-llvm/src/extension/collections/borrow_array.rs +++ b/hugr-llvm/src/extension/collections/borrow_array.rs @@ -32,7 +32,7 @@ use hugr_core::std_extensions::collections::borrow_array::{ BArrayRepeat, BArrayScan, BArrayToArray, BArrayToArrayDef, BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::TypeArg; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::Builder; @@ -1083,7 +1083,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("BArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -1149,7 +1149,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("BArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -1216,7 +1216,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("BArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index 37e173ecac..decfbf40b6 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -4,7 +4,7 @@ use hugr_core::{ extension::simple_op::MakeExtensionOp as _, ops::ExtensionOp, std_extensions::collections::list::{self, ListOp, ListValue}, - types::{SumType, Type, TypeArg}, + types::{SumType, Type}, }; use inkwell::values::FunctionValue; use inkwell::{ @@ -202,7 +202,7 @@ fn emit_list_op<'c, H: HugrView>( op: ListOp, ) -> Result<()> { let hugr_elem_ty = match args.node().args() { - [TypeArg::Runtime(ty)] => ty.clone(), + [ty] => ty.clone().try_into().expect("List elements not a type"), _ => { bail!("Collections: invalid type args for list op"); } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap index 38d3385687..2f13d9a6cd 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap @@ -5,24 +5,24 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.6acc1b76.0, ptr @sa.inner.e637bb5.0, ptr @sa.inner.2b6593f.0, ptr @sa.inner.1b9ad7c.0, ptr @sa.inner.e67fbfa4.0, ptr @sa.inner.15dc27f6.0, ptr @sa.inner.c43a2bb2.0, ptr @sa.inner.7f5d5e16.0, ptr @sa.inner.a0bc9c53.0, ptr @sa.inner.1e8aada3.0] } +@sa.inner.ac73413c.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.3334f213.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.9447a20a.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.dfbce68f.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.5712c1c3.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.fc8747b9.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.aaa0b715.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.7aa729b2.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.66390c92.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.4d7c0c80.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.f1be8bcf.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.ac73413c.0, ptr @sa.inner.3334f213.0, ptr @sa.inner.9447a20a.0, ptr @sa.inner.dfbce68f.0, ptr @sa.inner.5712c1c3.0, ptr @sa.inner.fc8747b9.0, ptr @sa.inner.aaa0b715.0, ptr @sa.inner.7aa729b2.0, ptr @sa.inner.66390c92.0, ptr @sa.inner.4d7c0c80.0] } define internal i64 @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = getelementptr inbounds { i64, [0 x ptr] }, ptr @sa.outer.e55b610a.0, i32 0, i32 0 + %0 = getelementptr inbounds { i64, [0 x ptr] }, ptr @sa.outer.f1be8bcf.0, i32 0, i32 0 %1 = load i64, ptr %0, align 4 ret i64 %1 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap index 950cd1316d..0373d2a471 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap @@ -5,17 +5,17 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.inner.ac73413c.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.3334f213.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.9447a20a.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.dfbce68f.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.5712c1c3.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.fc8747b9.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.aaa0b715.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.7aa729b2.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.66390c92.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.4d7c0c80.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.f1be8bcf.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.ac73413c.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.3334f213.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.9447a20a.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.dfbce68f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.5712c1c3.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.fc8747b9.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.aaa0b715.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7aa729b2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.66390c92.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.4d7c0c80.0 to { i64, [0 x i64] }*)] } define internal i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 + store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.f1be8bcf.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i64, [0 x i64] }*] }*, { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* %"5_01", i32 0, i32 0 %1 = load i64, i64* %0, align 4 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap index 65ffc76d07..ecafb18565 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap @@ -5,17 +5,17 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.6acc1b76.0, ptr @sa.inner.e637bb5.0, ptr @sa.inner.2b6593f.0, ptr @sa.inner.1b9ad7c.0, ptr @sa.inner.e67fbfa4.0, ptr @sa.inner.15dc27f6.0, ptr @sa.inner.c43a2bb2.0, ptr @sa.inner.7f5d5e16.0, ptr @sa.inner.a0bc9c53.0, ptr @sa.inner.1e8aada3.0] } +@sa.inner.ac73413c.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.3334f213.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.9447a20a.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.dfbce68f.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.5712c1c3.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.fc8747b9.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.aaa0b715.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.7aa729b2.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.66390c92.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.4d7c0c80.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.f1be8bcf.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.ac73413c.0, ptr @sa.inner.3334f213.0, ptr @sa.inner.9447a20a.0, ptr @sa.inner.dfbce68f.0, ptr @sa.inner.5712c1c3.0, ptr @sa.inner.fc8747b9.0, ptr @sa.inner.aaa0b715.0, ptr @sa.inner.7aa729b2.0, ptr @sa.inner.66390c92.0, ptr @sa.inner.4d7c0c80.0] } define internal i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.outer.e55b610a.0, ptr %"5_0", align 8 + store ptr @sa.outer.f1be8bcf.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x ptr] }, ptr %"5_01", i32 0, i32 0 %1 = load i64, ptr %0, align 4 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap index 59bbf8007b..c775c0323e 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.64db7f63.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define internal ptr @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret ptr @sa.a.97cb22bf.0 + ret ptr @sa.a.64db7f63.0 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap index 0401cf0b50..dab0859bcc 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.d797f156.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define internal ptr @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret ptr @sa.c.d2dddd66.0 + ret ptr @sa.c.d797f156.0 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap index e682c3b6ce..81eea756ce 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } +@sa.d.e9aebfdb.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } define internal ptr @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret ptr @sa.d.eee08a59.0 + ret ptr @sa.d.e9aebfdb.0 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap index 480fd4e9c8..682e879acc 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.64db7f63.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define internal { i64, [0 x i64] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.97cb22bf.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }** %"5_0", align 8 + store { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.64db7f63.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }** %"5_0", align 8 %"5_01" = load { i64, [0 x i64] }*, { i64, [0 x i64] }** %"5_0", align 8 store { i64, [0 x i64] }* %"5_01", { i64, [0 x i64] }** %"0", align 8 %"02" = load { i64, [0 x i64] }*, { i64, [0 x i64] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap index 8f5dd5efb6..8594f22b58 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.d797f156.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define internal { i64, [0 x i1] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d2dddd66.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8 + store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d797f156.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8 %"5_01" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"5_0", align 8 store { i64, [0 x i1] }* %"5_01", { i64, [0 x i1] }** %"0", align 8 %"02" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap index 6f341ffa4d..c246496191 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.64db7f63.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define internal ptr @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.a.97cb22bf.0, ptr %"5_0", align 8 + store ptr @sa.a.64db7f63.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 store ptr %"5_01", ptr %"0", align 8 %"02" = load ptr, ptr %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap index 3fb020a531..82f98439a6 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.d797f156.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define internal ptr @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.c.d2dddd66.0, ptr %"5_0", align 8 + store ptr @sa.c.d797f156.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 store ptr %"5_01", ptr %"0", align 8 %"02" = load ptr, ptr %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap index 869b5a847f..3bea34aa6c 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } +@sa.d.e9aebfdb.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } define internal ptr @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.d.eee08a59.0, ptr %"5_0", align 8 + store ptr @sa.d.e9aebfdb.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 store ptr %"5_01", ptr %"0", align 8 %"02" = load ptr, ptr %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index 285a1ba3ec..fc1459f7e0 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -14,7 +14,7 @@ use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ self, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::TypeArg; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::{Builder, BuilderError}; @@ -135,10 +135,11 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), ty] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; - let elem_ty = ts.llvm_type(ty)?; + let ty = ty.clone().try_into().expect("Array elements not a type"); + let elem_ty = ts.llvm_type(&ty)?; Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum()) } }) @@ -357,7 +358,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -420,7 +421,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -494,7 +495,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 882fbad9f1..de7936d37b 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -339,9 +339,10 @@ impl CodegenExtension for StaticArrayCodegenE { move |ts, custom_type| { // check the arg type, even though the return is always ptr - let _ = custom_type.args()[0] - .as_runtime() - .expect("Type argument for static array must be a type"); + assert!( + custom_type.args()[0].is_runtime_type(), + "Type argument for static array must be a type" + ); Ok(ts.llvm_ptr_type().into()) } }, diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index c620d0926c..c2cf0011b8 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -8,7 +8,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, constant::Value, custom::ExtensionOp}, std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, - types::{TypeEnum, TypeRow}, + types::TypeRow, }; use inkwell::{FloatPredicate, IntPredicate, types::IntType, values::BasicValue}; @@ -189,12 +189,12 @@ fn emit_conversion_op<'c, H: HugrView>( .typing_session() .llvm_type(&INT_TYPES[0])? .into_int_type(); - let sum_ty = context - .typing_session() - .llvm_sum_type(match bool_t().as_type_enum() { - TypeEnum::Sum(st) => st.clone(), - _ => panic!("Hugr prelude bool_t() not a Sum"), - })?; + let sum_ty = context.typing_session().llvm_sum_type( + bool_t() + .as_sum() + .expect("Hugr prelude bool_t() not a Sum") + .clone(), + )?; emit_custom_unary_op(context, args, |ctx, arg, _| { let res = if conversion_op == ConvertOpDef::itobool { diff --git a/hugr-llvm/src/utils/type_map.rs b/hugr-llvm/src/utils/type_map.rs index c8129c2578..2cf006f95b 100644 --- a/hugr-llvm/src/utils/type_map.rs +++ b/hugr-llvm/src/utils/type_map.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use hugr_core::{ extension::ExtensionId, - types::{CustomType, TypeEnum, TypeName, TypeRow}, + types::{CustomType, Term, TypeName, TypeRow}, }; use anyhow::{Result, bail}; @@ -115,18 +115,18 @@ impl<'a, TM: TypeMapping + 'a> TypeMap<'a, TM> { /// Map `hugr_type` using the [`TypeMapping`] `TM`, the registered callbacks, /// and the auxiliary data `inv`. pub fn map_type<'c>(&self, hugr_type: &HugrType, inv: TM::InV<'c>) -> Result> { - match hugr_type.as_type_enum() { - TypeEnum::Extension(custom_type) => { + match &**hugr_type { + Term::RuntimeExtension(custom_type) => { let key = (custom_type.extension().clone(), custom_type.name().clone()); let Some(handler) = self.custom_hooks.get(&key) else { return self.type_map.default_out(inv, &custom_type.clone().into()); }; handler.map_type(inv, custom_type) } - TypeEnum::Sum(sum_type) => self + Term::RuntimeSum(sum_type) => self .map_sum_type(sum_type, inv) .map(|x| self.type_map.sum_into_out(x)), - TypeEnum::Function(function_type) => self + Term::RuntimeFunction(function_type) => self .map_function_type(&function_type.as_ref().clone().try_into()?, inv) .map(|x| self.type_map.func_into_out(x)), _ => self.type_map.default_out(inv, hugr_type), diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 694e426253..edbfeb872c 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -27,7 +27,7 @@ use hugr_core::std_extensions::arithmetic::{ int_types::{ConstInt, INT_TYPES}, }; use hugr_core::std_extensions::logic::LogicOp; -use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; +use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row}; use crate::{ComposablePass as _, composable::ValidatingPass}; @@ -809,7 +809,7 @@ fn test_fold_idivmod_checked_u() { // x0, x1 := int_u<5>(20), int_u<5>(0) // x2 := idivmod_checked_u(x0, x1) // output x2 == error - let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); + let intpair: TypeRow = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); let elem_type = Type::new_tuple(intpair); let sum_type = sum_with_error([elem_type.clone()]); let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); @@ -857,7 +857,7 @@ fn test_fold_idivmod_checked_s() { // x0, x1 := int_s<5>(-20), int_u<5>(0) // x2 := idivmod_checked_s(x0, x1) // output x2 == error - let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); + let intpair: TypeRow = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); let elem_type = Type::new_tuple(intpair); let sum_type = sum_with_error([elem_type.clone()]); let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 78547943fe..9b79fb9e98 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::Lattice; use ascent::lattice::BoundedLattice; use hugr_core::Node; -use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeRow}; use itertools::{Itertools, zip_eq}; use std::cmp::Ordering; use std::collections::HashMap; @@ -199,9 +199,11 @@ impl PartialSum { /// /// # Errors /// - /// If this `PartialSum` had multiple possible tags; or if `typ` was not a [`TypeEnum::Sum`] + /// If this `PartialSum` had multiple possible tags; or if `typ` was not a [`RuntimeSum`] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [`PartialValue::try_into_concrete`]. + /// + /// [`RuntimeSum`]: hugr_core::types::Term::RuntimeSum #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases pub fn try_into_sum>( self, @@ -211,7 +213,7 @@ impl PartialSum { return Err(ExtractValueError::MultipleVariants(self)); } let (tag, v) = self.0.into_iter().exactly_one().unwrap(); - if let TypeEnum::Sum(st) = typ.as_type_enum() + if let Some(st) = typ.as_sum() && let Some(r) = st.get_variant(tag) && let Ok(r) = TypeRow::try_from(r.clone()) && v.len() == r.len() diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 402292a527..629a98955a 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -251,7 +251,16 @@ fn escape_dollar(str: impl AsRef) -> String { fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match arg { - TypeArg::Runtime(ty) => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), + TypeArg::RuntimeExtension(cty) => { + // ALAN make this "e({})" to better distinguish? Or make function "t({})" also? + f.write_fmt(format_args!("t({})", escape_dollar(cty.to_string()))) + } + TypeArg::RuntimeSum(sty) => { + f.write_fmt(format_args!("t({})", escape_dollar(sty.to_string()))) + } + TypeArg::RuntimeFunction(fty) => { + f.write_fmt(format_args!("f({})", escape_dollar(fty.to_string()))) + } TypeArg::BoundedNat(n) => f.write_fmt(format_args!("n({n})")), TypeArg::String(arg) => f.write_fmt(format_args!("s({})", escape_dollar(arg))), TypeArg::List(elems) => f.write_fmt(format_args!("list({})", TypeArgsSeq(elems))), @@ -302,7 +311,7 @@ mod test { use hugr_core::extension::prelude::{ConstUsize, UnpackTuple, UnwrapBuilder, usize_t}; use hugr_core::ops::handle::FuncID; use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; - use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeEnum}; + use hugr_core::types::{PolyFuncType, Signature, Term, Type, TypeArg, TypeBound}; use hugr_core::{Hugr, HugrView, Node, Visibility}; use rstest::rstest; @@ -548,9 +557,7 @@ mod test { let popleft = BArrayOpDef::pop_left.to_concrete(arr2u(), n); let ar2 = outer.add_dataflow_op(popleft.clone(), [arr2]).unwrap(); let sig = popleft.to_extension_op().unwrap().signature().into_owned(); - let TypeEnum::Sum(st) = sig.output().get(0).unwrap().as_type_enum() else { - panic!() - }; + let st = sig.output().get(0).unwrap().as_sum().unwrap(); let [left_arr, ar2_unwrapped] = outer .build_unwrap_sum(1, st.clone(), ar2.out_wire(0)) .unwrap(); @@ -663,7 +670,7 @@ mod test { #[case::type_int(vec![INT_TYPES[2].clone().into()], "$foo$$t(int(2))")] #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] - #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$list($n(0)$t(Unit))")] + #[case::sequence(vec![vec![0.into(), Term::from(Type::UNIT)].into()], "$foo$$list($n(0)$t(Unit))")] #[case::sequence(vec![TypeArg::Tuple(vec![0.into(),Type::UNIT.into()])], "$foo$$tuple($n(0)$t(Unit))")] #[should_panic] #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::StringType)], diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index dbe18c47ca..e0c490799e 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -24,8 +24,7 @@ use hugr_core::ops::{ ExtensionOp, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, - TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeRow, TypeTransformer, }; use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Visibility, Wire}; @@ -867,16 +866,16 @@ impl ReplaceTypes { Ok(any_change) } Value::Extension { e } => Ok({ - let new_const = match e.get_type().as_type_enum() { - TypeEnum::Extension(exty) => match self.consts.get(exty) { - Some(const_fn) => Some(const_fn(e, self)), - None => self - .param_consts - .get(&exty.into()) - .and_then(|const_fn| const_fn(e, self).transpose()), - }, - _ => None, - }; + let new_const = + e.get_type() + .as_extension() + .and_then(|exty| match self.consts.get(exty) { + Some(const_fn) => Some(const_fn(e, self)), + None => self + .param_consts + .get(&exty.into()) + .and_then(|const_fn| const_fn(e, self).transpose()), + }); if let Some(new_const) = new_const { *value = new_const?; true @@ -1037,11 +1036,11 @@ mod test { ExtensionOp::new(ext.get_op(READ).unwrap().clone(), [t.into()]).unwrap() } - fn just_elem_type(args: &[TypeArg]) -> &Type { - let [TypeArg::Runtime(ty)] = args else { + fn just_elem_type(args: &[TypeArg]) -> Type { + let [ty] = args else { panic!("Expected just elem type") }; - ty + Type::try_from(ty.clone()).unwrap() } fn ext() -> Arc { @@ -1109,7 +1108,7 @@ mod test { .unwrap() .outputs_arr(); let [res] = dfb - .build_unwrap_sum(1, option_type([Type::from(elem_ty)]), opt) + .build_unwrap_sum(1, option_type([elem_ty]), opt) .unwrap(); dfb.set_outputs([res]).unwrap(); dfb @@ -1290,7 +1289,7 @@ mod test { // 1. Lower List to BArray<10, T> UNLESS T is usize_t() or i64_t lowerer.set_replace_parametrized_type(list_type_def(), |args| { let ty = just_elem_type(args); - (![usize_t(), i64_t()].contains(ty)).then_some(borrow_array_type(10, ty.clone())) + (![usize_t(), i64_t()].contains(&ty)).then_some(borrow_array_type(10, ty.clone())) }); { let mut h = backup.clone(); @@ -1356,13 +1355,7 @@ mod test { h.get_optype(pred.node()) .as_load_constant() .map(hugr_core::ops::LoadConstant::constant_type), - Some(&Type::new_sum(vec![ - [Type::from(borrow_array_type( - 4, - i64_t() - ))]; - 2 - ])) + Some(&Type::new_sum(vec![[borrow_array_type(4, i64_t())]; 2])) ); } @@ -1384,10 +1377,10 @@ mod test { .unwrap(); }, ); - fn option_contents(ty: &Type) -> Option { + fn option_contents(ty: Type) -> Option { let row = ty.as_sum()?.get_variant(1).unwrap().clone(); - let elem = row.into_owned().into_iter().exactly_one().unwrap(); - Some(elem.try_into_type().unwrap()) + let elems = TypeRow::try_from(row).unwrap(); + Some(elems.into_owned().into_iter().exactly_one().unwrap()) } let i32_t = || INT_TYPES[5].clone(); let opt_i32 = Type::from(option_type([i32_t()])); @@ -1584,11 +1577,12 @@ mod test { let mut lw = lowerer(&e); lw.set_replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| { Ok(Some({ - let [Term::Runtime(ty)] = args else { + let [tm] = args else { return Err(SignatureError::InvalidTypeArgs.into()); }; + let ty = Type::try_from(tm.clone()).map_err(|_| SignatureError::InvalidTypeArgs)?; - let defn_hugr = lowered_read(ty.clone(), |sig| { + let defn_hugr = lowered_read(ty, |sig| { FunctionBuilder::new_vis( mangle_name("lowered_read", args), sig, @@ -1704,9 +1698,10 @@ mod test { .unwrap() .as_ref(), move |args, _| { - let [sz, Term::Runtime(ty)] = args else { + let [sz, ty] = args else { panic!("Expected two args to array-get") }; + let ty = Type::try_from(ty.clone()).unwrap(); if sz != &Term::BoundedNat(64) { return Ok(None); } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 21cd4541ef..447993a325 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -110,9 +110,10 @@ pub fn linearize_generic_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat(n), ty] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; + let ty = Type::try_from(ty.clone()).unwrap(); if num_outports == 0 { // "Simple" discard let array_scan = GenericArrayScan::::new(ty.clone(), Type::UNIT, vec![], *n); @@ -131,7 +132,7 @@ pub fn linearize_generic_array( ) .unwrap(); let [to_discard] = fb.input_wires_arr(); - let disc = lin.copy_discard_op(ty, 0)?; + let disc = lin.copy_discard_op(&ty, 0)?; disc.add(&mut fb, [to_discard]).map_err(|e| { LinearizeError::NestedTemplateError(Box::new(ty.clone()), Box::new(e)) })?; @@ -217,7 +218,7 @@ pub fn linearize_generic_array( .unwrap() .outputs_arr(); let mut copies = lin - .copy_discard_op(ty, num_outports)? + .copy_discard_op(&ty, num_outports)? .add(&mut fb, [elem]) .map_err(|e| LinearizeError::NestedTemplateError(Box::new(ty.clone()), Box::new(e)))? .outputs(); @@ -332,9 +333,10 @@ pub fn copy_discard_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat(n), ty] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; + let ty = Type::try_from(ty.clone()).expect("Illegal array element type"); if ty.copyable() { // For arrays with copyable elements, we can just use the clone/discard ops if num_outports == 0 { @@ -379,9 +381,10 @@ pub fn copy_discard_borrow_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat(n), ty] = args else { panic!("Illegal TypeArgs to borrow array: {args:?}") }; + let ty = Type::try_from(ty.clone()).expect("Illegal BorrowArray element type"); if ty.copyable() { // For arrays with copyable elements, we can just use the clone/discard ops if num_outports == 0 { @@ -411,7 +414,7 @@ pub fn copy_discard_borrow_array( } } else if num_outports == 0 { // Override "generic" array discard to only discard non-borrowed elements. - let elem_discard = lin.copy_discard_op(ty, 0)?; + let elem_discard = lin.copy_discard_op(&ty, 0)?; let array_ty = || borrow_array_type(*n, ty.clone()); let i64_t = || INT_TYPES[6].clone(); let mut dfb = DFGBuilder::new(inout_sig([array_ty()], [])).unwrap(); diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 4540342770..33d23e312e 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -7,7 +7,7 @@ use hugr_core::builder::{ use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::borrow_array::borrow_array_type_def; -use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::types::{CustomType, Signature, Term, Type, TypeArg, TypeRow}; use hugr_core::{HugrView, IncomingPort, Node, Wire, hugr::hugrmut::HugrMut, ops::Tag}; use itertools::Itertools; @@ -106,7 +106,7 @@ pub trait Linearizer { /// A configuration for implementing [Linearizer] by delegating to /// type-specific callbacks, and by composing them in order to handle compound types -/// such as [`TypeEnum::Sum`]s. +/// such as [`Term::RuntimeSum`]s. #[derive(Clone)] pub struct DelegatingLinearizer { // Keyed by lowered type, as only needed when there is an op outputting such @@ -165,8 +165,7 @@ pub enum LinearizeError { #[error(transparent)] SignatureError(#[from] SignatureError), /// We cannot linearize (insert copy and discard functions) for - /// [Variable](TypeEnum::Variable)s, [Row variables](TypeEnum::RowVar), - /// or [Alias](TypeEnum::Alias)es. + /// [Variable](Term::Variable)s. #[error("Cannot linearize type {_0}")] UnsupportedType(Box), /// Neither does linearization make sense for copyable types @@ -191,7 +190,7 @@ impl DelegatingLinearizer { /// Configures this instance that the specified monomorphic type can be copied and/or /// discarded via the provided [`NodeTemplate`]s - directly or as part of a compound type - /// e.g. [`TypeEnum::Sum`]. + /// e.g. [`Term::RuntimeSum`]. /// `copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports. /// @@ -272,12 +271,13 @@ impl Linearizer for DelegatingLinearizer { } assert!(num_outports != 1); - match typ.as_type_enum() { - TypeEnum::Sum(sum_type) => { + match &**typ { + Term::RuntimeSum(sum_type) => { let variants = sum_type .variants() .map(|trv| trv.clone().try_into()) - .collect::, _>>()?; + .collect::, _>>() + .map_err(SignatureError::from)?; let mut cb = ConditionalBuilder::new( variants.clone(), vec![], @@ -319,7 +319,7 @@ impl Linearizer for DelegatingLinearizer { cb.finish_hugr().unwrap(), ))) } - TypeEnum::Extension(cty) => { + Term::RuntimeExtension(cty) => { if let Some((copy, discard)) = self.copy_discard.get(cty) { Ok(if num_outports == 0 { discard.clone() @@ -352,7 +352,7 @@ impl Linearizer for DelegatingLinearizer { Ok(tmpl) } } - TypeEnum::Function(_) => panic!("Ruled out above as copyable"), + Term::RuntimeFunction(_) => panic!("Ruled out above as copyable"), _ => Err(LinearizeError::UnsupportedType(Box::new(typ.clone()))), } } @@ -844,10 +844,11 @@ mod test { ); let drop_op = drop_ext.get_op("drop").unwrap(); lowerer.set_replace_parametrized_op(drop_op, |args, rt| { - let [TypeArg::Runtime(ty)] = args else { + let [ty] = args else { panic!("Expected just one type") }; - Ok(Some(rt.get_linearizer().copy_discard_op(ty, 0)?)) + let ty = Type::try_from(ty.clone()).unwrap(); + Ok(Some(rt.get_linearizer().copy_discard_op(&ty, 0)?)) }); let build_hugr = |ty: Type| { diff --git a/hugr/benches/benchmarks/types.rs b/hugr/benches/benchmarks/types.rs index d05896f01b..1b89fd7d9f 100644 --- a/hugr/benches/benchmarks/types.rs +++ b/hugr/benches/benchmarks/types.rs @@ -1,8 +1,7 @@ // Required for black_box uses #![allow(clippy::unit_arg)] use hugr::extension::prelude::{qb_t, usize_t}; -use hugr::ops::AliasDecl; -use hugr::types::{Signature, Type, TypeBound}; +use hugr::types::{Signature, Type}; use criterion::{AxisScale, Criterion, PlotConfiguration, criterion_group}; use std::hint::black_box; @@ -13,8 +12,8 @@ fn make_complex_type() -> Type { let int = usize_t(); let q_register = Type::new_tuple(vec![qb; 8]); let b_register = Type::new_tuple(vec![int; 8]); - let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); - let sum = Type::new_sum([[q_register], [q_alias]]); + //let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); + let sum = Type::new_sum([[q_register], [b_register.clone()]]); Type::new_function(Signature::new(vec![sum], vec![b_register])) }