diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index f41e3db24..c29c215c1 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -469,15 +469,15 @@ pub(crate) mod test { BuilderWiringError, CFGBuilder, DataflowSubContainer, ModuleBuilder, TailLoopBuilder, endo_sig, inout_sig, }; - use crate::extension::SignatureError; + use crate::extension::prelude::{Noop, bool_t, qb_t, usize_t}; use crate::hugr::linking::{NameLinkingPolicy, NodeLinkingDirective, OnMultiDefn}; use crate::hugr::validate::InterGraphEdgeError; use crate::metadata::Metadata; use crate::ops::{FuncDecl, FuncDefn, OpParent, OpTag, OpTrait, Value, handle::NodeHandle}; use crate::std_extensions::logic::test::and_op; - use crate::types::type_param::TypeParam; - use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV}; + use crate::types::type_param::{TermTypeError, TypeParam}; + use crate::types::{EdgeKind, FuncValueType, Signature, Type, TypeBound, TypeRV}; use crate::utils::test_quantum_extension::h_gate; use crate::{Wire, builder::test::n_identity, type_row}; @@ -930,7 +930,7 @@ pub(crate) mod test { Signature::new( [Type::new_function(FuncValueType::new( [usize_t()], - [tv.clone()], + tv.clone(), ))], [], ), @@ -938,15 +938,15 @@ pub(crate) mod test { )?; // But cannot eval it... - let ev = e.instantiate_extension_op( - "eval", - [vec![usize_t().into()].into(), vec![tv.into()].into()], - ); + let ev = + e.instantiate_extension_op("eval", [vec![usize_t()].into(), vec![tv.clone()].into()]); assert_eq!( ev, - Err(SignatureError::RowVarWhereTypeExpected { - var: RowVariable(0, TypeBound::Copyable) - }) + Err(TermTypeError::TypeMismatch { + term: Box::new(tv), + type_: Box::new(TypeBound::Linear.into()) + } + .into()) ); Ok(()) } diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index e31cddd2e..91813ad39 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -9,12 +9,10 @@ use crate::hugr::{ ValidationError, hugrmut::InsertedForest, internal::HugrMutInternals, views::HugrView, }; use crate::ops; -use crate::ops::handle::{AliasID, FuncID, NodeHandle}; -use crate::types::{PolyFuncType, Type, TypeBound}; +use crate::ops::handle::{FuncID, NodeHandle}; +use crate::types::PolyFuncType; use crate::{Hugr, Node, Visibility, ops::FuncDefn}; -use smol_str::SmolStr; - /// Builder for a HUGR module. #[derive(Debug, Default, Clone, PartialEq)] pub struct ModuleBuilder(pub(super) T); @@ -179,49 +177,6 @@ impl + AsRef> ModuleBuilder { self.define_function_op(FuncDefn::new(name, signature)) } - /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias. - /// - /// # Errors - /// - /// Error in adding [`crate::ops::OpType::AliasDefn`] child node. - pub fn add_alias_def( - &mut self, - name: impl Into, - typ: Type, - ) -> Result, BuildError> { - // TODO: add AliasDefn in other containers - // This is currently tricky as they are not connected to anything so do - // not appear in topological traversals. - // Could be fixed by removing single-entry requirement and sorting from - // every 0-input node. - let name: SmolStr = name.into(); - let bound = typ.least_upper_bound(); - let node = self.add_child_node(ops::AliasDefn { - name: name.clone(), - definition: typ, - }); - - Ok(AliasID::new(node, name, bound)) - } - - /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias. - /// # Errors - /// - /// Error in adding [`crate::ops::OpType::AliasDecl`] child node. - pub fn add_alias_declare( - &mut self, - name: impl Into, - bound: TypeBound, - ) -> Result, BuildError> { - let name: SmolStr = name.into(); - let node = self.add_child_node(ops::AliasDecl { - name: name.clone(), - bound, - }); - - Ok(AliasID::new(node, name, bound)) - } - /// Add some module-children of another Hugr to this module, with /// linking directives specified explicitly by [Node]. /// @@ -260,7 +215,7 @@ mod test { use cool_asserts::assert_matches; use crate::builder::test::dfg_calling_defn_decl; - use crate::builder::{Dataflow, DataflowSubContainer, test::n_identity}; + use crate::builder::{Dataflow, DataflowSubContainer}; use crate::extension::prelude::usize_t; use crate::{hugr::linking::NodeLinkingDirective, ops::OpType, types::Signature}; @@ -285,28 +240,6 @@ mod test { Ok(()) } - #[test] - fn simple_alias() -> Result<(), BuildError> { - let build_result = { - let mut module_builder = ModuleBuilder::new(); - - let qubit_state_type = - module_builder.add_alias_declare("qubit_state", TypeBound::Linear)?; - - let f_build = module_builder.define_function( - "main", - Signature::new( - vec![qubit_state_type.get_alias_type()], - vec![qubit_state_type.get_alias_type()], - ), - )?; - n_identity(f_build)?; - module_builder.finish_hugr() - }; - assert_matches!(build_result, Ok(_)); - Ok(()) - } - #[test] fn builder_from_existing() -> Result<(), BuildError> { let hugr = Hugr::new(); diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 4469feb7a..13bceedbb 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -2,7 +2,6 @@ use crate::Visibility; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; -use crate::types::type_param::Term; use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port, extension::{ExtensionId, OpDef, SignatureFunc}, @@ -14,10 +13,8 @@ use crate::{ arithmetic::{float_types::ConstF64, int_types::ConstInt}, collections::array::ArrayValue, }, - types::{ - CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase, - }, + types::type_param::{Term, TermVar}, + types::{CustomType, EdgeKind, FuncValueType, Signature, SumType, TypeBound, TypeRow}, }; use hugr_model::v0::bumpalo; @@ -342,10 +339,12 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let symbol_name = this.export_func_name(node, &mut meta); - let symbol = this.export_poly_func_type( + let sig = func.signature(); + let symbol = this.export_symbol_params( symbol_name, Some(func.visibility().clone().into()), - func.signature(), + sig.params(), + |this| this.export_signature(sig.body()), ); regions = this.bump.alloc_slice_copy(&[this.export_dfg( node, @@ -358,11 +357,12 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let symbol_name = this.export_func_name(node, &mut meta); - - let symbol = this.export_poly_func_type( + let sig = func.signature(); + let symbol = this.export_symbol_params( symbol_name, Some(func.visibility().clone().into()), - func.signature(), + sig.params(), + |this| this.export_signature(sig.body()), ); table::Operation::DeclareFunc(symbol) }), @@ -381,7 +381,7 @@ impl<'a> Context<'a> { }), OpType::AliasDefn(alias) => self.with_local_scope(node_id, |this| { - let value = this.export_type(&alias.definition); + let value = this.export_term(&alias.definition, None); // TODO: We should support aliases with different types and with parameters let signature = this.make_term_apply(model::CORE_TYPE, &[]); let symbol = this.bump.alloc(table::Symbol { @@ -507,7 +507,7 @@ impl<'a> Context<'a> { Some(signature) => { let num_inputs = signature.input_types().len(); let num_outputs = signature.output_types().len(); - let signature = self.export_func_type(signature); + let signature = self.export_signature(signature); (Some(signature), num_inputs, num_outputs) } None => (None, 0, 0), @@ -559,7 +559,9 @@ impl<'a> Context<'a> { let symbol = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); - this.export_poly_func_type(name, None, poly_func_type) + this.export_symbol_params(name, None, poly_func_type.params(), |this| { + this.export_func_type(poly_func_type.body()) + }) }); let meta = { @@ -816,60 +818,48 @@ impl<'a> Context<'a> { } /// Exports a polymorphic function type. - pub fn export_poly_func_type( + pub fn export_symbol_params( &mut self, name: &'a str, visibility: Option, - t: &PolyFuncTypeBase, + params: &[Term], + export_body: impl FnOnce(&mut Self) -> table::TermId, ) -> &'a table::Symbol<'a> { - let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let mut param_vec = BumpVec::with_capacity_in(params.len(), self.bump); let scope = self .local_scope .expect("exporting poly func type outside of local scope"); let visibility = self.bump.alloc(visibility); - for (i, param) in t.params().iter().enumerate() { + for (i, param) in params.iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); let r#type = self.export_term(param, Some((scope, i as _))); let param = table::Param { name, r#type }; - params.push(param); + param_vec.push(param); } let constraints = self.bump.alloc_slice_copy(&self.local_constraints); - let body = self.export_func_type(t.body()); + let body = export_body(self); self.bump.alloc(table::Symbol { visibility, name, - params: params.into_bump_slice(), + params: param_vec.into_bump_slice(), constraints, signature: body, }) } - pub fn export_type(&mut self, t: &TypeBase) -> table::TermId { - self.export_type_enum(t.as_type_enum()) - } - - pub fn export_type_enum(&mut self, t: &TypeEnum) -> table::TermId { - match t { - TypeEnum::Extension(ext) => self.export_custom_type(ext), - TypeEnum::Alias(alias) => { - let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); - self.make_term(table::Term::Apply(symbol, &[])) - } - TypeEnum::Function(func) => self.export_func_type(func), - TypeEnum::Variable(index, _) => { - let node = self.local_scope.expect("local variable out of scope"); - self.make_term(table::Term::Var(table::VarId(node, *index as _))) - } - TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), - TypeEnum::Sum(sum) => self.export_sum_type(sum), - } - } - - pub fn export_func_type(&mut self, t: &FuncTypeBase) -> table::TermId { + pub fn export_signature(&mut self, t: &Signature) -> table::TermId { let inputs = self.export_type_row(t.input()); let outputs = self.export_type_row(t.output()); + // Ok to use CORE_FN here: the elements of the row will be exported inside a List + self.make_term_apply(model::CORE_FN, &[inputs, outputs]) + } + + pub fn export_func_type(&mut self, t: &FuncValueType) -> table::TermId { + let inputs = self.export_term(t.input(), None); + let outputs = self.export_term(t.output(), None); + // Ok to use CORE_FN here: the input/output should each be a core List or ListConcat self.make_term_apply(model::CORE_FN, &[inputs, outputs]) } @@ -888,28 +878,12 @@ impl<'a> Context<'a> { self.make_term(table::Term::Var(table::VarId(node, var.index() as _))) } - pub fn export_row_var(&mut self, t: &RowVariable) -> table::TermId { - let node = self.local_scope.expect("local variable out of scope"); - self.make_term(table::Term::Var(table::VarId(node, t.0 as _))) - } - pub fn export_sum_variants(&mut self, t: &SumType) -> table::TermId { - match t { - SumType::Unit { size } => { - let parts = self.bump.alloc_slice_fill_iter( - (0..*size) - .map(|_| table::SeqPart::Item(self.make_term(table::Term::List(&[])))), - ); - self.make_term(table::Term::List(parts)) - } - SumType::General { rows } => { - let parts = self.bump.alloc_slice_fill_iter( - rows.iter() - .map(|row| table::SeqPart::Item(self.export_type_row(row))), - ); - self.make_term(table::Term::List(parts)) - } - } + // Sadly we cannot use alloc_slice_fill_iter because SumType::variants is not an ExactSizeIterator. + let parts = self.bump.alloc_slice_fill_with(t.num_variants(), |i| { + table::SeqPart::Item(self.export_term(t.get_variant(i).unwrap(), None)) + }); + self.make_term(table::Term::List(parts)) } pub fn export_sum_type(&mut self, t: &SumType) -> table::TermId { @@ -918,27 +892,20 @@ impl<'a> Context<'a> { } #[inline] - pub fn export_type_row(&mut self, row: &TypeRowBase) -> table::TermId { + pub fn export_type_row(&mut self, row: &TypeRow) -> table::TermId { self.export_type_row_with_tail(row, None) } - pub fn export_type_row_with_tail( + pub fn export_type_row_with_tail( &mut self, - row: &TypeRowBase, + row: &TypeRow, tail: Option, ) -> table::TermId { let mut parts = BumpVec::with_capacity_in(row.len() + usize::from(tail.is_some()), self.bump); for t in row.iter() { - match t.as_type_enum() { - TypeEnum::RowVar(var) => { - parts.push(table::SeqPart::Splice(self.export_row_var(var.as_rv()))); - } - _ => { - parts.push(table::SeqPart::Item(self.export_type(t))); - } - } + parts.push(table::SeqPart::Item(self.export_term(t, None))); } if let Some(tail) = tail { @@ -982,7 +949,14 @@ impl<'a> Context<'a> { let item_types = self.export_term(item_types, None); self.make_term_apply(model::CORE_TUPLE_TYPE, &[item_types]) } - Term::Runtime(ty) => self.export_type(ty), + Term::RuntimeExtension(ext) => self.export_custom_type(ext), + /*TypeEnum::Alias(alias) => { + let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); + self.make_term(table::Term::Apply(symbol, &[])) + }*/ + Term::RuntimeFunction(func) => self.export_func_type(func), + Term::RuntimeSum(sum) => self.export_sum_type(sum), + Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()), Term::Float(value) => self.make_term(model::Literal::Float(*value).into()), @@ -1022,7 +996,7 @@ impl<'a> Context<'a> { Term::Variable(v) => self.export_type_arg_var(v), Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]), Term::ConstType(ty) => { - let ty = self.export_type(ty); + let ty = self.export_term(ty, None); self.make_term_apply(model::CORE_CONST, &[ty]) } } @@ -1037,7 +1011,7 @@ impl<'a> Context<'a> { if let Some(array) = e.value().downcast_ref::() { let len = self .make_term(model::Literal::Nat(array.get_contents().len() as u64).into()); - let element_type = self.export_type(array.get_element_type()); + let element_type = self.export_term(array.get_element_type(), None); let mut contents = BumpVec::with_capacity_in(array.get_contents().len(), self.bump); @@ -1076,7 +1050,7 @@ impl<'a> Context<'a> { }; let json = self.make_term(model::Literal::Str(json.into()).into()); - let runtime_type = self.export_type(&e.get_type()); + let runtime_type = self.export_term(&e.get_type(), None); let args = self.bump.alloc_slice_copy(&[runtime_type, json]); let symbol = self.resolve_symbol(model::COMPAT_CONST_JSON); self.make_term(table::Term::Apply(symbol, args)) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 8fe6a21b1..5c138f9ff 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -21,7 +21,6 @@ use thiserror::Error; use crate::hugr::IdentList; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; -use crate::types::RowVariable; use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; @@ -414,9 +413,10 @@ pub enum SignatureError { /// A type variable that was used has not been declared #[error("Type variable {idx} was not declared ({num_decls} in scope)")] FreeTypeVar { idx: usize, num_decls: usize }, - /// A row variable was found outside of a variable-length row - #[error("Expected a single type, but found row variable {var}")] - RowVarWhereTypeExpected { var: RowVariable }, + // ALAN this is now just another TypeArgMismatch + // A row variable was found outside of a variable-length row + //#[error("Expected a single type, but found row variable {var}")] + //RowVarWhereTypeExpected { var: RowVariable }, /// The result of the type application stored in a [Call] /// is not what we get by applying the type-args to the polymorphic function /// diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 25cb1e58b..2168dedab 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -714,10 +714,10 @@ pub(super) mod test { reg.validate()?; let e = reg.get(&EXT_ID).unwrap(); - let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?); + let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t()])?); let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?; let rev = dfg.add_dataflow_op( - e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()]) + e.instantiate_extension_op(&OP_NAME, vec![usize_t()]) .unwrap(), dfg.input_wires(), )?; @@ -748,7 +748,7 @@ pub(super) mod test { .collect(); Ok(PolyFuncTypeRV::new( vec![TP.clone()], - Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]), + Signature::new(tvs.clone(), vec![Type::new_runtime_tuple(tvs)]), )) } @@ -762,12 +762,12 @@ pub(super) mod test { ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?; // Base case, no type variables: - let args = [TypeArg::BoundedNat(3), usize_t().into()]; + let args = [TypeArg::BoundedNat(3), usize_t()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( vec![usize_t(); 3], - vec![Type::new_tuple(vec![usize_t(); 3])] + vec![Type::new_runtime_tuple(vec![usize_t(); 3])] )) ); assert_eq!(def.validate_args(&args, &[]), Ok(())); @@ -775,12 +775,12 @@ pub(super) mod test { // Second arg may be a variable (substitutable) let tyvar = Type::new_var_use(0, TypeBound::Copyable); let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat(3), tyvar.clone().into()]; + let args = [TypeArg::BoundedNat(3), tyvar.clone()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( tyvars.clone(), - vec![Type::new_tuple(tyvars)] + vec![Type::new_runtime_tuple(tyvars)] )) ); def.validate_args(&args, &[TypeBound::Copyable.into()]) @@ -797,7 +797,7 @@ pub(super) mod test { // First arg must be concrete, not a variable let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap()); - let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()]; + let args = [TypeArg::new_var_use(0, kind.clone()), usize_t()]; // We can't prevent this from getting into our compute_signature implementation: assert_eq!( def.compute_signature(&args), @@ -833,12 +833,12 @@ pub(super) mod test { extension_ref, )?; let tv = Type::new_var_use(0, TypeBound::Copyable); - let args = [tv.clone().into()]; + let args = [tv.clone()]; let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo([tv]))); // But not with an external row variable - let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); + let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable); assert_eq!( def.compute_signature(std::slice::from_ref(&arg)), Err(SignatureError::TypeArgMismatch( diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index cfd09b30d..fbbb582ec 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -15,7 +15,7 @@ use crate::extension::{ use crate::ops::OpName; use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::{NamedOp, Value}; -use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_type}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, @@ -26,7 +26,7 @@ use crate::{Extension, type_row}; use strum::{EnumIter, EnumString, IntoStaticStr}; use super::ExtensionRegistry; -use super::resolution::{ExtensionResolutionError, WeakExtensionRegistry, resolve_type_extensions}; +use super::resolution::{ExtensionResolutionError, WeakExtensionRegistry, resolve_term_extensions}; mod unwrap_builder; @@ -117,11 +117,11 @@ pub static PRELUDE: LazyLock> = LazyLock::new(|| { TypeParam::new_list_type(TypeBound::Linear), ], FuncValueType::new( - vec![ - TypeRV::new_extension(error_type.clone()), + Term::concat_lists([ + Term::new_list([TypeRV::new_extension(error_type.clone())]), TypeRV::new_row_var_use(0, TypeBound::Linear), - ], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], + ]), + TypeRV::new_row_var_use(1, TypeBound::Linear), ), ), extension_ref, @@ -137,11 +137,11 @@ pub static PRELUDE: LazyLock> = LazyLock::new(|| { TypeParam::new_list_type(TypeBound::Linear), ], FuncValueType::new( - vec![ - TypeRV::new_extension(error_type), + Term::concat_lists([ + Term::new_list([Type::new_extension(error_type)]), TypeRV::new_row_var_use(0, TypeBound::Linear), - ], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], + ]), + TypeRV::new_row_var_use(1, TypeBound::Linear), ), ), extension_ref, @@ -592,7 +592,7 @@ impl CustomConst for ConstExternalSymbol { &mut self, extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - resolve_type_extensions(&mut self.typ, extensions) + resolve_term_extensions(&mut self.typ, extensions) } fn validate(&self) -> Result<(), CustomCheckFailure> { @@ -643,15 +643,15 @@ impl MakeOpDef for TupleOpDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { let rv = TypeRV::new_row_var_use(0, TypeBound::Linear); - let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); + let tuple_type = TypeRV::new_runtime_tuple(rv.clone()); let param = TypeParam::new_list_type(TypeBound::Linear); match self { TupleOpDef::MakeTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new([rv], [tuple_type])) + PolyFuncTypeRV::new([param], FuncValueType::new(rv, [tuple_type])) } TupleOpDef::UnpackTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new([tuple_type], [rv])) + PolyFuncTypeRV::new([param], FuncValueType::new([tuple_type], rv)) } } .into() @@ -711,18 +711,14 @@ impl MakeExtensionOp for MakeTuple { let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self(tys?.into())) + for e in elems { + check_term_type(e, &TypeBound::Linear.into()).map_err(SignatureError::from)?; + } + Ok(Self(elems.clone().into())) } fn type_args(&self) -> Vec { - vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] + vec![Term::new_list(self.0.iter().cloned())] } } @@ -766,18 +762,14 @@ impl MakeExtensionOp for UnpackTuple { let [Term::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - Term::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self(tys?.into())) + for e in elems { + check_term_type(e, &TypeBound::Linear.into()).map_err(SignatureError::from)?; + } + Ok(Self(elems.clone().into())) } fn type_args(&self) -> Vec { - vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] + vec![Term::new_list(self.0.iter().cloned())] } } @@ -881,14 +873,15 @@ impl MakeExtensionOp for Noop { Self: Sized, { let _def = NoopDef::from_def(ext_op.def())?; - let [TypeArg::Runtime(ty)] = ext_op.args() else { + let [ty] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; Ok(Self(ty.clone())) } fn type_args(&self) -> Vec { - vec![self.0.clone().into()] + vec![self.0.clone()] } } @@ -929,7 +922,7 @@ impl MakeOpDef for BarrierDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( vec![TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo([TypeRV::new_row_var_use(0, TypeBound::Linear)]), + FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Linear)), ) .into() } @@ -990,22 +983,16 @@ impl MakeExtensionOp for Barrier { let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); + for e in elems { + check_term_type(e, &TypeBound::Linear.into()).map_err(SignatureError::from)?; + } Ok(Self { - type_row: tys?.into(), + type_row: elems.clone().into(), }) } fn type_args(&self) -> Vec { - vec![TypeArg::new_list( - self.type_row.iter().map(|t| t.clone().into()), - )] + vec![TypeArg::new_list(self.type_row.iter().cloned())] } } @@ -1046,7 +1033,7 @@ mod test { optype.dataflow_signature().unwrap().io(), ( &type_row![Type::UNIT], - &vec![Type::new_tuple(type_row![Type::UNIT])].into(), + &vec![Type::new_runtime_tuple(type_row![Type::UNIT])].into(), ) ); @@ -1061,7 +1048,7 @@ mod test { assert_eq!( optype.dataflow_signature().unwrap().io(), ( - &vec![Type::new_tuple(type_row![Type::UNIT])].into(), + &vec![Type::new_runtime_tuple(type_row![Type::UNIT])].into(), &type_row![Type::UNIT], ) ); @@ -1182,8 +1169,7 @@ mod test { /// test the panic operation with input and output wires fn test_panic_with_io() { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: Term = qb_t().into(); - let type_arg_2q: Term = Term::new_list([type_arg_q.clone(), type_arg_q]); + let type_arg_2q: Term = Term::new_list([qb_t(), qb_t()]); let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index f73b5ce60..2b4177422 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -21,17 +21,9 @@ pub trait UnwrapBuilder: Dataflow { inputs: impl IntoIterator, ) -> Result, BuildError> { let (input_wires, input_types): (Vec<_>, Vec<_>) = inputs.into_iter().unzip(); - let input_arg: TypeArg = input_types - .into_iter() - .map(>::from) - .collect_vec() - .into(); - let output_arg: TypeArg = output_row - .into_iter() - .map(>::from) - .collect_vec() - .into(); - let op = PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?; + let output_arg: TypeArg = output_row.into_iter().collect_vec().into(); + let op = + PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_types.into(), output_arg])?; let err = self.add_load_value(err); self.add_dataflow_op(op, iter::once(err).chain(input_wires)) } @@ -69,11 +61,9 @@ pub trait UnwrapBuilder: Dataflow { input: Wire, mut error: impl FnMut(usize) -> T, ) -> Result<[Wire; N], BuildError> { - let variants: Vec = (0..sum_type.num_variants()) - .map(|i| { - let tr_rv = sum_type.get_variant(i).unwrap().to_owned(); - TypeRow::try_from(tr_rv) - }) + let variants: Vec = sum_type + .variants() + .map(|t| t.clone().try_into()) .collect::>()?; // TODO don't panic if tag >= num_variants diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 097284f24..eb51036ab 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -25,11 +25,9 @@ mod weak_registry; pub use weak_registry::WeakExtensionRegistry; pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; -pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_type_exts}; +pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_term_exts}; pub(crate) use types_mut::resolve_op_types_extensions; -use types_mut::{ - resolve_custom_type_exts, resolve_term_exts, resolve_type_exts, resolve_value_exts, -}; +use types_mut::{resolve_custom_type_exts, resolve_term_exts, resolve_value_exts}; use derive_more::{Display, Error, From}; @@ -39,15 +37,15 @@ use crate::core::HugrNode; use crate::ops::constant::ValueName; use crate::ops::custom::OpaqueOpError; use crate::ops::{NamedOp, OpName, OpType, Value}; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, TypeArg, TypeBase, TypeName}; +use crate::types::{CustomType, Signature, Term, TypeArg, TypeName}; -/// Update all weak Extension pointers inside a type. -pub fn resolve_type_extensions( - typ: &mut TypeBase, +/// Update all weak Extension pointers inside a [Term]. +pub fn resolve_term_extensions( + typ: &mut Term, extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = WeakExtensionRegistry::default(); - resolve_type_exts(None, typ, extensions, &mut used_extensions) + resolve_term_exts(None, typ, extensions, &mut used_extensions) } /// Update all weak Extension pointers in a custom type. @@ -242,8 +240,8 @@ impl ExtensionCollectionError { } /// Create a new error when signature extensions have been dropped. - pub fn dropped_signature( - signature: &FuncTypeBase, + pub fn dropped_signature( + signature: &Signature, missing_extension: impl IntoIterator, ) -> Self { Self::DroppedSignatureExtensions { @@ -253,8 +251,8 @@ impl ExtensionCollectionError { } /// Create a new error when signature extensions have been dropped. - pub fn dropped_type( - typ: &TypeBase, + pub fn dropped_type( + typ: &Term, missing_extension: impl IntoIterator, ) -> Self { Self::DroppedTypeExtensions { diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 05c0faf69..27ee3798e 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef}; -use super::types_mut::resolve_signature_exts; +use super::types_mut::resolve_func_type_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -155,5 +155,5 @@ pub(super) fn resolve_signature_func_exts( return Ok(()); } }; - resolve_signature_exts(None, signature_body, extensions, used_extensions) + resolve_func_type_exts(None, signature_body, extensions, used_extensions) } diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 732d05901..68972c37f 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -337,8 +337,8 @@ fn resolve_call() { Signature::new(vec![], vec![bool_t()]), ); - let generic_type_1 = float64_type().into(); - let generic_type_2 = int_type(6).into(); + let generic_type_1 = float64_type(); + let generic_type_2 = int_type(6); let expected_exts = [ float_types::EXTENSION_ID.clone(), int_types::EXTENSION_ID.clone(), diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 221ecbeb9..700a42f53 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -10,8 +10,7 @@ use super::{ExtensionCollectionError, WeakExtensionRegistry}; use crate::Node; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; -use crate::types::type_row::TypeRowBase; -use crate::types::{FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{Signature, SumType, Term, TypeRow}; /// Collects every extension used to define the types in an operation. /// @@ -59,7 +58,7 @@ pub(crate) fn collect_op_types_extensions( } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), - OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing), + OpType::LoadConstant(lc) => collect_term_exts(&lc.datatype, &mut used, &mut missing), OpType::LoadFunction(lf) => { collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); @@ -121,7 +120,7 @@ pub(crate) fn collect_op_types_extensions( } } -/// Collect the Extension pointers in the [`CustomType`]s inside a signature. +/// Collect the Extension pointers in the [`CustomType`]s inside a [Signature]. /// /// # Attributes /// @@ -129,8 +128,8 @@ pub(crate) fn collect_op_types_extensions( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(crate) fn collect_signature_exts( - signature: &FuncTypeBase, +pub(crate) fn collect_signature_exts( + signature: &Signature, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { @@ -146,31 +145,31 @@ pub(crate) fn collect_signature_exts( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -fn collect_type_row_exts( - row: &TypeRowBase, +fn collect_type_row_exts( + row: &TypeRow, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { for ty in row.iter() { - collect_type_exts(ty, used_extensions, missing_extensions); + collect_term_exts(ty, used_extensions, missing_extensions); } } -/// Collect the Extension pointers in the [`CustomType`]s inside a type. +/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// # Attributes /// -/// - `typ`: The type to collect the extensions from. +/// - `term`: The term argument to collect the extensions from. /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(crate) fn collect_type_exts( - typ: &TypeBase, +pub(crate) fn collect_term_exts( + term: &Term, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - match typ.as_type_enum() { - TypeEnum::Extension(custom) => { + match term { + Term::RuntimeExtension(custom) => { for arg in custom.args() { collect_term_exts(arg, used_extensions, missing_extensions); } @@ -185,39 +184,16 @@ pub(crate) fn collect_type_exts( } } } - TypeEnum::Function(f) => { - collect_type_row_exts(&f.input, used_extensions, missing_extensions); - collect_type_row_exts(&f.output, used_extensions, missing_extensions); + Term::RuntimeFunction(f) => { + collect_term_exts(&f.input, used_extensions, missing_extensions); + collect_term_exts(&f.output, used_extensions, missing_extensions); } - TypeEnum::Sum(SumType::General { rows }) => { - for row in rows { - collect_type_row_exts(row, used_extensions, missing_extensions); + Term::RuntimeSum(g @ SumType::General(_)) => { + for row in g.variants() { + collect_term_exts(row, used_extensions, missing_extensions); } } - // Other types do not store extensions. - TypeEnum::Alias(_) - | TypeEnum::RowVar(_) - | TypeEnum::Variable(_, _) - | TypeEnum::Sum(SumType::Unit { .. }) => {} - } -} - -/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. -/// -/// # Attributes -/// -/// - `term`: The term argument to collect the extensions from. -/// - `used_extensions`: A The registry where to store the used extensions. -/// - `missing_extensions`: A set of `ExtensionId`s of which the -/// `Weak` pointer has been invalidated. -pub(super) fn collect_term_exts( - term: &Term, - used_extensions: &mut WeakExtensionRegistry, - missing_extensions: &mut ExtensionSet, -) { - match term { - Term::Runtime(ty) => collect_type_exts(ty, used_extensions, missing_extensions), - Term::ConstType(ty) => collect_type_exts(ty, used_extensions, missing_extensions), + Term::ConstType(ty) => collect_term_exts(ty, used_extensions, missing_extensions), Term::List(elems) => { for elem in elems.iter() { collect_term_exts(elem, used_extensions, missing_extensions); @@ -254,7 +230,8 @@ pub(super) fn collect_term_exts( | Term::BoundedNat(_) | Term::String(_) | Term::Bytes(_) - | Term::Float(_) => {} + | Term::Float(_) + | Term::RuntimeSum(SumType::Unit { .. }) => {} } } @@ -274,16 +251,16 @@ fn collect_value_exts( match value { Value::Extension { e } => { let typ = e.get_type(); - collect_type_exts(&typ, used_extensions, missing_extensions); + collect_term_exts(&typ, used_extensions, missing_extensions); } #[expect(deprecated)] // remove when Value::Function removed Value::Function { hugr: _ } => { // The extensions used by nested hugrs do not need to be counted for the root hugr. } Value::Sum(s) => { - if let SumType::General { rows } = &s.sum_type { - for row in rows { - collect_type_row_exts(row, used_extensions, missing_extensions); + if matches!(s.sum_type, SumType::General(_)) { + for row in s.sum_type.variants() { + collect_term_exts(row, used_extensions, missing_extensions); } } s.values diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 16ad96af6..581d32534 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -5,12 +5,11 @@ use std::sync::Weak; -use super::types::collect_type_exts; +use super::types::collect_term_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; -use crate::types::type_row::TypeRowBase; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{CustomType, FuncValueType, Signature, SumType, Term, TypeRow}; use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an @@ -68,7 +67,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, &mut c.signature, extensions, used_extensions)?; } OpType::LoadConstant(lc) => { - resolve_type_exts(node, &mut lc.datatype, extensions, used_extensions)?; + resolve_term_exts(node, &mut lc.datatype, extensions, used_extensions)?; } OpType::LoadFunction(lf) => { resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; @@ -125,12 +124,12 @@ pub fn resolve_op_types_extensions( Ok(used.into_iter()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [Signature]. /// /// Adds the extensions used in the signature to the `used_extensions` registry. -pub(super) fn resolve_signature_exts( +pub(super) fn resolve_signature_exts( node: Option, - signature: &mut FuncTypeBase, + signature: &mut Signature, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { @@ -139,48 +138,31 @@ pub(super) fn resolve_signature_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [FuncValueType]. /// -/// Adds the extensions used in the row to the `used_extensions` registry. -pub(super) fn resolve_type_row_exts( +/// Adds the extensions used in the signature to the `used_extensions` registry. +pub(super) fn resolve_func_type_exts( node: Option, - row: &mut TypeRowBase, + signature: &mut FuncValueType, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - for ty in row.iter_mut() { - resolve_type_exts(node, ty, extensions, used_extensions)?; - } + resolve_term_exts(node, &mut signature.input, extensions, used_extensions)?; + resolve_term_exts(node, &mut signature.output, extensions, used_extensions)?; Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. /// -/// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_type_exts( +/// Adds the extensions used in the row to the `used_extensions` registry. +pub(super) fn resolve_type_row_exts( node: Option, - typ: &mut TypeBase, + row: &mut TypeRow, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - match typ.as_type_enum_mut() { - TypeEnum::Extension(custom) => { - resolve_custom_type_exts(node, custom, extensions, used_extensions)?; - } - TypeEnum::Function(f) => { - resolve_type_row_exts(node, &mut f.input, extensions, used_extensions)?; - resolve_type_row_exts(node, &mut f.output, extensions, used_extensions)?; - } - TypeEnum::Sum(SumType::General { rows }) => { - for row in rows.iter_mut() { - resolve_type_row_exts(node, row, extensions, used_extensions)?; - } - } - // Other types do not store extensions. - TypeEnum::Alias(_) - | TypeEnum::RowVar(_) - | TypeEnum::Variable(_, _) - | TypeEnum::Sum(SumType::Unit { .. }) => {} + for ty in row.iter_mut() { + resolve_term_exts(node, ty, extensions, used_extensions)?; } Ok(()) } @@ -214,15 +196,26 @@ pub(super) fn resolve_custom_type_exts( /// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_term_exts( +pub(crate) fn resolve_term_exts( node: Option, term: &mut Term, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { match term { - Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, - Term::ConstType(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, + Term::RuntimeExtension(custom) => { + resolve_custom_type_exts(node, custom, extensions, used_extensions)?; + } + Term::RuntimeFunction(f) => { + resolve_term_exts(node, &mut f.input, extensions, used_extensions)?; + resolve_term_exts(node, &mut f.output, extensions, used_extensions)?; + } + Term::RuntimeSum(SumType::General(gs)) => { + for row in gs.iter_mut() { + resolve_term_exts(node, row, extensions, used_extensions)?; + } + } + Term::ConstType(ty) => resolve_term_exts(node, ty, extensions, used_extensions)?, Term::List(children) | Term::ListConcat(children) | Term::Tuple(children) @@ -247,7 +240,8 @@ pub(super) fn resolve_term_exts( | Term::BoundedNat(_) | Term::String(_) | Term::Bytes(_) - | Term::Float(_) => {} + | Term::Float(_) + | Term::RuntimeSum(SumType::Unit { .. }) => {} } Ok(()) } @@ -269,7 +263,7 @@ pub(super) fn resolve_value_exts( // return types with valid extensions after we call `update_extensions`. let typ = e.get_type(); let mut missing = ExtensionSet::new(); - collect_type_exts(&typ, used_extensions, &mut missing); + collect_term_exts(&typ, used_extensions, &mut missing); if !missing.is_empty() { return Err(ExtensionResolutionError::InvalidConstTypes { value: e.name(), @@ -286,9 +280,9 @@ pub(super) fn resolve_value_exts( } } Value::Sum(s) => { - if let SumType::General { rows } = &mut s.sum_type { - for row in rows.iter_mut() { - resolve_type_row_exts(node, row, extensions, used_extensions)?; + if let SumType::General(gs) = &mut s.sum_type { + for row in gs.iter_mut() { + resolve_term_exts(node, row, extensions, used_extensions)?; } } s.values diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index b848c7528..c247a1463 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -4,7 +4,7 @@ use std::sync::Weak; use super::{CustomConcrete, ExtensionBuildError}; use super::{Extension, ExtensionId, SignatureError}; -use crate::types::{CustomType, TypeName, least_upper_bound}; +use crate::types::{CustomType, Term, TypeName, least_upper_bound}; use crate::types::type_param::{TypeArg, check_term_types}; @@ -144,13 +144,13 @@ impl TypeDef { // Assume most general case return TypeBound::Linear; } - least_upper_bound(indices.iter().map(|i| { - let ta = args.get(*i); - match ta { - Some(TypeArg::Runtime(s)) => s.least_upper_bound(), - _ => panic!("TypeArg index does not refer to a type."), - } - })) + let bounds = indices.iter().map(|i| { + args.get(*i) + .copied() + .and_then(Term::least_upper_bound) + .expect("TypeArg index does not refer to a type.") + }); + least_upper_bound(bounds) } } } @@ -258,21 +258,19 @@ mod test { bound: TypeDefBound::FromParams { indices: vec![0] }, }; let typ = Type::new_extension( - def.instantiate(vec![ - Type::new_function(Signature::new(vec![], vec![])).into(), - ]) - .unwrap(), + def.instantiate(vec![Type::new_function(Signature::new(vec![], vec![]))]) + .unwrap(), ); - assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); - let typ2 = Type::new_extension(def.instantiate([usize_t().into()]).unwrap()); - assert_eq!(typ2.least_upper_bound(), TypeBound::Copyable); + assert_eq!(typ.least_upper_bound(), Some(TypeBound::Copyable)); + let typ2 = Type::new_extension(def.instantiate([usize_t()]).unwrap()); + assert_eq!(typ2.least_upper_bound(), Some(TypeBound::Copyable)); // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate([qb_t().into()]), + def.instantiate([qb_t()]), Err(SignatureError::TypeArgMismatch( TermTypeError::TypeMismatch { - term: Box::new(qb_t().into()), + term: Box::new(qb_t()), type_: Box::new(TypeBound::Copyable.into()) } )) @@ -284,7 +282,7 @@ mod test { ); // Too many arguments: assert_eq!( - def.instantiate([float64_type().into(), float64_type().into(),]) + def.instantiate([float64_type(), float64_type(),]) .unwrap_err(), SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) ); diff --git a/hugr-core/src/hugr/patch/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs index 23ccdbb14..b5f053034 100644 --- a/hugr-core/src/hugr/patch/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -287,7 +287,7 @@ mod test { #[test] fn test_polymorphic() -> Result<(), Box> { - let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); + let tuple_ty = Type::new_runtime_tuple(vec![usize_t(); 2]); let mut fb = FunctionBuilder::new("mkpair", Signature::new([usize_t()], [tuple_ty.clone()]))?; let helper = { @@ -302,10 +302,10 @@ mod test { let inps = fb2.input_wires(); fb2.finish_with_outputs(inps)? }; - let call1 = fb.call(helper.handle(), &[usize_t().into()], fb.input_wires())?; + let call1 = fb.call(helper.handle(), &[usize_t()], fb.input_wires())?; let [call1_out] = call1.outputs_arr(); let tup = fb.make_tuple([call1_out, call1_out])?; - let call2 = fb.call(helper.handle(), &[tuple_ty.into()], [tup])?; + let call2 = fb.call(helper.handle(), &[tuple_ty], [tup])?; let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap(); assert_eq!( diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 1aa438474..f47a71700 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -60,7 +60,7 @@ impl SimpleReplacement { node: replacement.entrypoint(), op: Box::new(replacement.get_optype(replacement.entrypoint()).to_owned()), })?; - if subgraph_sig != repl_sig { + if &subgraph_sig != repl_sig.as_ref() { return Err(InvalidReplacement::InvalidSignature { expected: Box::new(subgraph_sig), actual: Some(Box::new(repl_sig.into_owned())), diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 87a7e5507..72aae2cb6 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -26,8 +26,8 @@ use crate::std_extensions::std_reg; use crate::test_file; use crate::types::type_param::TypeParam; use crate::types::{ - FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, - TypeRV, + FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, TypeArg, + TypeBound, TypeRV, }; use crate::{OutgoingPort, Visibility, type_row}; use std::fs::File; @@ -39,6 +39,7 @@ use itertools::Itertools; use jsonschema::{Draft, Validator}; use portgraph::{Hierarchy, LinkMut, PortMut, UnmanagedDenseMap, multiportgraph::MultiPortGraph}; use rstest::rstest; +use serde_with::serde_as; /// A serde-serializable hugr. Used for testing. #[derive(Debug, serde::Serialize)] @@ -50,8 +51,10 @@ pub(super) struct HugrSer<'h>(#[serde(serialize_with = "Hugr::serde_serialize")] pub(super) struct HugrDeser(#[serde(deserialize_with = "Hugr::serde_deserialize")] pub Hugr); /// Version 1 of the Testing HUGR serialization format, see `testing_hugr.py`. +#[serde_as] #[derive(Serialize, Deserialize, PartialEq, Debug, Default)] struct SerTestingLatest { + #[serde_as(as = "Option")] typ: Option, sum_type: Option, poly_func_type: Option, @@ -142,7 +145,7 @@ macro_rules! impl_sertesting_from { }; } -impl_sertesting_from!(crate::types::TypeRV, typ); +impl_sertesting_from!(crate::types::Type, typ); impl_sertesting_from!(crate::types::SumType, sum_type); impl_sertesting_from!(crate::types::PolyFuncTypeRV, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); @@ -156,13 +159,6 @@ impl From for SerTestingLatest { } } -impl From for SerTestingLatest { - fn from(v: Type) -> Self { - let t: TypeRV = v.into(); - t.into() - } -} - #[test] fn empty_hugr_serialize() { check_hugr_json_roundtrip(&Hugr::default(), true); @@ -522,7 +518,7 @@ fn serialize_types_roundtrip() { check_testing_roundtrip(g.clone()); // A Simple tuple - let t = Type::new_tuple(vec![usize_t(), g]); + let t = Type::new_runtime_tuple(vec![usize_t(), g]); check_testing_roundtrip(t); // A Classic sum @@ -537,9 +533,8 @@ fn serialize_types_roundtrip() { #[case(bool_t())] #[case(usize_t())] #[case(INT_TYPES[2].clone())] -#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Linear)))] #[case(Type::new_var_use(2, TypeBound::Copyable))] -#[case(Type::new_tuple(vec![bool_t(),qb_t()]))] +#[case(Type::new_runtime_tuple(vec![bool_t(),qb_t()]))] #[case(Type::new_sum([vec![bool_t(),qb_t()], vec![Type::new_unit_sum(4)]]))] #[case(Type::new_function(Signature::new_endo([qb_t(),bool_t(),usize_t()])))] fn roundtrip_type(#[case] typ: Type) { @@ -575,11 +570,14 @@ fn polyfunctype2() -> PolyFuncTypeRV { let tv0 = TypeRV::new_row_var_use(0, TypeBound::Linear); let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); let params = [TypeBound::Linear, TypeBound::Copyable].map(TypeParam::new_list_type); - let inputs = vec![ - TypeRV::new_function(FuncValueType::new([tv0.clone()], [tv1.clone()])), + let inputs = Term::concat_lists([ + Term::new_list([TypeRV::new_function(FuncValueType::new( + tv0.clone(), + tv1.clone(), + ))]), tv0, - ]; - let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, [tv1])); + ]); + let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, tv1)); // Just check we've got the arguments the right way round // (not that it really matters for the serialization schema we have) res.validate().unwrap(); @@ -595,7 +593,7 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[case(PolyFuncType::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( [TypeParam::new_list_type(TypeBound::Linear)], - Signature::new_endo([Type::new_tuple([TypeRV::new_row_var_use(0, TypeBound::Linear)])])))] + Signature::new_endo([Type::new_runtime_tuple(TypeRV::new_row_var_use(0, TypeBound::Linear))])))] fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type); } @@ -608,7 +606,7 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[case(PolyFuncTypeRV::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( [TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo([TypeRV::new_row_var_use(0, TypeBound::Linear)])))] + FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Linear))))] #[case(polyfunctype2())] fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { check_testing_roundtrip(poly_func_type); @@ -670,7 +668,8 @@ mod proptest { use super::check_testing_roundtrip; use super::{NodeSer, SimpleOpDef}; use crate::ops::{OpType, OpaqueOp, Value}; - use crate::types::{PolyFuncTypeRV, Type}; + use crate::proptest::RecursionDepth; + use crate::types::{PolyFuncTypeRV, proptest_utils::any_type}; use proptest::prelude::*; impl Arbitrary for NodeSer { @@ -698,7 +697,7 @@ mod proptest { proptest! { #[test] - fn prop_roundtrip_type(t: Type) { + fn prop_roundtrip_type(t in any_type(RecursionDepth::default())) { check_testing_roundtrip(t); } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 7cec7d6e6..a9408ac52 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -324,7 +324,7 @@ fn invalid_types() { let valid = Type::new_extension(CustomType::new( "MyContainer", - vec![usize_t().into()], + vec![usize_t()], EXT_ID, TypeBound::Linear, &Arc::downgrade(&ext), @@ -336,7 +336,7 @@ fn invalid_types() { // valid is Any, so is not allowed as an element of an outer MyContainer. let element_outside_bound = CustomType::new( "MyContainer", - vec![valid.clone().into()], + vec![valid.clone()], EXT_ID, TypeBound::Linear, &Arc::downgrade(&ext), @@ -345,13 +345,13 @@ fn invalid_types() { validate_to_sig_error(element_outside_bound), SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { type_: Box::new(TypeBound::Copyable.into()), - term: Box::new(valid.into()) + term: Box::new(valid) }) ); let bad_bound = CustomType::new( "MyContainer", - vec![usize_t().into()], + vec![usize_t()], EXT_ID, TypeBound::Copyable, &Arc::downgrade(&ext), @@ -367,7 +367,7 @@ fn invalid_types() { // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. let nested = CustomType::new( "MyContainer", - vec![Type::new_extension(bad_bound).into()], + vec![Type::new_extension(bad_bound)], EXT_ID, TypeBound::Linear, &Arc::downgrade(&ext), @@ -382,7 +382,7 @@ fn invalid_types() { let too_many_type_args = CustomType::new( "MyContainer", - vec![usize_t().into(), 3u64.into()], + vec![usize_t(), 3u64.into()], EXT_ID, TypeBound::Linear, &Arc::downgrade(&ext), @@ -497,11 +497,10 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { Extension::new_test_arc(EXT_ID, |ext, extension_ref| { let inputs = TypeRV::new_row_var_use(0, TypeBound::Linear); let outputs = TypeRV::new_row_var_use(1, TypeBound::Linear); - let evaled_fn = - TypeRV::new_function(FuncValueType::new([inputs.clone()], [outputs.clone()])); + let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone()], - FuncValueType::new([evaled_fn, inputs], [outputs]), + FuncValueType::new(Term::concat_lists([[evaled_fn].into(), inputs]), outputs), ); ext.add_op("eval".into(), String::new(), pf, extension_ref) .unwrap(); @@ -510,13 +509,13 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], Signature::new( - vec![ - Type::new_function(FuncValueType::new([rv(0)], [rv(2)])), - Type::new_function(FuncValueType::new([rv(1)], [rv(3)])), + [ + Type::new_function(FuncValueType::new(rv(0), rv(2))), + Type::new_function(FuncValueType::new(rv(1), rv(3))), ], [Type::new_function(FuncValueType::new( - [rv(0), rv(1)], - [rv(2), rv(3)], + Term::concat_lists([rv(0), rv(1)]), + Term::concat_lists([rv(2), rv(3)]), ))], ), ); @@ -528,7 +527,7 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { #[test] fn instantiate_row_variables() -> Result<(), Box> { fn uint_seq(i: usize) -> Term { - vec![usize_t().into(); i].into() + vec![usize_t(); i].into() } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( @@ -552,16 +551,15 @@ fn instantiate_row_variables() -> Result<(), Box> { Ok(()) } -fn list1ty(t: TypeRV) -> Term { - Term::new_list([t.into()]) -} - #[test] fn row_variables() -> Result<(), Box> { let e = extension_with_eval_parallel(); let tv = TypeRV::new_row_var_use(0, TypeBound::Linear); - let inner_ft = Type::new_function(FuncValueType::new_endo([tv.clone()])); - let ft_usz = Type::new_function(FuncValueType::new_endo([tv.clone(), usize_t().into()])); + let inner_ft = Type::new_function(FuncValueType::new_endo(tv.clone())); + let ft_usz = Type::new_function(FuncValueType::new_endo(Term::concat_lists([ + tv.clone(), + [usize_t()].into(), + ]))); let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( @@ -580,7 +578,12 @@ fn row_variables() -> Result<(), Box> { }; let par = e.instantiate_extension_op( "parallel", - [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(list1ty), + [ + tv.clone(), + [usize_t()].into(), + tv.clone(), + [usize_t()].into(), + ], )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs(par_func.outputs())?; @@ -602,7 +605,7 @@ fn test_polymorphic_load() -> Result<(), Box> { vec![Type::new_function(Signature::new_endo([usize_t()]))], ); let mut f = m.define_function("main", sig)?; - let l = f.load_func(&id, &[usize_t().into()])?; + let l = f.load_func(&id, &[usize_t()])?; f.finish_with_outputs([l])?; let _ = m.finish_hugr()?; Ok(()) diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index abdbacf74..eb68a8ac9 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -12,7 +12,7 @@ use crate::{ OpParent, OpTrait, OpType, handle::{DataflowParentID, DfgID}, }, - types::{NoRV, Signature, Type, TypeBase}, + types::{Signature, Type}, }; use super::RootChecked; @@ -262,7 +262,7 @@ fn update_signature(hugr: &mut H, node: H::Node, new_sig: &Signature fn check_valid_inputs( old_ports: &[Vec], - old_sig: &[TypeBase], + old_sig: &[Type], map_sig: &[usize], ) -> Result<(), InvalidSignature> { if let Some(old_pos) = map_sig @@ -291,10 +291,7 @@ fn check_valid_inputs( Ok(()) } -fn check_valid_outputs( - old_sig: &[TypeBase], - map_sig: &[usize], -) -> Result<(), InvalidSignature> { +fn check_valid_outputs(old_sig: &[Type], map_sig: &[usize]) -> Result<(), InvalidSignature> { if let Some(old_pos) = map_sig .iter() .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos)) @@ -684,8 +681,8 @@ mod test { let new_inputs = vec![bool_t(), float64_type()]; dfg_view.extend_inputs(&new_inputs).unwrap(); assert_eq!( - dfg_view.hugr().inner_function_type().unwrap(), - Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()]) + dfg_view.hugr().inner_function_type().unwrap().as_ref(), + &Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()]) ); let new_inputs_fail = vec![qb_t()]; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 6d826131c..b9ae9c79b 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -1970,7 +1970,7 @@ mod tests { assert_eq!(subg.nodes().len(), 1); assert_eq!( subg.signature(&h).io(), - Signature::new(type_row![], vec![Type::new_tuple(type_row![])]).io() + Signature::new(type_row![], vec![Type::new_runtime_tuple(type_row![])]).io() ); // `from_nodes` is different, is it only uses incoming and outgoing edges to @@ -1990,7 +1990,7 @@ mod tests { // A hugr with some empty MakeTuple operations. let tuple_op = MakeTuple::new(type_row![]); let untuple_op = UnpackTuple::new(type_row![]); - let tuple_t = Type::new_tuple(type_row![]); + let tuple_t = Type::new_runtime_tuple(type_row![]); let mut b = DFGBuilder::new(Signature::new(type_row![], vec![tuple_t.clone()])).unwrap(); let mk_tuple_1 = b.add_dataflow_op(tuple_op.clone(), []).unwrap(); diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 1e62bc699..d6b6f4bf7 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use crate::envelope::description::GeneratorDesc; use crate::metadata::{self, Metadata}; +use crate::types::FuncValueType; use crate::{ Direction, Hugr, HugrView, Node, Port, envelope::description::{ExtensionDesc, ModuleDesc}, @@ -27,10 +28,8 @@ use crate::{ collections::array::ArrayValue, }, types::{ - CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, - Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, + CustomType, PolyFuncType, Signature, Term, Type, TypeArg, TypeBound, TypeName, TypeRow, type_param::{SeqPart, TypeParam}, - type_row::TypeRowBase, }, }; use hugr_model::v0::table; @@ -314,7 +313,7 @@ impl<'a> Context<'a> { let signature = node_data .signature .ok_or_else(|| error_uninferred!("node signature"))?; - self.import_func_type(signature) + self.import_signature(signature) } /// Get the node with the given `NodeId`, or return an error if it does not exist. @@ -687,7 +686,7 @@ impl<'a> Context<'a> { } let signature = self - .import_func_type( + .import_signature( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -839,12 +838,15 @@ impl<'a> Context<'a> { let sum_rows: Vec<_> = { let [variants] = self.expect_symbol(*first, model::CORE_ADT)?; - self.import_type_rows(variants)? + self.import_closed_list(variants)? + .into_iter() + .map(|term_id| self.import_type_row(term_id)) + .collect::>()? }; let rest = rest .iter() - .map(|term| self.import_type(*term)) + .map(|term| self.import_term(*term)) .collect::, _>>()? .into(); @@ -933,7 +935,7 @@ impl<'a> Context<'a> { for region in node_data.regions { let region_data = self.get_region(*region)?; - let signature = self.import_func_type( + let signature = self.import_signature( region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, @@ -1255,7 +1257,7 @@ impl<'a> Context<'a> { let output = outputs.first().ok_or_else(|| { error_invalid!("`{}` expects a single output", model::CORE_LOAD_CONST) })?; - let datatype = self.import_type(*output)?; + let datatype = self.import_term(*output)?; let imported_value = self.import_value(value, *output)?; @@ -1350,7 +1352,7 @@ impl<'a> Context<'a> { let optype = OpType::AliasDefn(AliasDefn { name: symbol.name.to_smolstr(), - definition: self.import_type(value)?, + definition: self.import_term(value)?, }); let node = self.make_node(node_id, optype, parent)?; @@ -1378,11 +1380,11 @@ impl<'a> Context<'a> { Ok(node) } - fn import_poly_func_type( + fn import_poly_func_type( &mut self, node: table::NodeId, symbol: table::Symbol<'a>, - in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, + in_scope: impl FnOnce(&mut Self, PolyFuncType) -> Result, ) -> Result { (|| { let mut imported_params = Vec::with_capacity(symbol.params.len()); @@ -1425,8 +1427,8 @@ impl<'a> Context<'a> { ); } - let body = self.import_func_type::(symbol.signature)?; - in_scope(self, PolyFuncTypeBase::new(imported_params, body)) + let body = self.import_signature(symbol.signature)?; + in_scope(self, PolyFuncType::new(imported_params, body)) })() .map_err(|err| error_context!(err, "symbol `{}` defined by node {}", symbol.name, node)) } @@ -1472,7 +1474,7 @@ impl<'a> Context<'a> { if let Some([ty]) = self.match_symbol(term_id, model::CORE_CONST)? { let ty = self - .import_type(ty) + .import_term(ty) .map_err(|err| error_context!(err, "type of a constant"))?; return Ok(TypeParam::new_const(ty)); } @@ -1495,6 +1497,23 @@ impl<'a> Context<'a> { return Ok(TypeParam::new_tuple_type(item_types)); } + if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { + let func_type = self.import_func_type(term_id)?; + return Ok(Type::new_function(func_type)); + } + + if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { + let variants = (|| { + self.import_closed_list(variants)? + .iter() + .map(|variant| self.import_term(*variant)) + .collect::, _>>() + })() + .map_err(|err| error_context!(err, "adt variants"))?; + + return Ok(Type::new_sum(variants)); + } + match self.get_term(term_id)? { table::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -1539,51 +1558,6 @@ impl<'a> Context<'a> { table::Term::Literal(model::Literal::Float(value)) => Ok(Term::Float(*value)), table::Term::Func { .. } => Err(error_unsupported!("function constant")), - table::Term::Apply { .. } => { - let ty: Type = self.import_type(term_id)?; - Ok(ty.into()) - } - } - })() - .map_err(|err| error_context!(err, "term {}", term_id)) - } - - fn import_seq_part( - &mut self, - seq_part: &'a table::SeqPart, - ) -> Result, ImportErrorInner> { - Ok(match seq_part { - table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), - table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), - }) - } - - /// Import a `Type` from a term that represents a runtime type. - fn import_type( - &mut self, - term_id: table::TermId, - ) -> Result, ImportErrorInner> { - (|| { - if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { - let func_type = self.import_func_type::(term_id)?; - return Ok(TypeBase::new_function(func_type)); - } - - if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { - let variants = (|| { - self.import_closed_list(variants)? - .iter() - .map(|variant| self.import_type_row::(*variant)) - .collect::, _>>() - })() - .map_err(|err| error_context!(err, "adt variants"))?; - - return Ok(TypeBase::new_sum(variants)); - } - - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), - table::Term::Apply(symbol, args) => { let name = self.get_symbol_name(*symbol)?; @@ -1615,7 +1589,7 @@ impl<'a> Context<'a> { let bound = ext_type.bound(&args); - Ok(TypeBase::new_extension(CustomType::new( + Ok(Term::new_extension(CustomType::new( id, args, extension, @@ -1623,24 +1597,19 @@ impl<'a> Context<'a> { &Arc::downgrade(extension_ref), ))) } - - table::Term::Var(var @ table::VarId(_, index)) => { - let local_var = self - .local_vars - .get(var) - .ok_or(error_invalid!("unknown var {}", var))?; - Ok(TypeBase::new_var_use(*index as _, local_var.bound)) - } - - // The following terms are not runtime types, but the core `Type` only contains runtime types. - // We therefore report a type error here. - table::Term::Literal(_) - | table::Term::List { .. } - | table::Term::Tuple { .. } - | table::Term::Func { .. } => Err(error_invalid!("expected a runtime type")), } })() - .map_err(|err| error_context!(err, "term {} as `Type`", term_id)) + .map_err(|err| error_context!(err, "term {}", term_id)) + } + + fn import_seq_part( + &mut self, + seq_part: &'a table::SeqPart, + ) -> Result, ImportErrorInner> { + Ok(match seq_part { + table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), + table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), + }) } fn get_func_type( @@ -1667,23 +1636,28 @@ impl<'a> Context<'a> { /// /// Function types are not special-cased in `hugr-model` but are represented /// via the `core.fn` term constructor. - fn import_func_type( + fn import_func_type( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { + ) -> Result { (|| { let [inputs, outputs] = self.get_func_type(term_id)?; let inputs = self - .import_type_row(inputs) + .import_term(inputs) .map_err(|err| error_context!(err, "function inputs"))?; let outputs = self - .import_type_row(outputs) + .import_term(outputs) .map_err(|err| error_context!(err, "function outputs"))?; - Ok(FuncTypeBase::new(inputs, outputs)) + Ok(FuncValueType::new(inputs, outputs)) })() .map_err(|err| error_context!(err, "function type")) } + fn import_signature(&mut self, term_id: table::TermId) -> Result { + let fvt = self.import_func_type(term_id)?; + Ok(fvt.try_into()?) + } + /// Import a closed list as a vector of term ids. /// /// This method supports list terms that contain spliced sublists as long as @@ -1778,64 +1752,18 @@ impl<'a> Context<'a> { Ok(types) } - /// Imports a list of lists as a vector of type rows. - /// - /// See [`Self::import_type_row`]. - fn import_type_rows( - &mut self, - term_id: table::TermId, - ) -> Result>, ImportErrorInner> { - self.import_closed_list(term_id)? - .into_iter() - .map(|term_id| self.import_type_row::(term_id)) - .collect() - } - - /// Imports a list as a type row. + /// Imports a closed list as a type row. /// /// This method works to produce a [`TypeRow`] or a [`TypeRowRV`], depending /// on the `RV` type argument. For [`TypeRow`] a closed list is expected. /// For [`TypeRowRV`] we import spliced variables as row variables. - fn import_type_row( - &mut self, - term_id: table::TermId, - ) -> Result, ImportErrorInner> { - fn import_into( - ctx: &mut Context, - term_id: table::TermId, - types: &mut Vec>, - ) -> Result<(), ImportErrorInner> { - match ctx.get_term(term_id)? { - table::Term::List(parts) => { - types.reserve(parts.len()); - - for item in *parts { - match item { - table::SeqPart::Item(term_id) => { - types.push(ctx.import_type::(*term_id)?); - } - table::SeqPart::Splice(term_id) => { - import_into(ctx, *term_id, types)?; - } - } - } - } - table::Term::Var(table::VarId(_, index)) => { - let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Linear)) - .map_err(|_| { - error_invalid!("Expected a closed list.\n{}", CLOSED_LIST_HINT) - })?; - types.push(TypeBase::new(TypeEnum::RowVar(var))); - } - _ => return Err(error_invalid!("expected a list")), - } - - Ok(()) - } - - let mut types = Vec::new(); - import_into(self, term_id, &mut types)?; - Ok(types.into()) + fn import_type_row(&mut self, term_id: table::TermId) -> Result { + let elems = self.import_closed_list(term_id)?; + Ok(elems + .into_iter() + .map(|id| self.import_term(id)) + .collect::, _>>()? + .into()) } fn import_custom_name( @@ -1891,7 +1819,7 @@ impl<'a> Context<'a> { let opaque_value = OpaqueValue::from(value); return Ok(Value::Extension { e: opaque_value }); } else { - let runtime_type = self.import_type(runtime_type)?; + let runtime_type = self.import_term(runtime_type)?; let value: serde_json::Value = serde_json::from_str(json).map_err(|_| { error_invalid!( "unable to parse JSON string for `{}`", @@ -1907,7 +1835,7 @@ impl<'a> Context<'a> { if let Some([_, element_type_term, contents]) = self.match_symbol(term_id, ArrayValue::CTR_NAME)? { - let element_type = self.import_type(element_type_term)?; + let element_type = self.import_term(element_type_term)?; let contents = self.import_closed_list(contents)?; let contents = contents .iter() @@ -1987,13 +1915,8 @@ impl<'a> Context<'a> { .map(|(value, ty)| self.import_value(*value, *ty)) .collect::, _>>()?; - let ty = { - // TODO: Import as a `SumType` directly and avoid the copy. - let ty: Type = self.import_type(type_id)?; - match ty.as_type_enum() { - TypeEnum::Sum(sum) => sum.clone(), - _ => unreachable!(), - } + let Term::RuntimeSum(ty) = self.import_term(type_id)? else { + unreachable!() }; return Ok(Value::sum(*tag as _, items, ty).unwrap()); @@ -2138,7 +2061,7 @@ impl<'a> Context<'a> { struct LocalVar { /// The type of the variable. r#type: table::TermId, - /// The type bound of the variable. + /// The type bound of the variable. Overwritten if a constraint is seen. bound: TypeBound, } diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index b9db8214a..f6d80fd03 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -208,7 +208,7 @@ pub enum Value { /// use serde_json::json; /// /// let expected_json = json!({ -/// "typ": usize_t(), +/// "typ": {"t": "I"}, // No public way to serialize a Term as a (SerSimple)Type... /// "value": {'c': "ConstUsize", 'v': 1} /// }); /// let ev = OpaqueValue::new(ConstUsize::new(1)); @@ -217,7 +217,7 @@ pub enum Value { /// /// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null)); /// let expected_json = json!({ -/// "typ": usize_t(), +/// "typ": {"t": "I"}, // No public way to serialize a Term as a (SerSimple)Type /// "value": null /// }); /// @@ -367,7 +367,9 @@ impl Value { let vs = items.into_iter().collect_vec(); let tys = vs.iter().map(Self::get_type).collect_vec(); - Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid") + let sty = SumType::try_new([tys.clone()]) + .unwrap_or_else(|_| panic!("Values {:?} tys {:?}", vs, tys)); + Self::sum(0, vs, sty).expect("Tuple type is valid") } /// Returns a constant function defined by a Hugr. @@ -757,7 +759,7 @@ pub(crate) mod test { #[case(Value::unit(), Type::UNIT, "const:seq:{}")] #[case(const_usize(), usize_t(), "const:custom:ConstUsize(")] #[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")] - #[case(const_tuple(), Type::new_tuple(vec![usize_t(), bool_t()]), "const:seq:{")] + #[case(const_tuple(), Type::new_runtime_tuple(vec![usize_t(), bool_t()]), "const:seq:{")] #[case(const_array_bool(), array_type(2, bool_t()), "const:custom:array")] #[case( const_borrow_array_bool(), @@ -830,7 +832,7 @@ pub(crate) mod test { ); let json_const: Value = CustomSerialized::new(typ_int.clone(), 6.into()).into(); let classic_t = Type::new_extension(typ_int.clone()); - assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable); + assert_matches!(classic_t.least_upper_bound(), Some(TypeBound::Copyable)); assert_eq!(json_const.get_type(), classic_t); let typ_qb = CustomType::new( @@ -885,9 +887,9 @@ pub(crate) mod test { use super::super::{OpaqueValue, Sum}; use crate::{ ops::{Value, constant::CustomSerialized}, - std_extensions::arithmetic::int_types::ConstInt, - std_extensions::collections::list::ListValue, - types::{SumType, Type}, + proptest::RecursionDepth, + std_extensions::{arithmetic::int_types::ConstInt, collections::list::ListValue}, + types::{SumType, proptest_utils::any_type}, }; use ::proptest::{collection::vec, prelude::*}; impl Arbitrary for OpaqueValue { @@ -905,12 +907,14 @@ pub(crate) mod test { 32, // Target around 32 total elements 3, // Each collection is up to 3 elements long |child_strat| { - (any::(), vec(child_strat, 0..3)).prop_map(|(typ, children)| { - Self::new(ListValue::new( - typ, - children.into_iter().map(|e| Value::Extension { e }), - )) - }) + (any_type(RecursionDepth::default()), vec(child_strat, 0..3)).prop_map( + |(typ, children)| { + Self::new(ListValue::new( + typ, + children.into_iter().map(|e| Value::Extension { e }), + )) + }, + ) }, ) .boxed() diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index ac4251b5f..c69449301 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -8,11 +8,12 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use downcast_rs::{Downcast, impl_downcast}; +use serde_with::serde_as; use thiserror::Error; use crate::IncomingPort; use crate::extension::resolution::{ - ExtensionResolutionError, WeakExtensionRegistry, resolve_type_extensions, + ExtensionResolutionError, WeakExtensionRegistry, resolve_term_extensions, }; use crate::macros::impl_box_clone; use crate::types::{CustomCheckFailure, Type}; @@ -171,9 +172,11 @@ fn deserialize_dyn_custom_const( impl_downcast!(CustomConst); impl_box_clone!(CustomConst, CustomConstBoxClone); +#[serde_as] #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// A constant value stored as a serialized blob that can report its own type. pub struct CustomSerialized { + #[serde_as(as = "crate::types::serialize::SerType")] typ: Type, value: serde_json::Value, } @@ -303,7 +306,7 @@ impl CustomConst for CustomSerialized { &mut self, extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - resolve_type_extensions(&mut self.typ, extensions) + resolve_term_extensions(&mut self.typ, extensions) } fn get_type(&self) -> Type { self.typ.clone() @@ -525,14 +528,14 @@ mod proptest { use crate::{ ops::constant::CustomSerialized, proptest::{any_serde_json_value, any_string}, - types::Type, + types::CustomType, }; impl Arbitrary for CustomSerialized { type Parameters = (); type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - let typ = any::(); + let typ = any::().prop_map_into(); // here we manually construct a serialized `dyn CustomConst`. // The "c" and "v" come from the `typetag::serde` annotation on // `trait CustomConst`. diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 45a06b16f..97b8c0654 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -365,11 +365,11 @@ mod test { other_outputs: vec![tv0.clone()].into(), sum_rows: vec![[usize_t()].into(), [qb_t(), tv0.clone()].into()], }; - let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()])); + let dfb2 = dfb.substitute(&Substitution::new(&[qb_t()])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( - dfb2.inner_signature(), - Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) + dfb2.inner_signature().as_ref(), + &Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) ); } @@ -378,22 +378,22 @@ mod test { let tv1 = Type::new_var_use(1, TypeBound::Linear); let cond = Conditional { sum_rows: vec![[usize_t()].into(), [tv1.clone()].into()], - other_inputs: vec![Type::new_tuple([TypeRV::new_row_var_use( + other_inputs: vec![Type::new_runtime_tuple(TypeRV::new_row_var_use( 0, TypeBound::Linear, - )])] + ))] .into(), outputs: vec![usize_t(), tv1].into(), }; let cond2 = cond.substitute(&Substitution::new(&[ - TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]), - qb_t().into(), + TypeArg::new_list([usize_t(), usize_t(), usize_t()]), + qb_t(), ])); let st = Type::new_sum([[usize_t()], [qb_t()]]); assert_eq!( - cond2.signature(), - Signature::new( - [st, Type::new_tuple(vec![usize_t(); 3])], + cond2.signature().as_ref(), + &Signature::new( + [st, Type::new_runtime_tuple(vec![usize_t(); 3])], [usize_t(), qb_t()] ) ); @@ -407,10 +407,10 @@ mod test { just_outputs: vec![tv0.clone(), qb_t()].into(), rest: vec![tv0.clone()].into(), }; - let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()])); + let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t()])); assert_eq!( - tail2.signature(), - Signature::new( + tail2.signature().as_ref(), + &Signature::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 878fbe04c..a220a1a7e 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -410,11 +410,11 @@ mod test { let op = OpaqueOp::new( "res".try_into().unwrap(), "op", - vec![usize_t().into()], + vec![usize_t()], sig.clone(), ); assert_eq!(op.name(), "OpaqueOp:res.op"); - assert_eq!(op.args(), &[usize_t().into()]); + assert_eq!(op.args(), &[usize_t()]); assert_eq!(op.signature().as_ref(), &sig); let optype: OpType = op.into(); diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 9e4676472..4fba12a83 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -2,6 +2,8 @@ use std::borrow::Cow; +use serde_with::serde_as; + use super::{OpTag, OpTrait, impl_op_name}; use crate::extension::SignatureError; @@ -10,7 +12,11 @@ use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeAr use crate::{IncomingPort, type_row}; #[cfg(test)] -use {crate::types::proptest_utils::any_serde_type_arg_vec, proptest_derive::Arbitrary}; +use { + crate::proptest::RecursionDepth, + crate::types::proptest_utils::{any_serde_type_arg_vec, any_type}, + proptest_derive::Arbitrary, +}; /// Trait implemented by all dataflow operations. pub trait DataflowOpTrait: Sized { @@ -326,10 +332,13 @@ impl DataflowOpTrait for CallIndirect { } /// Load a static constant in to the local dataflow graph. +#[serde_as] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(Arbitrary))] pub struct LoadConstant { /// Constant type + #[cfg_attr(test, proptest(strategy = "any_type(RecursionDepth::default())"))] + #[serde_as(as = "crate::types::serialize::SerType")] pub datatype: Type, } impl_op_name!(LoadConstant); diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index 71955bdc1..b8e54e194 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,12 +1,10 @@ //! Handles to nodes in HUGR. use crate::Node; use crate::core::HugrNode; -use crate::types::{Type, TypeBound}; use derive_more::From as DerFrom; -use smol_str::SmolStr; -use super::{AliasDecl, OpTag}; +use super::OpTag; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. @@ -71,34 +69,6 @@ pub struct ModuleID(N); /// defined or just declared. pub struct FuncID(N); -#[derive(Debug, Clone, PartialEq, Eq)] -/// Handle to an [`AliasDefn`](crate::ops::OpType::AliasDefn) -/// or [`AliasDecl`](crate::ops::OpType::AliasDecl) node. -/// -/// The `DEF` const generic is used to indicate whether the function is -/// defined or just declared. -pub struct AliasID { - node: N, - name: SmolStr, - bound: TypeBound, -} - -impl AliasID { - /// Construct new `AliasID` - pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self { - Self { node, name, bound } - } - - /// Construct new `AliasID` - pub fn get_alias_type(&self) -> Type { - Type::new_alias(AliasDecl::new(self.name.clone(), self.bound)) - } - /// Retrieve the underlying core type - pub fn get_name(&self) -> &SmolStr { - &self.name - } -} - #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. pub struct ConstID(N); @@ -166,14 +136,6 @@ impl NodeHandle for FuncID { } } -impl NodeHandle for AliasID { - const TAG: OpTag = OpTag::Alias; - #[inline] - fn node(&self) -> N { - self.node - } -} - impl NodeHandle for N { const TAG: OpTag = OpTag::Any; #[inline] @@ -202,6 +164,3 @@ impl_containerHandle!(BasicBlockID, DataflowOpID); impl ContainerHandle for FuncID { type ChildrenHandle = DataflowOpID; } -impl ContainerHandle for AliasID { - type ChildrenHandle = DataflowOpID; -} diff --git a/hugr-core/src/ops/module.rs b/hugr-core/src/ops/module.rs index eda121f23..f3745506f 100644 --- a/hugr-core/src/ops/module.rs +++ b/hugr-core/src/ops/module.rs @@ -2,10 +2,12 @@ use std::borrow::Cow; +use serde_with::serde_as; use smol_str::SmolStr; #[cfg(test)] use { - crate::proptest::{any_nonempty_smolstr, any_nonempty_string}, + crate::proptest::{RecursionDepth, any_nonempty_smolstr, any_nonempty_string}, + crate::types::proptest_utils::any_type, ::proptest_derive::Arbitrary, }; @@ -231,6 +233,7 @@ impl OpTrait for FuncDecl { } /// A type alias definition, used only for debug/metadata. +#[serde_as] #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(Arbitrary))] pub struct AliasDefn { @@ -238,8 +241,11 @@ pub struct AliasDefn { #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))] pub name: SmolStr, /// Aliased type + #[serde_as(as = "crate::types::serialize::SerType")] + #[cfg_attr(test, proptest(strategy = "any_type(RecursionDepth::default())"))] pub definition: Type, } + impl_op_name!(AliasDefn); impl StaticTag for AliasDefn { const TAG: OpTag = OpTag::Alias; diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 11c16b14a..268854f5c 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -10,7 +10,7 @@ use crate::extension::simple_op::{ use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs}; use crate::ops::OpName; use crate::ops::custom::ExtensionOp; -use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; +use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRow}; use crate::utils::collect_array; use crate::{ @@ -136,16 +136,16 @@ impl MakeOpDef for IntOpDef { } ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(), idivmod_checked_u | idivmod_checked_s => { - let intpair: TypeRowRV = vec![tv0; 2].into(); + let intpair = vec![tv0; 2]; int_polytype( 1, intpair.clone(), - [sum_ty_with_err(Type::new_tuple(intpair))], + [sum_ty_with_err(Type::new_runtime_tuple(intpair))], ) } .into(), idivmod_u | idivmod_s => { - let intpair: TypeRowRV = vec![tv0; 2].into(); + let intpair = vec![tv0; 2]; int_polytype(1, intpair.clone(), intpair.clone()) } .into(), @@ -227,15 +227,15 @@ impl MakeOpDef for IntOpDef { } } -/// Returns a polytype composed by a function type, and a number of integer width type parameters. +/// Returns a polytype composed by a fixed-arity function type, and a number of integer width type parameters. pub(in crate::std_extensions::arithmetic) fn int_polytype( n_vars: usize, - input: impl Into, - output: impl Into, + input: impl Into, + output: impl Into, ) -> PolyFuncTypeRV { PolyFuncTypeRV::new( vec![LOG_WIDTH_TYPE_PARAM; n_vars], - FuncValueType::new(input, output), + FuncValueType::new(input.into().into_owned(), output.into().into_owned()), ) } diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs index 2df2ceb36..cc0ed9729 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -586,7 +586,7 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { } else { let q_type = INT_TYPES[logwidth0 as usize].clone(); let r_type = q_type.clone(); - let qr_type: Type = Type::new_tuple(vec![q_type, r_type]); + let qr_type: Type = Type::new_runtime_tuple(vec![q_type, r_type]); let err_value = || { ConstError { signal: 0, @@ -647,7 +647,7 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { } else { let q_type = INT_TYPES[logwidth0 as usize].clone(); let r_type = INT_TYPES[logwidth0 as usize].clone(); - let qr_type: Type = Type::new_tuple(vec![q_type, r_type]); + let qr_type: Type = Type::new_runtime_tuple(vec![q_type, r_type]); let err_value = || { ConstError { signal: 0, diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index a731f0e67..85ef5df89 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -357,7 +357,7 @@ pub trait ArrayOpBuilder: GenericArrayOpBuilder { index1: Wire, index2: Wire, ) -> Result { - let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty])?; let [out] = self .add_dataflow_op(op, vec![input, index1, index2])? .outputs_arr(); diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs index 566ee12c7..fbc787f29 100644 --- a/hugr-core/src/std_extensions/collections/array/array_clone.rs +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -157,7 +157,7 @@ impl MakeExtensionOp for GenericArrayClone { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![self.size.into(), self.elem_ty.clone()] } } @@ -180,7 +180,7 @@ impl HasConcrete for GenericArrayCloneDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { + [TypeArg::BoundedNat(n), ty] if ty.copyable() => { Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs index 61b013a06..fbc1a4c49 100644 --- a/hugr-core/src/std_extensions/collections/array/array_conversion.rs +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -10,7 +10,7 @@ use crate::extension::simple_op::{ }; use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; -use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_type}; use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; use super::array_kind::ArrayKind; @@ -202,7 +202,7 @@ impl MakeExtensionOp } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()] + vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone()] } } @@ -231,7 +231,8 @@ impl HasConcrete fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { + [TypeArg::BoundedNat(n), ty] => { + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; Ok(GenericArrayConvert::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs index 17e2be157..44aabd382 100644 --- a/hugr-core/src/std_extensions/collections/array/array_discard.rs +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -141,7 +141,7 @@ impl MakeExtensionOp for GenericArrayDiscard { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![self.size.into(), self.elem_ty.clone()] } } @@ -164,7 +164,7 @@ impl HasConcrete for GenericArrayDiscardDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { + [TypeArg::BoundedNat(n), ty] if ty.copyable() => { Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 26ebb5b5f..b426d1961 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -15,6 +15,7 @@ use crate::extension::{ }; use crate::ops::{ExtensionOp, OpName}; use crate::type_row; +use crate::types::type_param::check_term_type; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Term, Type, TypeBound}; use crate::utils::Never; @@ -290,7 +291,7 @@ impl MakeExtensionOp for GenericArrayOp { use GenericArrayOpDef::{ _phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap, unpack, }; - let ty_arg = self.elem_ty.clone().into(); + let ty_arg = self.elem_ty.clone(); match self.def { discard_empty => { debug_assert_eq!( @@ -326,10 +327,11 @@ impl HasConcrete for GenericArrayOpDef { fn instantiate(&self, type_args: &[Term]) -> Result { let (ty, size) = match (self, type_args) { - (GenericArrayOpDef::discard_empty, [Term::Runtime(ty)]) => (ty.clone(), 0), - (_, [Term::BoundedNat(n), Term::Runtime(ty)]) => (ty.clone(), *n), + (GenericArrayOpDef::discard_empty, [ty]) => (ty.clone(), 0), + (_, [Term::BoundedNat(n), ty]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; + check_term_type(&ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; Ok(self.to_concrete(ty.clone(), size)) } diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 3fb121980..b9e735b2b 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -10,7 +10,7 @@ use crate::extension::simple_op::{ }; use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, OpName}; -use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_type}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; use super::array_kind::ArrayKind; @@ -147,7 +147,7 @@ impl MakeExtensionOp for GenericArrayRepeat { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![self.size.into(), self.elem_ty.clone()] } } @@ -170,7 +170,8 @@ impl HasConcrete for GenericArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { + [TypeArg::BoundedNat(n), ty] => { + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; Ok(GenericArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 5bd62466c..5e4a561b1 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -12,8 +12,8 @@ use crate::extension::simple_op::{ }; use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, OpName}; -use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_type}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound, TypeRV}; use super::array_kind::ArrayKind; @@ -62,29 +62,26 @@ impl GenericArrayScanDef { TypeParam::new_list_type(TypeBound::Linear), ]; let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let t1 = Type::new_var_use(1, TypeBound::Linear); - let t2 = Type::new_var_use(2, TypeBound::Linear); - let s = TypeRV::new_row_var_use(3, TypeBound::Linear); + let src_elem = Type::new_var_use(1, TypeBound::Linear); + let tgt_elem = Type::new_var_use(2, TypeBound::Linear); + let with_rest = |tys: Vec| { + TypeArg::concat_lists([tys.into(), TypeRV::new_row_var_use(3, TypeBound::Linear)]) + }; PolyFuncTypeRV::new( params, - FuncTypeBase::::new( - vec![ - AK::instantiate_ty(array_def, n.clone(), t1.clone()) - .expect("Array type instantiation failed") - .into(), - Type::new_function(FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - )) - .into(), - s.clone(), - ], - vec![ - AK::instantiate_ty(array_def, n, t2) - .expect("Array type instantiation failed") - .into(), - s, - ], + FuncValueType::new( + with_rest(vec![ + AK::instantiate_ty(array_def, n.clone(), src_elem.clone()) + .expect("Array type instantiation failed"), + Type::new_function(FuncValueType::new( + with_rest(vec![src_elem]), + with_rest(vec![tgt_elem.clone()]), + )), + ]), + with_rest(vec![ + AK::instantiate_ty(array_def, n, tgt_elem) + .expect("Array type instantiation failed"), + ]), ), ) .into() @@ -186,8 +183,8 @@ impl MakeExtensionOp for GenericArrayScan { fn type_args(&self) -> Vec { vec![ self.size.into(), - self.src_ty.clone().into(), - self.tgt_ty.clone().into(), + self.src_ty.clone(), + self.tgt_ty.clone(), TypeArg::new_list(self.acc_tys.clone().into_iter().map_into()), ] } @@ -214,21 +211,17 @@ impl HasConcrete for GenericArrayScanDef { match type_args { [ TypeArg::BoundedNat(n), - TypeArg::Runtime(src_ty), - TypeArg::Runtime(tgt_ty), + src_elem_ty, + tgt_elem_ty, TypeArg::List(acc_tys), ] => { - let acc_tys: Result<_, OpLoadError> = acc_tys - .iter() - .map(|acc_ty| match acc_ty { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs.into()), - }) - .collect(); + for ty in [src_elem_ty, tgt_elem_ty].into_iter().chain(acc_tys.iter()) { + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; + } Ok(GenericArrayScan::new( - src_ty.clone(), - tgt_ty.clone(), - acc_tys?, + src_elem_ty.clone(), + tgt_elem_ty.clone(), + acc_tys.clone(), *n, )) } diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs index 33828d9e0..7ab1b7e2a 100644 --- a/hugr-core/src/std_extensions/collections/array/array_value.rs +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -4,13 +4,13 @@ use std::hash::{Hash, Hasher}; use std::marker::PhantomData; use crate::extension::resolution::{ - ExtensionResolutionError, WeakExtensionRegistry, resolve_type_extensions, + ExtensionResolutionError, WeakExtensionRegistry, resolve_term_extensions, resolve_value_extensions, }; use crate::ops::Value; use crate::ops::constant::{TryHash, ValueName, maybe_hash_values}; -use crate::types::type_param::TypeArg; -use crate::types::{CustomCheckFailure, CustomType, Type}; +use crate::types::type_param::{TypeArg, check_term_type}; +use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound}; use super::array_kind::ArrayKind; @@ -94,7 +94,10 @@ impl GenericArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if *n as usize == self.values.len() => { + [TypeArg::BoundedNat(n), ty] + if *n as usize == self.values.len() + && check_term_type(ty, &TypeBound::Linear.into()).is_ok() => + { ty } _ => { @@ -125,7 +128,7 @@ impl GenericArrayValue { for val in &mut self.values { resolve_value_extensions(val, extensions)?; } - resolve_type_extensions(&mut self.typ, extensions) + resolve_term_extensions(&mut self.typ, extensions) } } diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 2a96f563b..53fe62951 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -72,7 +72,7 @@ pub trait GenericArrayOpBuilder: Dataflow { size: u64, input: Wire, ) -> Result, BuildError> { - let op = GenericArrayOpDef::::unpack.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::unpack.instantiate(&[size.into(), elem_ty])?; Ok(self.add_dataflow_op(op, vec![input])?.outputs().collect()) } /// Adds an array clone operation to the dataflow graph and return the wires @@ -148,7 +148,7 @@ pub trait GenericArrayOpBuilder: Dataflow { input: Wire, index: Wire, ) -> Result<(Wire, Wire), BuildError> { - let op = GenericArrayOpDef::::get.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::get.instantiate(&[size.into(), elem_ty])?; let [out, arr] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr(); Ok((out, arr)) } @@ -180,7 +180,7 @@ pub trait GenericArrayOpBuilder: Dataflow { index: Wire, value: Wire, ) -> Result { - let op = GenericArrayOpDef::::set.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::set.instantiate(&[size.into(), elem_ty])?; let [out] = self .add_dataflow_op(op, vec![input, index, value])? .outputs_arr(); @@ -214,7 +214,7 @@ pub trait GenericArrayOpBuilder: Dataflow { index1: Wire, index2: Wire, ) -> Result { - let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty])?; let [out] = self .add_dataflow_op(op, vec![input, index1, index2])? .outputs_arr(); @@ -244,7 +244,7 @@ pub trait GenericArrayOpBuilder: Dataflow { size: u64, input: Wire, ) -> Result { - let op = GenericArrayOpDef::::pop_left.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::pop_left.instantiate(&[size.into(), elem_ty])?; Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) } @@ -271,7 +271,7 @@ pub trait GenericArrayOpBuilder: Dataflow { size: u64, input: Wire, ) -> Result { - let op = GenericArrayOpDef::::pop_right.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::pop_right.instantiate(&[size.into(), elem_ty])?; Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) } @@ -292,7 +292,7 @@ pub trait GenericArrayOpBuilder: Dataflow { ) -> Result<(), BuildError> { self.add_dataflow_op( GenericArrayOpDef::::discard_empty - .instantiate(&[elem_ty.into()]) + .instantiate(&[elem_ty]) .unwrap(), [input], )?; diff --git a/hugr-core/src/std_extensions/collections/borrow_array.rs b/hugr-core/src/std_extensions/collections/borrow_array.rs index 534a28655..78c9ef534 100644 --- a/hugr-core/src/std_extensions/collections/borrow_array.rs +++ b/hugr-core/src/std_extensions/collections/borrow_array.rs @@ -8,7 +8,7 @@ use delegate::delegate; use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; use crate::ops::constant::{CustomConst, ValueName}; use crate::type_row; -use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_type}; use crate::types::{CustomCheckFailure, Term, Type, TypeBound, TypeName}; use crate::{Extension, Wire}; use crate::{ @@ -142,7 +142,7 @@ impl BArrayUnsafeOpDef { let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let elem_ty_var = Type::new_var_use(1, TypeBound::Linear); let array_ty: Type = def - .instantiate(vec![size_var, elem_ty_var.clone().into()]) + .instantiate(vec![size_var, elem_ty_var.clone()]) .unwrap() .into(); @@ -267,7 +267,7 @@ impl MakeExtensionOp for BArrayUnsafeOp { } fn type_args(&self) -> Vec { - vec![self.size.into(), self.elem_ty.clone().into()] + vec![self.size.into(), self.elem_ty.clone()] } } @@ -279,10 +279,11 @@ impl HasConcrete for BArrayUnsafeOpDef { type Concrete = BArrayUnsafeOp; fn instantiate(&self, type_args: &[TypeArg]) -> Result { - match type_args { - [Term::BoundedNat(n), Term::Runtime(ty)] => Ok(self.to_concrete(ty.clone(), *n)), - _ => Err(SignatureError::InvalidTypeArgs.into()), - } + let [Term::BoundedNat(n), ty] = type_args else { + return Err(SignatureError::InvalidTypeArgs.into()); + }; + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; + Ok(self.to_concrete(ty.clone(), *n)) } } @@ -557,8 +558,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { index1: Wire, index2: Wire, ) -> Result { - let op = - GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty])?; let [out] = self .add_dataflow_op(op, vec![input, index1, index2])? .outputs_arr(); @@ -660,7 +660,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { input: Wire, index: Wire, ) -> Result<(Wire, Wire), BuildError> { - let op = BArrayUnsafeOpDef::borrow.instantiate(&[size.into(), elem_ty.into()])?; + let op = BArrayUnsafeOpDef::borrow.instantiate(&[size.into(), elem_ty])?; let [arr, out] = self .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])? .outputs_arr(); @@ -688,7 +688,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { index: Wire, value: Wire, ) -> Result { - let op = BArrayUnsafeOpDef::r#return.instantiate(&[size.into(), elem_ty.into()])?; + let op = BArrayUnsafeOpDef::r#return.instantiate(&[size.into(), elem_ty])?; let [arr] = self .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index, value])? .outputs_arr(); @@ -712,8 +712,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { size: u64, input: Wire, ) -> Result<(), BuildError> { - let op = - BArrayUnsafeOpDef::discard_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?; + let op = BArrayUnsafeOpDef::discard_all_borrowed.instantiate(&[size.into(), elem_ty])?; self.add_dataflow_op(op.to_extension_op().unwrap(), vec![input])?; Ok(()) } @@ -729,7 +728,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { /// /// Returns an error if building the operation fails. fn add_new_all_borrowed(&mut self, elem_ty: Type, size: u64) -> Result { - let op = BArrayUnsafeOpDef::new_all_borrowed.instantiate(&[size.into(), elem_ty.into()])?; + let op = BArrayUnsafeOpDef::new_all_borrowed.instantiate(&[size.into(), elem_ty])?; let [arr] = self .add_dataflow_op(op.to_extension_op().unwrap(), vec![])? .outputs_arr(); @@ -761,7 +760,7 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { input: Wire, index: Wire, ) -> Result<(Wire, Wire), BuildError> { - let op = BArrayUnsafeOpDef::is_borrowed.instantiate(&[size.into(), elem_ty.into()])?; + let op = BArrayUnsafeOpDef::is_borrowed.instantiate(&[size.into(), elem_ty])?; let [arr, is_borrowed] = self .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])? .outputs_arr(); diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 495ab0e00..1d4855356 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -13,13 +13,14 @@ use strum::{EnumIter, EnumString, IntoStaticStr}; use crate::extension::prelude::{either_type, option_type, usize_t}; use crate::extension::resolution::{ - ExtensionResolutionError, WeakExtensionRegistry, resolve_type_extensions, + ExtensionResolutionError, WeakExtensionRegistry, resolve_term_extensions, resolve_value_extensions, }; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc}; use crate::ops::constant::{TryHash, ValueName, maybe_hash_values}; use crate::ops::{OpName, Value}; +use crate::types::type_param::{TermTypeError, check_term_type}; use crate::types::{Term, TypeName, TypeRowRV}; use crate::{ Extension, @@ -111,9 +112,12 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Runtime(ty)] = typ.args() else { + let [ty] = typ.args() else { return Err(error()); }; + if !ty.copyable() { + return Err(error()); + } // check all values are instances of the element type for v in &self.0 { @@ -136,7 +140,7 @@ impl CustomConst for ListValue { for val in &mut self.0 { resolve_value_extensions(val, extensions)?; } - resolve_type_extensions(&mut self.1, extensions) + resolve_term_extensions(&mut self.1, extensions) } } @@ -214,7 +218,10 @@ impl ListOp { input: impl Into, output: impl Into, ) -> PolyFuncTypeRV { - PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) + PolyFuncTypeRV::new( + vec![Self::TP], + FuncValueType::new(input.into(), output.into()), + ) } /// Returns the type of a generic list, associated with the element type parameter at index `idx`. @@ -323,7 +330,7 @@ pub fn list_type_def() -> &'static TypeDef { /// Get the type of a list of `elem_type` as a `CustomType`. #[must_use] pub fn list_custom_type(elem_type: Type) -> CustomType { - list_type_def().instantiate(vec![elem_type.into()]).unwrap() + list_type_def().instantiate(vec![elem_type]).unwrap() } /// Get the `Type` of a list of `elem_type`. @@ -349,9 +356,16 @@ impl MakeExtensionOp for ListOpInst { fn from_extension_op( ext_op: &ExtensionOp, ) -> Result { - let [Term::Runtime(ty)] = ext_op.args() else { - return Err(SignatureError::InvalidTypeArgs.into()); + let [ty] = ext_op.args() else { + return Err( + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs( + ext_op.args().len(), + 1, + )) + .into(), + ); }; + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; let name = ext_op.unqualified_id(); let Ok(op) = ListOp::from_str(name) else { return Err(OpLoadError::NotMember(name.to_string())); @@ -364,7 +378,7 @@ impl MakeExtensionOp for ListOpInst { } fn type_args(&self) -> Vec { - vec![self.elem_type.clone().into()] + vec![self.elem_type.clone()] } } @@ -405,7 +419,7 @@ mod test { fn test_list() { let list_def = list_type_def(); - let list_type = list_def.instantiate([usize_t().into()]).unwrap(); + let list_type = list_def.instantiate([usize_t()]).unwrap(); assert!(list_def.instantiate([3u64.into()]).is_err()); diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 007be5ecc..d46629e48 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -81,7 +81,7 @@ impl StaticArrayValue { typ: Type, contents: impl IntoIterator, ) -> Result { - if !TypeBound::Copyable.contains(typ.least_upper_bound()) { + if !typ.copyable() { return Err(CustomCheckFailure::Message(format!( "Failed to construct a StaticArrayValue with non-Copyable type: {typ}" )) @@ -295,7 +295,7 @@ impl MakeExtensionOp for StaticArrayOp { } fn type_args(&self) -> Vec { - vec![self.elem_ty.clone().into()] + vec![self.elem_ty.clone()] } } @@ -309,17 +309,17 @@ impl HasConcrete for StaticArrayOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { use TypeBound::Copyable; match type_args { - [arg] => { - let elem_ty = arg - .as_runtime() - .filter(|t| Copyable.contains(t.least_upper_bound())) - .ok_or(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { + [elem_ty] => { + let elem_ty = elem_ty.clone(); + if !elem_ty.copyable() { + return Err( + SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { type_: Box::new(Copyable.into()), - term: Box::new(arg.clone()), - }, - ))?; - + term: Box::new(elem_ty), + }) + .into(), + ); + } Ok(StaticArrayOp { def: *self, elem_ty, diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 7816e9b03..cf9285b0d 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -8,6 +8,7 @@ use crate::Wire; use crate::builder::{BuildError, Dataflow}; use crate::extension::TypeDefBound; use crate::ops::OpName; +use crate::types::type_param::{TermTypeError, check_term_type}; use crate::types::{CustomType, PolyFuncType, Signature, Type, TypeBound, TypeName}; use crate::{ Extension, @@ -55,9 +56,8 @@ impl MakeOpDef for PtrOpDef { } fn init_signature(&self, extension_ref: &Weak) -> SignatureFunc { - let ptr_t: Type = - ptr_custom_type(Type::new_var_use(0, TypeBound::Copyable), extension_ref).into(); let inner_t = Type::new_var_use(0, TypeBound::Copyable); + let ptr_t: Type = ptr_custom_type(inner_t.clone(), extension_ref).into(); let body = match self { PtrOpDef::New => Signature::new([inner_t], [ptr_t]), PtrOpDef::Read => Signature::new([ptr_t], [inner_t]), @@ -118,7 +118,7 @@ fn ptr_custom_type(ty: impl Into, extension_ref: &Weak) -> Cust let ty = ty.into(); CustomType::new( PTR_TYPE_ID, - [ty.into()], + [ty], EXTENSION_ID, TypeBound::Copyable, extension_ref, @@ -156,7 +156,7 @@ impl MakeExtensionOp for PtrOp { } fn type_args(&self) -> Vec { - vec![self.ty.clone().into()] + vec![self.ty.clone()] } } @@ -203,12 +203,15 @@ impl HasConcrete for PtrOpDef { type Concrete = PtrOp; fn instantiate(&self, type_args: &[TypeArg]) -> Result { - let ty = match type_args { - [TypeArg::Runtime(ty)] => ty.clone(), - _ => return Err(SignatureError::InvalidTypeArgs.into()), + let [ty] = type_args else { + return Err( + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1)) + .into(), + ); }; + check_term_type(ty, &TypeBound::Linear.into()).map_err(SignatureError::from)?; - Ok(self.with_type(ty)) + Ok(self.with_type(ty.clone())) } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 1f67edf59..18a1a69e7 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -3,41 +3,34 @@ mod check; pub mod custom; mod poly_func; -mod row_var; pub(crate) mod serialize; mod signature; pub mod type_param; pub mod type_row; -pub(crate) use row_var::MaybeRV; -pub use row_var::{NoRV, RowVariable}; -use crate::extension::resolution::{ - ExtensionCollectionError, WeakExtensionRegistry, collect_type_exts, -}; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; use crate::types::type_param::check_term_type; use crate::utils::display_list_with_separator; +use crate::{ + extension::resolution::{ExtensionCollectionError, WeakExtensionRegistry, collect_term_exts}, + types::type_param::TermTypeError, +}; pub use check::SumTypeError; pub use custom::CustomType; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; -pub use signature::{FuncTypeBase, FuncValueType, Signature}; +pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::{Term, TypeArg}; pub use type_row::{TypeRow, TypeRowRV}; -pub(crate) use poly_func::PolyFuncTypeBase; - -use itertools::FoldWhile::{Continue, Done}; use itertools::{Either, Itertools as _}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; -use crate::ops::AliasDecl; use self::type_param::TypeParam; -use self::type_row::TypeRowBase; /// A unique identifier for a type. pub type TypeName = SmolStr; @@ -159,16 +152,13 @@ impl TypeBound { } } -/// Calculate the least upper bound for an iterator of bounds -pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> TypeBound { - tags.fold_while(TypeBound::Copyable, |acc, new| { - if acc == TypeBound::Linear || new == TypeBound::Linear { - Done(TypeBound::Linear) - } else { - Continue(acc.union(new)) +pub(crate) fn least_upper_bound(bounds: impl IntoIterator) -> TypeBound { + for b in bounds { + if b == TypeBound::Linear { + return TypeBound::Linear; } - }) - .into_inner() + } + TypeBound::Copyable } #[derive(Clone, Debug, Eq, Serialize, Deserialize)] @@ -183,7 +173,81 @@ pub enum SumType { Unit { size: u8 }, /// General case of a Sum type. #[allow(missing_docs)] - General { rows: Vec }, + General(GeneralSum), +} + +/// General case of a [SumType]. Prefer using [SumType::new] and friends. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde( + into = "crate::types::serialize::SerGenSum", + from = "crate::types::serialize::SerGenSum" +)] +pub struct GeneralSum { + /// Each term here must be an instance of [Term::ListType]([Term::RuntimeType]), being + /// the elements of exactly one variant. (Thus, this explicitly forbids sums with an + /// unknown number of variants.) + // We could just have a single `rows: Term` here, an instance of + //`Term::ListType(Term::ListType(Term::RuntimeType))`, but then many functions like + // `len` and `variants` would be impossible. (We might want a separate "FixedAritySum" + // rust type supporting those, with try_from(SumType).) + rows: TypeRow, + /// Caches the bound. Falls back to [TypeBound::Linear] if any are not even runtime types + /// (this is checked in validation) + bound: TypeBound, +} + +fn sum_bound<'a>(rows: impl IntoIterator) -> TypeBound { + least_upper_bound(rows.into_iter().map(|t| { + if check_term_type(t, &Term::new_list_type(TypeBound::Copyable)).is_ok() { + TypeBound::Copyable + } else { + TypeBound::Linear + } + })) +} + +impl GeneralSum { + /// Initialize a new general sum type. (Note the number of variants is fixed.) + /// + /// # Panics + /// + /// If any element of `rows` is not a list (perhaps of variable length) of runtime types. + /// See [Self::try_new] or [Self::new_unchecked] for alternatives. + pub fn new(rows: impl Into) -> Self { + Self::try_new(rows).unwrap() + } + + /// Initialize a new general sum type, checking that each variant is a list of runtime types. + /// + /// # Errors + /// + /// If any element of `rows` is not a list (perhaps of variable length) of runtime types. + pub fn try_new(rows: impl Into) -> Result { + let rows = rows.into(); + for row in rows.iter() { + check_term_type(row, &Term::new_list_type(TypeBound::Linear))?; + } + Ok(Self::new_unchecked(rows)) + } + + /// Initialize a new general sum type without checking the variants. + pub fn new_unchecked(rows: impl Into) -> Self { + let rows: TypeRow = rows.into(); + let bound = sum_bound(rows.iter()); + Self { rows, bound } + } + + /// Returns an iterator over the variants, each an instance of [Term::ListType]`(`[Term::RuntimeType]`)` + pub fn iter(&self) -> impl Iterator { + self.rows.iter() + } + + /// Returns a mutable iterator over the variants, each should be an instance + /// of [Term::ListType]`(`[Term::RuntimeType]`)` but of course `iter_mut` allows + /// bypassing such checks. + pub fn iter_mut(&mut self) -> impl Iterator { + self.rows.iter_mut() + } } impl std::hash::Hash for SumType { @@ -210,9 +274,9 @@ impl std::fmt::Display for SumType { SumType::Unit { size } => { display_list_with_separator(itertools::repeat_n("[]", *size as usize), f, "+") } - SumType::General { rows } => match rows.len() { - 1 if rows[0].is_empty() => write!(f, "Unit"), - 2 if rows[0].is_empty() && rows[1].is_empty() => write!(f, "Bool"), + SumType::General(GeneralSum { rows, .. }) => match rows.len() { + 1 if rows[0].is_empty_list() => write!(f, "Unit"), + 2 if rows[0].is_empty_list() && rows[1].is_empty_list() => write!(f, "Bool"), _ => display_list_with_separator(rows.iter(), f, "+"), }, } @@ -221,17 +285,44 @@ impl std::fmt::Display for SumType { impl SumType { /// Initialize a new sum type. + /// + /// # Panics + /// + /// If any element of `variants` is not a list (perhaps of variable length) of runtime types. + /// See [Self::try_new] or [Self::new_unchecked] for alternatives. pub fn new(variants: impl IntoIterator) -> Self where - V: Into, + V: Into, { - let rows = variants.into_iter().map(Into::into).collect_vec(); + Self::try_new(variants).unwrap() + } + + /// Initialize a new sum type, checking that each variant is a list of runtime types. + /// + /// # Errors + /// + /// If any element of `variants` is not a list (perhaps of variable length) of runtime types. + /// See [Self::new_unchecked] for an alternative. + pub fn try_new>( + variants: impl IntoIterator, + ) -> Result { + let variants = variants.into_iter().map(V::into).collect_vec(); + let len = variants.len(); + if u8::try_from(len).is_ok() && variants.iter().all(Term::is_empty_list) { + Ok(Self::new_unary(len as u8)) + } else { + GeneralSum::try_new(variants).map(Self::General) + } + } - let len: usize = rows.len(); - if u8::try_from(len).is_ok() && rows.iter().all(TypeRowRV::is_empty) { + /// Initialize a new sum type without checking the variants. + pub fn new_unchecked(variants: impl Into) -> Self { + let variants = variants.into(); + let len: usize = variants.len(); + if u8::try_from(len).is_ok() && variants.iter().all(Term::is_empty_list) { Self::new_unary(len as u8) } else { - Self::General { rows } + Self::General(GeneralSum::new_unchecked(variants)) } } @@ -253,10 +344,10 @@ impl SumType { /// Report the tag'th variant, if it exists. #[must_use] - pub fn get_variant(&self, tag: usize) -> Option<&TypeRowRV> { + pub fn get_variant(&self, tag: usize) -> Option<&Term> { match self { - SumType::Unit { size } if tag < (*size as usize) => Some(TypeRV::EMPTY_TYPEROW_REF), - SumType::General { rows } => rows.get(tag), + SumType::Unit { size } if tag < (*size as usize) => Some(Type::EMPTY_TYPE_LIST_REF), + SumType::General(GeneralSum { rows, .. }) => rows.get(tag), _ => None, } } @@ -266,192 +357,110 @@ impl SumType { pub fn num_variants(&self) -> usize { match self { SumType::Unit { size } => *size as usize, - SumType::General { rows } => rows.len(), + SumType::General(GeneralSum { rows, .. }) => rows.len(), } } - /// Returns variant row if there is only one variant. + /// Returns variant row if there is only one variant + /// (will be an instance of [Term::ListType]([Term::RuntimeType]). #[must_use] - pub fn as_tuple(&self) -> Option<&TypeRowRV> { + pub fn as_tuple(&self) -> Option<&Term> { match self { - SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF), - SumType::General { rows } if rows.len() == 1 => Some(&rows[0]), + SumType::Unit { size } if *size == 1 => Some(Term::EMPTY_TYPE_LIST_REF), + SumType::General(GeneralSum { rows, .. }) if rows.len() == 1 => Some(&rows[0]), _ => None, } } - /// If the sum matches the convention of `Option[row]`, return the row. + /// If the sum matches the convention of `Option[row]`, return the row + /// (an instance of [Term::ListType]([Term::RuntimeType]). #[must_use] - pub fn as_option(&self) -> Option<&TypeRowRV> { + pub fn as_option(&self) -> Option<&Term> { match self { - SumType::Unit { size } if *size == 2 => Some(TypeRV::EMPTY_TYPEROW_REF), - SumType::General { rows } if rows.len() == 2 && rows[0].is_empty() => Some(&rows[1]), + SumType::Unit { size } if *size == 2 => Some(Term::EMPTY_TYPE_LIST_REF), + SumType::General(GeneralSum { rows, .. }) + if rows.len() == 2 && rows[0].is_empty_list() => + { + Some(&rows[1]) + } _ => None, } } - /// If a sum is an option of a single type, return the type. - #[must_use] - pub fn as_unary_option(&self) -> Option<&TypeRV> { - self.as_option() - .and_then(|row| row.iter().exactly_one().ok()) - } + // ALAN removing as_unary_option. + // "If a sum is an option of a single type, return the type. pub fn as_unary_option(&self) -> Option<&TypeRV>" + // But of course a TypeRV was not necessarily a single type... - /// Returns an iterator over the variants. - pub fn variants(&self) -> impl Iterator { + /// Returns an iterator over the variants, each an instance of [Term::ListType]`(`[Term::RuntimeType]`)` + pub fn variants(&self) -> impl Iterator { match self { SumType::Unit { size } => Either::Left(itertools::repeat_n( - TypeRV::EMPTY_TYPEROW_REF, + Term::EMPTY_TYPE_LIST_REF, *size as usize, )), - SumType::General { rows } => Either::Right(rows.iter()), + SumType::General(gs) => Either::Right(gs.iter()), } } -} -impl Transformable for SumType { - fn transform(&mut self, tr: &T) -> Result { + /// Returns the bound of this sum type. + /// + /// (Cached; will be [TypeBound::Linear] if any variant is not a list of runtime types.) + pub const fn bound(&self) -> TypeBound { match self { - SumType::Unit { .. } => Ok(false), - SumType::General { rows } => rows.transform(tr), - } - } -} - -impl From for TypeBase { - fn from(sum: SumType) -> Self { - match sum { - SumType::Unit { size } => TypeBase::new_unit_sum(size), - SumType::General { rows } => TypeBase::new_sum(rows), + SumType::Unit { .. } => TypeBound::Copyable, + SumType::General(GeneralSum { bound, .. }) => *bound, } } } -#[derive(Clone, Debug, Eq, Hash, derive_more::Display)] -/// Core types -pub enum TypeEnum { - /// An extension type. - // - // TODO optimise with `Box`? - // or some static version of this? - Extension(CustomType), - /// An alias of a type. - #[display("Alias({})", _0.name())] - Alias(AliasDecl), - /// A function type. - #[display("{_0}")] - Function(Box), - /// A type variable, defined by an index into a list of type parameters. - // - // We cache the TypeBound here (checked in validation) - #[display("#{_0}")] - Variable(usize, TypeBound), - /// `RowVariable`. Of course, this requires that `RV` has instances, [`NoRV`] doesn't. - #[display("RowVar({_0})")] - RowVar(RV), - /// Sum of types. - #[display("{_0}")] - Sum(SumType), -} - -impl TypeEnum { - /// The smallest type bound that covers the whole type. - fn least_upper_bound(&self) -> TypeBound { +impl Transformable for SumType { + fn transform(&mut self, tr: &T) -> Result { match self { - TypeEnum::Extension(c) => c.bound(), - TypeEnum::Alias(a) => a.bound, - TypeEnum::Function(_) => TypeBound::Copyable, - TypeEnum::Variable(_, b) => *b, - TypeEnum::RowVar(b) => b.bound(), - TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Copyable, - TypeEnum::Sum(SumType::General { rows }) => least_upper_bound( - rows.iter() - .flat_map(TypeRowRV::iter) - .map(TypeRV::least_upper_bound), - ), + SumType::Unit { .. } => Ok(false), + SumType::General(GeneralSum { rows, bound }) => { + let ch = rows.transform(tr)?; + if ch { + *bound = sum_bound(rows.iter()) + } + Ok(ch) + } } } } -#[derive(Clone, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize)] -#[display("{_0}")] -#[serde( - into = "serialize::SerSimpleType", - try_from = "serialize::SerSimpleType" -)] -/// A HUGR type - the valid types of [`EdgeKind::Value`] and [`EdgeKind::Const`] edges. -/// -/// Such an edge is valid if the ports on either end agree on the [Type]. -/// Types have an optional [`TypeBound`] which places limits on the valid -/// operations on a type. -/// -/// Examples: -/// ``` -/// # use hugr::types::{Type, TypeBound}; -/// # use hugr::type_row; -/// -/// let sum = Type::new_sum([type_row![], type_row![]]); -/// assert_eq!(sum.least_upper_bound(), TypeBound::Copyable); -/// ``` -/// -/// ``` -/// # use hugr::types::{Type, TypeBound, Signature}; -/// -/// let func_type: Type = Type::new_function(Signature::new_endo([])); -/// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable); -/// ``` -pub struct TypeBase(TypeEnum, TypeBound); - -/// The type of a single value, that can be sent down a wire -pub type Type = TypeBase; - -/// One or more types - either a single type, or a row variable -/// standing for multiple types. -pub type TypeRV = TypeBase; - -impl PartialEq> for TypeEnum { - fn eq(&self, other: &TypeEnum) -> bool { - match (self, other) { - (TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2, - (TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2, - (TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1 == f2, - (TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1 == i2 && b1 == b2, - (TypeEnum::RowVar(v1), TypeEnum::RowVar(v2)) => v1.as_rv() == v2.as_rv(), - (TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2, - _ => false, - } +impl From for Type { + fn from(sum: SumType) -> Self { + Type::RuntimeSum(sum) } } -impl PartialEq> for TypeBase { - fn eq(&self, other: &TypeBase) -> bool { - self.0 == other.0 && self.1 == other.1 - } -} +/// Legacy alias for Term. Will become deprecated at some point. +pub type Type = Term; +/// Legacy alias for Term. Will become deprecated at some point. +pub type TypeRV = Term; -impl TypeBase { +impl Type { /// An empty `TypeRow` or `TypeRowRV`. Provided here for convenience - pub const EMPTY_TYPEROW: TypeRowBase = TypeRowBase::::new(); - /// Unit type (empty tuple). - pub const UNIT: Self = Self( - TypeEnum::Sum(SumType::Unit { size: 1 }), - TypeBound::Copyable, - ); + pub const EMPTY_TYPEROW: TypeRow = TypeRow::new(); + /// Runtime unit type (empty tuple). + pub const UNIT: Self = Self::RuntimeSum(SumType::Unit { size: 1 }); + + const EMPTY_TYPE_LIST: Term = Term::List(vec![]); // or (EMPTY_TYPEROW)....? ALAN - const EMPTY_TYPEROW_REF: &'static TypeRowBase = &Self::EMPTY_TYPEROW; + const EMPTY_TYPE_LIST_REF: &'static Term = &Self::EMPTY_TYPE_LIST; /// Initialize a new function type. pub fn new_function(fun_ty: impl Into) -> Self { - Self::new(TypeEnum::Function(Box::new(fun_ty.into()))) + Self::RuntimeFunction(Box::new(fun_ty.into())) } /// Initialize a new tuple type by providing the elements. #[inline(always)] - pub fn new_tuple(types: impl Into) -> Self { + pub fn new_runtime_tuple(types: impl Into) -> Self { let row = types.into(); - match row.len() { - 0 => Self::UNIT, - _ => Self::new_sum([row]), + match row.is_empty_list() { + true => Self::UNIT, + false => Self::new_sum([row]), } } @@ -459,135 +468,23 @@ impl TypeBase { #[inline(always)] pub fn new_sum(variants: impl IntoIterator) -> Self where - R: Into, + R: Into, { - Self::new(TypeEnum::Sum(SumType::new(variants))) + Self::RuntimeSum(SumType::new(variants)) } /// Initialize a new custom type. - // TODO remove? Extensions/TypeDefs should just provide `Type` directly + // ALAN TODO remove? Doesn't really do anything now #[must_use] pub const fn new_extension(opaque: CustomType) -> Self { - let bound = opaque.bound(); - TypeBase(TypeEnum::Extension(opaque), bound) - } - - /// Initialize a new alias. - #[must_use] - pub fn new_alias(alias: AliasDecl) -> Self { - Self::new(TypeEnum::Alias(alias)) - } - - pub(crate) fn new(type_e: TypeEnum) -> Self { - let bound = type_e.least_upper_bound(); - Self(type_e, bound) + Type::RuntimeExtension(opaque) } /// New `UnitSum` with empty Tuple variants #[must_use] pub const fn new_unit_sum(size: u8) -> Self { // should be the only way to avoid going through SumType::new - Self(TypeEnum::Sum(SumType::new_unary(size)), TypeBound::Copyable) - } - - /// New use (occurrence) of the type variable with specified index. - /// `bound` must be exactly that with which the variable was declared - /// (i.e. as a [`Term::RuntimeType`]`(bound)`), which may be narrower - /// than required for the use. - #[must_use] - pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::Variable(idx, bound), bound) - } - - /// Report the least upper [`TypeBound`] - #[inline(always)] - pub const fn least_upper_bound(&self) -> TypeBound { - self.1 - } - - /// Report the component `TypeEnum`. - #[inline(always)] - pub const fn as_type_enum(&self) -> &TypeEnum { - &self.0 - } - - /// Report a mutable reference to the component `TypeEnum`. - #[inline(always)] - pub fn as_type_enum_mut(&mut self) -> &mut TypeEnum { - &mut self.0 - } - - /// Returns the inner [`SumType`] if the type is a sum. - pub fn as_sum(&self) -> Option<&SumType> { - match &self.0 { - TypeEnum::Sum(s) => Some(s), - _ => None, - } - } - - /// Returns the inner [`CustomType`] if the type is from an extension. - pub fn as_extension(&self) -> Option<&CustomType> { - match &self.0 { - TypeEnum::Extension(ct) => Some(ct), - _ => None, - } - } - - /// Report if the type is copyable - i.e.the least upper bound of the type - /// is contained by the copyable bound. - pub const fn copyable(&self) -> bool { - TypeBound::Copyable.contains(self.least_upper_bound()) - } - - /// Checks all variables used in the type are in the provided list - /// of bound variables, rejecting any [`RowVariable`]s if `allow_row_vars` is False; - /// and that for each [`CustomType`] the corresponding - /// [`TypeDef`] is in the [`ExtensionRegistry`] and the type arguments - /// [validate] and fit into the def's declared parameters. - /// - /// [RowVariable]: TypeEnum::RowVariable - /// [validate]: crate::types::type_param::TypeArg::validate - /// [TypeDef]: crate::extension::TypeDef - pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - // There is no need to check the components against the bound, - // that is guaranteed by construction (even for deserialization) - match &self.0 { - TypeEnum::Sum(SumType::General { rows }) => { - rows.iter().try_for_each(|row| row.validate(var_decls)) - } - TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there - TypeEnum::Alias(_) => Ok(()), - TypeEnum::Extension(custy) => custy.validate(var_decls), - // Function values may be passed around without knowing their arity - // (i.e. with row vars) as long as they are not called: - TypeEnum::Function(ft) => ft.validate(var_decls), - TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()), - TypeEnum::RowVar(rv) => rv.validate(var_decls), - } - } - - /// Applies a substitution to a type. - /// This may result in a row of types, if this [Type] is not really a single type but actually a row variable - /// Invariants may be confirmed by validation: - /// * If [`Type::validate`]`(false)` returns successfully, this method will return a Vec containing exactly one type - /// * If [`Type::validate`]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self) - /// return a Vec containing any number of [Type]s. These may (or not) pass [`Type::validate`] - fn substitute(&self, t: &Substitution) -> Vec { - match &self.0 { - TypeEnum::RowVar(rv) => rv.substitute(t), - TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], - TypeEnum::Variable(idx, bound) => { - let TypeArg::Runtime(ty) = t.apply_var(*idx, &((*bound).into())) else { - panic!("Variable was not a type - try validate() first") - }; - vec![ty.into_()] - } - TypeEnum::Extension(cty) => vec![TypeBase::new_extension(cty.substitute(t))], - TypeEnum::Function(bf) => vec![TypeBase::new_function(bf.substitute(t))], - TypeEnum::Sum(SumType::General { rows }) => { - vec![TypeBase::new_sum(rows.iter().map(|r| r.substitute(t)))] - } - } + Self::RuntimeSum(SumType::new_unary(size)) } /// Returns a registry with the concrete extensions used by this type. @@ -598,7 +495,7 @@ impl TypeBase { let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); - collect_type_exts(self, &mut used, &mut missing); + collect_term_exts(self, &mut used, &mut missing); if missing.is_empty() { Ok(used.try_into().expect("all extensions are present")) @@ -608,49 +505,16 @@ impl TypeBase { } } -impl Transformable for TypeBase { - fn transform(&mut self, tr: &T) -> Result { - match &mut self.0 { - TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => Ok(false), - TypeEnum::Extension(custom_type) => { - if let Some(nt) = tr.apply_custom(custom_type)? { - *self = nt.into_(); - Ok(true) - } else { - let args_changed = custom_type.args_mut().transform(tr)?; - if args_changed { - *self = Self::new_extension( - custom_type - .get_type_def(&custom_type.get_extension()?)? - .instantiate(custom_type.args())?, - ); - } - Ok(args_changed) - } - } - TypeEnum::Function(fty) => fty.transform(tr), - TypeEnum::Sum(sum_type) => { - let ch = sum_type.transform(tr)?; - self.1 = self.0.least_upper_bound(); - Ok(ch) - } - } - } -} - -impl Type { - fn substitute1(&self, s: &Substitution) -> Self { - let v = self.substitute(s); - let [r] = v.try_into().unwrap(); // No row vars, so every Type produces exactly one - r - } -} - impl TypeRV { /// Tells if this Type is a row variable, i.e. could stand for any number >=0 of Types #[must_use] pub fn is_row_var(&self) -> bool { - matches!(self.0, TypeEnum::RowVar(_)) + if let Term::Variable(var) = self + && let Term::ListType(bx) = &*var.cached_decl + { + return matches!(&**bx, Term::RuntimeType(_)); + } + false } /// New use (occurrence) of the row variable with specified index. @@ -661,61 +525,8 @@ impl TypeRV { /// [OpDef]: crate::extension::OpDef /// [FuncDefn]: crate::ops::FuncDefn #[must_use] - pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::RowVar(RowVariable(idx, bound)), bound) - } -} - -// ====== Conversions ====== -impl TypeBase { - /// (Fallibly) converts a `TypeBase` (parameterized, so may or may not be able - /// to contain [`RowVariable`]s) into a [Type] that definitely does not. - pub fn try_into_type(self) -> Result { - Ok(TypeBase( - match self.0 { - TypeEnum::Extension(e) => TypeEnum::Extension(e), - TypeEnum::Alias(a) => TypeEnum::Alias(a), - TypeEnum::Function(f) => TypeEnum::Function(f), - TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), - TypeEnum::RowVar(rv) => Err(rv.as_rv().clone())?, - TypeEnum::Sum(s) => TypeEnum::Sum(s), - }, - self.1, - )) - } -} - -impl TryFrom for Type { - type Error = RowVariable; - fn try_from(value: TypeRV) -> Result { - value.try_into_type() - } -} - -impl TypeBase { - /// A swiss-army-knife for any safe conversion of the type argument `RV1` - /// to/from [`NoRV`]/RowVariable/rust-type-variable. - fn into_(self) -> TypeBase - where - RV1: Into, - { - TypeBase( - match self.0 { - TypeEnum::Extension(e) => TypeEnum::Extension(e), - TypeEnum::Alias(a) => TypeEnum::Alias(a), - TypeEnum::Function(f) => TypeEnum::Function(f), - TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), - TypeEnum::RowVar(rv) => TypeEnum::RowVar(rv.into()), - TypeEnum::Sum(s) => TypeEnum::Sum(s), - }, - self.1, - ) - } -} - -impl From for TypeRV { - fn from(value: Type) -> Self { - value.into_() + pub fn new_row_var_use(idx: usize, bound: TypeBound) -> Self { + Self::new_var_use(idx, Term::new_list_type(bound)) } } @@ -745,36 +556,6 @@ impl<'a> Substitution<'a> { debug_assert_eq!(check_term_type(arg, decl), Ok(())); arg.clone() } - - fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec { - let arg = self - .0 - .get(idx) - .expect("Undeclared type variable - call validate() ?"); - debug_assert!(check_term_type(arg, &TypeParam::new_list_type(bound)).is_ok()); - match arg { - TypeArg::List(elems) => elems - .iter() - .map(|ta| { - match ta { - Term::Runtime(ty) => return ty.clone().into(), - Term::Variable(v) => { - if let Some(b) = v.bound_if_row_var() { - return TypeRV::new_row_var_use(v.index(), b); - } - } - _ => (), - } - panic!("Not a list of types - call validate() ?") - }) - .collect(), - Term::Runtime(ty) if matches!(ty.0, TypeEnum::RowVar(_)) => { - // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type - vec![ty.clone().into()] - } - _ => panic!("Not a type or list of types - call validate() ?"), - } - } } /// A transformation that can be applied to a [Type] or [`TypeArg`]. @@ -825,31 +606,6 @@ impl Transformable for [E] { } } -pub(crate) fn check_typevar_decl( - decls: &[TypeParam], - idx: usize, - cached_decl: &TypeParam, -) -> Result<(), SignatureError> { - match decls.get(idx) { - None => Err(SignatureError::FreeTypeVar { - idx, - num_decls: decls.len(), - }), - Some(actual) => { - // The cache here just mirrors the declaration. The typevar can be used - // anywhere expecting a kind *containing* the decl - see `check_type_arg`. - if actual == cached_decl { - Ok(()) - } else { - Err(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: Box::new(cached_decl.clone()), - actual: Box::new(actual.clone()), - }) - } - } - } -} - #[cfg(test)] pub(crate) mod test { use std::hash::{Hash, Hasher}; @@ -865,7 +621,7 @@ pub(crate) mod test { #[test] fn construct() { - let t: Type = Type::new_tuple(vec![ + let t: Type = Type::new_runtime_tuple(vec![ usize_t(), Type::new_function(Signature::new_endo([])), Type::new_extension(CustomType::new( @@ -876,12 +632,8 @@ pub(crate) mod test { // Dummy extension reference. &Weak::default(), )), - Type::new_alias(AliasDecl::new("my_alias", TypeBound::Copyable)), ]); - assert_eq!( - &t.to_string(), - "[usize, [] -> [], my_custom, Alias(my_alias)]" - ); + assert_eq!(&t.to_string(), "[usize, [] -> [], my_custom]"); } #[rstest::rstest] @@ -898,22 +650,26 @@ pub(crate) mod test { #[test] fn as_sum() { let t = Type::new_unit_sum(0); - assert!(t.as_sum().is_some()); + assert!(t.as_runtime_sum().is_some()); } #[test] fn as_option() { let opt = option_type([usize_t()]); - assert_eq!(opt.as_unary_option().unwrap().clone(), usize_t()); + assert_eq!(opt.as_option().unwrap(), &Term::new_list([usize_t()])); assert_eq!( - Type::new_unit_sum(2).as_sum().unwrap().as_unary_option(), + Type::new_unit_sum(3).as_runtime_sum().unwrap().as_option(), None ); + assert_eq!( + Type::new_unit_sum(2).as_runtime_sum().unwrap().as_option(), + Some(&Term::EMPTY_TYPE_LIST) // Yes, option of zero types is valid + ); assert_eq!( - Type::new_tuple(vec![usize_t()]) - .as_sum() + Type::new_runtime_tuple(vec![usize_t()]) + .as_runtime_sum() .unwrap() .as_option(), None @@ -931,19 +687,31 @@ pub(crate) mod test { #[test] fn sum_variants() { - let variants: Vec = vec![ + fn into_typerow(t: &Term) -> TypeRow { + t.clone().try_into().unwrap() + } + let variants: Vec = vec![ [TypeRV::UNIT].into(), - vec![TypeRV::new_row_var_use(0, TypeBound::Linear)].into(), + TypeRV::new_row_var_use(0, TypeBound::Linear), ]; let t = SumType::new(variants.clone()); assert_eq!(variants, t.variants().cloned().collect_vec()); let empty_rows = vec![TypeRV::EMPTY_TYPEROW; 3]; let sum_unary = SumType::new_unary(3); - let sum_general = SumType::General { - rows: empty_rows.clone(), - }; - assert_eq!(&empty_rows, &sum_unary.variants().cloned().collect_vec()); + assert_eq!( + &empty_rows, + &sum_unary.variants().map(into_typerow).collect_vec() + ); + + let sum_general = SumType::General(GeneralSum { + rows: empty_rows + .into_iter() + .map(Term::from) + .collect::>() + .into(), + bound: TypeBound::Copyable, + }); assert_eq!(sum_general, sum_unary); let mut hasher_general = std::hash::DefaultHasher::new(); @@ -1023,7 +791,7 @@ pub(crate) mod test { let coln = e.get_type(&COLN).unwrap(); let c_of_cpy = coln - .instantiate([Term::new_list([Type::from(cpy.clone()).into()])]) + .instantiate([Term::new_list([Type::from(cpy.clone())])]) .unwrap(); let mut t = Type::new_extension(c_of_cpy.clone()); @@ -1031,19 +799,19 @@ pub(crate) mod test { t.transform(&cpy_to_qb), Err(SignatureError::from(TermTypeError::TypeMismatch { type_: Box::new(TypeBound::Copyable.into()), - term: Box::new(qb_t().into()) + term: Box::new(qb_t()) })) ); let mut t = Type::new_extension( - coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone())).into()])]) + coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone()))])]) .unwrap(), ); assert_eq!( t.transform(&cpy_to_qb), Err(SignatureError::from(TermTypeError::TypeMismatch { type_: Box::new(TypeBound::Copyable.into()), - term: Box::new(mk_opt(qb_t()).into()) + term: Box::new(mk_opt(qb_t())) })) ); @@ -1053,25 +821,23 @@ pub(crate) mod test { (ct == &c_of_cpy).then_some(usize_t()) }); let mut t = Type::new_extension( - coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()).into(); 2])]) + coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()); 2])]) .unwrap(), ); assert_eq!(t.transform(&cpy_to_qb2), Ok(true)); assert_eq!( t, Type::new_extension( - coln.instantiate([Term::new_list([usize_t().into(), usize_t().into()])]) + coln.instantiate([Term::new_list([usize_t(), usize_t()])]) .unwrap() ) ); } - mod proptest { - + pub(crate) mod proptest { use crate::proptest::RecursionDepth; - use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum}; - use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV}; + use crate::types::{SumType, TypeRow}; use proptest::prelude::*; impl Arbitrary for super::SumType { @@ -1082,45 +848,46 @@ pub(crate) mod test { if depth.leaf() { any::().prop_map(Self::new_unary).boxed() } else { - vec(any_with::(depth), 0..3) + vec(any_with::(depth), 0..3) .prop_map(SumType::new) .boxed() } } } - - impl Arbitrary for TypeBase { - type Parameters = RecursionDepth; - type Strategy = BoxedStrategy; - fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - // We descend here, because a TypeEnum may contain a Type - let depth = depth.descend(); - prop_oneof![ - 1 => any::().prop_map(TypeBase::new_alias), - 1 => any_with::(depth.into()).prop_map(TypeBase::new_extension), - 1 => any_with::(depth).prop_map(TypeBase::new_function), - 1 => any_with::(depth).prop_map(TypeBase::from), - 1 => (any::(), any::()).prop_map(|(i,b)| TypeBase::new_var_use(i,b)), - // proptest_derive::Arbitrary's weight attribute requires a constant, - // rather than this expression, hence the manual impl: - RV::weight() => RV::arb().prop_map(|rv| TypeBase::new(TypeEnum::RowVar(rv))) - ] - .boxed() - } - } } } #[cfg(test)] pub(super) mod proptest_utils { use proptest::collection::vec; - use proptest::prelude::{Strategy, any_with}; - - use super::serialize::{TermSer, TypeArgSer, TypeParamSer}; - use super::type_param::Term; + use proptest::prelude::{BoxedStrategy, Strategy, any, any_with}; + use proptest::strategy::Union; use crate::proptest::RecursionDepth; - use crate::types::serialize::ArrayOrTermSer; + + use super::serialize::{ArrayOrTermSer, TermSer, TypeArgSer, TypeParamSer}; + use super::{CustomType, FuncValueType, SumType, TypeBound, type_param::Term}; + + pub(crate) fn any_type(depth: RecursionDepth) -> BoxedStrategy { + let strat = Union::new([ + (any::(), any::()) + .prop_map(|(i, b)| Term::new_var_use(i, b)) + .boxed(), + any_with::(depth.into()) + .prop_map(Term::new_extension) + .boxed(), + ]); + if depth.leaf() { + return strat.boxed(); + } + let depth = depth.descend(); + strat + .or(any_with::(depth) + .prop_map(Term::new_function) + .boxed()) + .or(any_with::(depth).prop_map(Term::from).boxed()) + .boxed() + } fn term_is_serde_type_arg(t: &Term) -> bool { let TermSer::TypeArg(arg) = TermSer::from(t.clone()) else { @@ -1132,13 +899,9 @@ pub(super) mod proptest_utils { | TypeArgSer::Tuple { elems: terms } | TypeArgSer::TupleConcat { tuples: terms } => terms.iter().all(term_is_serde_type_arg), TypeArgSer::Variable { v } => term_is_serde_type_param(&v.cached_decl), - TypeArgSer::Type { ty } => { - if let Some(cty) = ty.as_extension() { - cty.args().iter().all(term_is_serde_type_arg) - } else { - true - } - } // Do we need to inspect inside function types? sum types? + TypeArgSer::Type { ty } => Term::from(ty) + .as_extension() + .is_none_or(|cty| cty.args().iter().all(term_is_serde_type_arg)), // Do we need to inspect inside function types? sum types? TypeArgSer::BoundedNat { .. } | TypeArgSer::String { .. } | TypeArgSer::Bytes { .. } diff --git a/hugr-core/src/types/check.rs b/hugr-core/src/types/check.rs index 072da5884..8debe93ce 100644 --- a/hugr-core/src/types/check.rs +++ b/hugr-core/src/types/check.rs @@ -3,7 +3,7 @@ use thiserror::Error; use super::{Type, TypeRow}; -use crate::{extension::SignatureError, ops::Value}; +use crate::{extension::SignatureError, ops::Value, types::type_param::TermTypeError}; /// Errors that arise from typechecking constants #[derive(Clone, Debug, PartialEq, Error)] @@ -69,10 +69,17 @@ impl super::SumType { num_variants: self.num_variants(), })?; let variant: TypeRow = variant.clone().try_into().map_err(|e| { - let SignatureError::RowVarWhereTypeExpected { var } = e else { - panic!("Unexpected error") + let SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { term, .. }) = e + else { + panic!("Unexpected error {e}") }; - SumTypeError::VariantNotConcrete { tag, varidx: var.0 } + let Type::Variable(tv) = &*term else { + panic!("Unexpected term {term}"); + }; + SumTypeError::VariantNotConcrete { + tag, + varidx: tv.index(), + } })?; if variant.len() != val.len() { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index ea16ab958..12c22a9c8 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -5,71 +5,154 @@ use std::borrow::Cow; use itertools::Itertools; use crate::extension::SignatureError; -#[cfg(test)] -use { - super::proptest_utils::any_serde_type_param, - crate::proptest::RecursionDepth, - ::proptest::{collection::vec, prelude::*}, - proptest_derive::Arbitrary, -}; +use crate::types::{FuncValueType, Signature}; use super::Substitution; use super::type_param::{TypeArg, TypeParam, check_term_types}; -use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; -/// A polymorphic type scheme, i.e. of a [`FuncDecl`], [`FuncDefn`] or [`OpDef`]. -/// (Nodes/operations in the Hugr are not polymorphic.) +/// A polymorphic type scheme, for a function ([`FuncDecl`] or [`FuncDefn`]). +/// Number of inputs and outputs fixed (no row variables) so that [`Input`] +/// and [`Output`] nodes can be wired up. /// -/// [`FuncDecl`]: crate::ops::module::FuncDecl -/// [`FuncDefn`]: crate::ops::module::FuncDefn -/// [`OpDef`]: crate::extension::OpDef +/// [`FuncDefn`]: crate::ops::FuncDefn +/// [`FuncDecl`]: crate::ops::FuncDecl +/// [`Input`]: crate::ops::Input +/// [`Output`]: crate::ops::Output + #[derive( - Clone, PartialEq, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, + PartialEq, + Debug, + Default, + Eq, + Hash, + derive_more::Display, + serde::Serialize, + serde::Deserialize, )] -#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] #[display("{}{body}", self.display_params())] -pub struct PolyFuncTypeBase { +pub struct PolyFuncType { /// The declared type parameters, i.e., these must be instantiated with /// the same number of [`TypeArg`]s before the function can be called. This /// defines the indices used by variables inside the body. - #[cfg_attr(test, proptest(strategy = "vec(any_serde_type_param(params), 0..3)"))] params: Vec, /// Template for the function. May contain variables up to length of [`Self::params`] - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - body: FuncTypeBase, + body: Signature, } -/// The polymorphic type of a [`Call`]-able function ([`FuncDecl`] or [`FuncDefn`]). -/// Number of inputs and outputs fixed. -/// -/// [`Call`]: crate::ops::Call -/// [`FuncDefn`]: crate::ops::FuncDefn -/// [`FuncDecl`]: crate::ops::FuncDecl -pub type PolyFuncType = PolyFuncTypeBase; +macro_rules! poly_func_type_general { + ($pf: ty, $ft: ty) => { + impl From<$ft> for $pf { + fn from(body: $ft) -> Self { + Self { + params: vec![], + body, + } + } + } -/// The polymorphic type of an [`OpDef`], whose number of input and outputs -/// may vary according to how [`RowVariable`]s therein are instantiated. -/// -/// [`OpDef`]: crate::extension::OpDef -pub type PolyFuncTypeRV = PolyFuncTypeBase; + impl TryFrom<$pf> for $ft { + /// If this PolyfuncType(RV) is not monomorphic, fail with its binders + type Error = Vec; -// deriving Default leads to an impl that only applies for RV: Default -impl Default for PolyFuncTypeBase { - fn default() -> Self { - Self { - params: Default::default(), - body: Default::default(), + fn try_from(value: $pf) -> Result { + if value.params.is_empty() { + Ok(value.body) + } else { + Err(value.params) + } + } } - } -} -impl From> for PolyFuncTypeBase { - fn from(body: FuncTypeBase) -> Self { - Self { - params: vec![], - body, + impl $pf { + /// The type parameters, aka binders, over which this type is polymorphic + pub fn params(&self) -> &[TypeParam] { + &self.params + } + + /// The body of the type, a function type. + pub fn body(&self) -> &$ft { + &self.body + } + + /// Create a new `PolyFuncType`(`RV``) given the kinds of the variables it declares + /// and the underlying [$ft] + pub fn new(params: impl Into>, body: impl Into<$ft>) -> Self { + Self { + params: params.into(), + body: body.into(), + } + } + + /// Helper function for the Display implementation + fn display_params(&self) -> Cow<'static, str> { + if self.params.is_empty() { + return Cow::Borrowed(""); + } + let params_list = self + .params + .iter() + .enumerate() + .map(|(i, param)| format!("(#{i} : {param})")) + .join(" "); + Cow::Owned(format!("∀ {params_list}. ",)) + } + + /// Returns a mutable reference to the body of the function type. + pub fn body_mut(&mut self) -> &mut $ft { + &mut self.body + } + + /// Instantiates a PolyFuncType(RV) (with no free variables, + /// as ensured by [`Self::validate`]), into a monomorphic type. + /// + /// # Errors + /// If there is not exactly one [`TypeArg`] for each binder ([`Self::params`]), + /// or an arg does not fit into its corresponding [`TypeParam`] + pub fn instantiate(&self, args: &[TypeArg]) -> Result<$ft, SignatureError> { + // Check that args are applicable, and that we have a value for each binder, + // i.e. each possible free variable within the body. + check_term_types(args, &self.params)?; + Ok(self.body.substitute(&Substitution(args))) + } + + /// Validates this instance, checking that the types in the body are + /// wellformed with respect to the registry, and the type variables declared. + pub fn validate(&self) -> Result<(), SignatureError> { + self.body.validate(&self.params) + } } - } + }; +} + +poly_func_type_general!(PolyFuncType, Signature); + +/// The polymorphic type of an [`OpDef`], with variable number of inputs and outputs. +/// +/// The inputs and outputs may splice in variables ranging over lists of types, +/// which may be instantiated with different numbers of types. These will be fixed +/// for any given node. +/// +/// [`OpDef`]: crate::extension::OpDef +#[derive( + Clone, + PartialEq, + Debug, + Default, // This covers only the case (PolyFuncType) + Eq, + Hash, + derive_more::Display, + serde::Serialize, + serde::Deserialize, +)] +#[display("{}{body}", self.display_params())] +pub struct PolyFuncTypeRV { + /// The declared type parameters, i.e., these must be instantiated with + /// the same number of [`TypeArg`]s before the function can be called. This + /// defines the indices used by variables inside the body. + params: Vec, + /// Template for the function. May contain variables up to length of [`Self::params`] + body: FuncValueType, } impl From for PolyFuncTypeRV { @@ -81,77 +164,7 @@ impl From for PolyFuncTypeRV { } } -impl TryFrom> for FuncTypeBase { - /// If the `PolyFuncTypeBase` is not monomorphic, fail with its binders - type Error = Vec; - - fn try_from(value: PolyFuncTypeBase) -> Result { - if value.params.is_empty() { - Ok(value.body) - } else { - Err(value.params) - } - } -} - -impl PolyFuncTypeBase { - /// The type parameters, aka binders, over which this type is polymorphic - pub fn params(&self) -> &[TypeParam] { - &self.params - } - - /// The body of the type, a function type. - pub fn body(&self) -> &FuncTypeBase { - &self.body - } - - /// Create a new `PolyFuncTypeBase` given the kinds of the variables it declares - /// and the underlying [`FuncTypeBase`]. - pub fn new(params: impl Into>, body: impl Into>) -> Self { - Self { - params: params.into(), - body: body.into(), - } - } - - /// Instantiates an outer [`PolyFuncTypeBase`], i.e. with no free variables - /// (as ensured by [`Self::validate`]), into a monomorphic type. - /// - /// # Errors - /// If there is not exactly one [`TypeArg`] for each binder ([`Self::params`]), - /// or an arg does not fit into its corresponding [`TypeParam`] - pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { - // Check that args are applicable, and that we have a value for each binder, - // i.e. each possible free variable within the body. - check_term_types(args, &self.params)?; - Ok(self.body.substitute(&Substitution(args))) - } - - /// Validates this instance, checking that the types in the body are - /// wellformed with respect to the registry, and the type variables declared. - pub fn validate(&self) -> Result<(), SignatureError> { - self.body.validate(&self.params) - } - - /// Helper function for the Display implementation - fn display_params(&self) -> Cow<'static, str> { - if self.params.is_empty() { - return Cow::Borrowed(""); - } - let params_list = self - .params - .iter() - .enumerate() - .map(|(i, param)| format!("(#{i} : {param})")) - .join(" "); - Cow::Owned(format!("∀ {params_list}. ",)) - } - - /// Returns a mutable reference to the body of the function type. - pub fn body_mut(&mut self) -> &mut FuncTypeBase { - &mut self.body - } -} +poly_func_type_general!(PolyFuncTypeRV, FuncValueType); #[cfg(test)] pub(crate) mod test { @@ -159,24 +172,61 @@ pub(crate) mod test { use std::sync::Arc; use cool_asserts::assert_matches; + use proptest::collection::vec; + use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any_with}; use crate::Extension; use crate::extension::prelude::{bool_t, usize_t}; use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound}; + use crate::proptest::RecursionDepth; use crate::std_extensions::collections::array::{self, array_type_parametric}; use crate::std_extensions::collections::list; - use crate::types::signature::FuncTypeBase; - use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; + use crate::types::proptest_utils::any_serde_type_param; + use crate::types::type_param::{Term, TermTypeError, TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, MaybeRV, Signature, Term, Type, TypeBound, TypeName, TypeRV, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, + TypeName, }; - use super::PolyFuncTypeBase; + impl Arbitrary for PolyFuncType { + type Parameters = RecursionDepth; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + let params_strategy = vec(any_serde_type_param(depth), 0..3); + let body_strategy = any_with::(depth); + (params_strategy, body_strategy) + .prop_map(|(params, body)| PolyFuncType::new(params, body)) + .boxed() + } + type Strategy = BoxedStrategy; + } + + impl Arbitrary for PolyFuncTypeRV { + type Parameters = RecursionDepth; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + let params_strategy = vec(any_serde_type_param(depth), 0..3); + let body_strategy = any_with::(depth); + (params_strategy, body_strategy) + .prop_map(|(params, body)| PolyFuncTypeRV::new(params, body)) + .boxed() + } + type Strategy = BoxedStrategy; + } - impl PolyFuncTypeBase { + impl PolyFuncType { fn new_validated( params: impl Into>, - body: FuncTypeBase, + body: Signature, + ) -> Result { + let res = Self::new(params, body); + res.validate()?; + Ok(res) + } + } + + impl PolyFuncTypeRV { + fn new_validated( + params: impl Into>, + body: FuncValueType, ) -> Result { let res = Self::new(params, body); res.validate()?; @@ -187,19 +237,19 @@ pub(crate) mod test { #[test] fn test_opaque() -> Result<(), SignatureError> { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); - let tyvar = TypeArg::new_var_use(0, TypeBound::Linear.into()); + let tyvar = TypeArg::new_var_use(0, TypeBound::Linear); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); - let list_len = PolyFuncTypeBase::new_validated( + let list_len = PolyFuncType::new_validated( [TypeBound::Linear.into()], Signature::new(vec![list_of_var], vec![usize_t()]), )?; - let t = list_len.instantiate(&[usize_t().into()])?; + let t = list_len.instantiate(&[usize_t()])?; assert_eq!( t, Signature::new( vec![Type::new_extension( - list_def.instantiate([usize_t().into()]).unwrap() + list_def.instantiate([usize_t()]).unwrap() )], vec![usize_t()] ) @@ -211,26 +261,24 @@ pub(crate) mod test { #[test] fn test_mismatched_args() -> Result<(), SignatureError> { let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let ty_var = TypeArg::new_var_use(1, TypeBound::Linear.into()); + let ty_var = TypeArg::new_var_use(1, TypeBound::Linear); let type_params = [TypeParam::max_nat_type(), TypeBound::Linear.into()]; // Valid schema... let good_array = array_type_parametric(size_var.clone(), ty_var.clone())?; - let good_ts = PolyFuncTypeBase::new_validated( - type_params.clone(), - Signature::new_endo([good_array]), - )?; + let good_ts = + PolyFuncType::new_validated(type_params.clone(), Signature::new_endo([good_array]))?; // Sanity check (good args) - good_ts.instantiate(&[5u64.into(), usize_t().into()])?; + good_ts.instantiate(&[5u64.into(), usize_t()])?; - let wrong_args = good_ts.instantiate(&[usize_t().into(), 5u64.into()]); + let wrong_args = good_ts.instantiate(&[usize_t(), 5u64.into()]); assert_eq!( wrong_args, Err(SignatureError::TypeArgMismatch( TermTypeError::TypeMismatch { type_: Box::new(type_params[0].clone()), - term: Box::new(usize_t().into()), + term: Box::new(usize_t()), } )) ); @@ -253,7 +301,7 @@ pub(crate) mod test { &Arc::downgrade(&array::EXTENSION), )); let bad_ts = - PolyFuncTypeBase::new_validated(type_params.clone(), Signature::new_endo([bad_array])); + PolyFuncType::new_validated(type_params.clone(), Signature::new_endo([bad_array])); assert_eq!(bad_ts.err(), Some(arg_err)); Ok(()) @@ -262,7 +310,7 @@ pub(crate) mod test { #[test] fn test_misused_variables() -> Result<(), SignatureError> { // Variables in args have different bounds from variable declaration - let tv = TypeArg::new_var_use(0, TypeBound::Copyable.into()); + let tv = TypeArg::new_var_use(0, TypeBound::Copyable); let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo([Type::new_extension(list_def.instantiate([tv])?)]); for decl in [ @@ -270,7 +318,7 @@ pub(crate) mod test { Term::StringType, Term::new_tuple_type([TypeBound::Linear.into(), Term::max_nat_type()]), ] { - let invalid_ts = PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone()); + let invalid_ts = PolyFuncType::new_validated([decl.clone()], body_type.clone()); assert_eq!( invalid_ts.err(), Some(SignatureError::TypeVarDoesNotMatchDeclaration { @@ -280,7 +328,7 @@ pub(crate) mod test { ); } // Variable not declared at all - let invalid_ts = PolyFuncTypeBase::new_validated([], body_type); + let invalid_ts = PolyFuncType::new_validated([], body_type); assert_eq!( invalid_ts.err(), Some(SignatureError::FreeTypeVar { @@ -315,7 +363,7 @@ pub(crate) mod test { reg.validate().unwrap(); let make_scheme = |tp: TypeParam| { - PolyFuncTypeBase::new_validated( + PolyFuncType::new_validated( [tp.clone()], Signature::new_endo([Type::new_extension(CustomType::new( TYPE_NAME, @@ -375,12 +423,9 @@ pub(crate) mod test { fn row_variables_bad_schema() { // Mismatched TypeBound (Copyable vs Any) let decl = Term::new_list_type(TP_ANY); - let e = PolyFuncTypeBase::new_validated( + let e = PolyFuncTypeRV::new_validated( [decl.clone()], - FuncValueType::new( - vec![usize_t()], - vec![TypeRV::new_row_var_use(0, TypeBound::Copyable)], - ), + FuncValueType::new([usize_t()], Term::new_row_var_use(0, TypeBound::Copyable)), ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { @@ -388,7 +433,7 @@ pub(crate) mod test { assert_eq!(*cached, TypeParam::new_list_type(TypeBound::Copyable)); }); // Declared as row variable, used as type variable - let e = PolyFuncTypeBase::new_validated( + let e = PolyFuncType::new_validated( [decl.clone()], Signature::new_endo([Type::new_var_use(0, TypeBound::Linear)]), ) @@ -401,18 +446,21 @@ pub(crate) mod test { #[test] fn row_variables() { - let rty = TypeRV::new_row_var_use(0, TypeBound::Linear); - let pf = PolyFuncTypeBase::new_validated( + let rty = Term::new_row_var_use(0, TypeBound::Linear); + let pf = PolyFuncTypeRV::new_validated( [TypeParam::new_list_type(TP_ANY)], - FuncValueType::new([usize_t().into(), rty.clone()], [TypeRV::new_tuple([rty])]), + FuncValueType::new( + Term::concat_lists([Term::new_list([usize_t()]), rty.clone()]), + [Term::new_runtime_tuple(rty)], + ), ) .unwrap(); fn seq2() -> Vec { - vec![usize_t().into(), bool_t().into()] + vec![usize_t(), bool_t()] } - pf.instantiate(&[usize_t().into()]).unwrap_err(); - pf.instantiate(&[Term::new_list([usize_t().into(), Term::new_list(seq2())])]) + pf.instantiate(&[usize_t()]).unwrap_err(); + pf.instantiate(&[Term::new_list([usize_t(), Term::new_list(seq2())])]) .unwrap_err(); let t2 = pf.instantiate(&[Term::new_list(seq2())]).unwrap(); @@ -420,18 +468,18 @@ pub(crate) mod test { t2, Signature::new( vec![usize_t(), usize_t(), bool_t()], - vec![Type::new_tuple(vec![usize_t(), bool_t()])] + vec![Type::new_runtime_tuple(vec![usize_t(), bool_t()])] ) ); } #[test] fn row_variables_inner() { - let inner_fty = Type::new_function(FuncValueType::new_endo([TypeRV::new_row_var_use( + let inner_fty = Type::new_function(FuncValueType::new_endo(Term::new_row_var_use( 0, TypeBound::Copyable, - )])); - let pf = PolyFuncTypeBase::new_validated( + ))); + let pf = PolyFuncType::new_validated( [Term::new_list_type(TypeBound::Copyable)], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), ) @@ -439,11 +487,7 @@ pub(crate) mod test { let inner3 = Type::new_function(Signature::new_endo([usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate(&[Term::new_list([ - usize_t().into(), - bool_t().into(), - usize_t().into(), - ])]) + .instantiate(&[Term::new_list([usize_t(), bool_t(), usize_t()])]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/row_var.rs b/hugr-core/src/types/row_var.rs deleted file mode 100644 index 086ab7b07..000000000 --- a/hugr-core/src/types/row_var.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! Classes for row variables (i.e. Type variables that can stand for multiple types) - -use super::type_param::TypeParam; -use super::{Substitution, TypeBase, TypeBound, check_typevar_decl}; -use crate::extension::SignatureError; - -#[cfg(test)] -use proptest::prelude::{BoxedStrategy, Strategy, any}; -/// Describes a row variable - a type variable bound with a list of runtime types -/// of the specified bound (checked in validation) -// The serde derives here are not used except as markers -// so that other types containing this can also #derive-serde the same way. -#[derive( - Clone, Debug, Eq, Hash, PartialEq, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -#[display("{_0}")] -pub struct RowVariable(pub usize, pub TypeBound); - -// Note that whilst 'pub' this is not re-exported outside private module `row_var` -// so is effectively sealed. -pub trait MaybeRV: - Clone - + std::fmt::Debug - + std::fmt::Display - + From - + Into - + Eq - + PartialEq - + 'static -{ - fn as_rv(&self) -> &RowVariable; - fn try_from_rv(rv: RowVariable) -> Result; - fn bound(&self) -> TypeBound; - fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError>; - #[allow(private_interfaces)] - fn substitute(&self, s: &Substitution) -> Vec>; - #[cfg(test)] - fn weight() -> u32 { - 1 - } - #[cfg(test)] - fn arb() -> BoxedStrategy; -} - -/// Has no instances - used as parameter to [`Type`] to rule out the possibility -/// of there being any [`TypeEnum::RowVar`]s -/// -/// [`TypeEnum::RowVar`]: super::TypeEnum::RowVar -/// [`Type`]: super::Type -// The serde derives here are not used except as markers -// so that other types containing this can also #derive-serde the same way. -#[derive( - Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -pub enum NoRV {} - -impl From for RowVariable { - fn from(value: NoRV) -> Self { - match value {} - } -} - -impl MaybeRV for RowVariable { - fn as_rv(&self) -> &RowVariable { - self - } - - fn bound(&self) -> TypeBound { - self.1 - } - - fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - check_typevar_decl(var_decls, self.0, &TypeParam::new_list_type(self.1)) - } - - #[allow(private_interfaces)] - fn substitute(&self, s: &Substitution) -> Vec> { - s.apply_rowvar(self.0, self.1) - } - - fn try_from_rv(rv: RowVariable) -> Result { - Ok(rv) - } - - #[cfg(test)] - fn arb() -> BoxedStrategy { - (any::(), any::()) - .prop_map(|(i, b)| Self(i, b)) - .boxed() - } -} - -impl MaybeRV for NoRV { - fn as_rv(&self) -> &RowVariable { - match *self {} - } - - fn bound(&self) -> TypeBound { - match *self {} - } - - fn validate(&self, _var_decls: &[TypeParam]) -> Result<(), SignatureError> { - match *self {} - } - - #[allow(private_interfaces)] - fn substitute(&self, _s: &Substitution) -> Vec> { - match *self {} - } - - fn try_from_rv(rv: RowVariable) -> Result { - Err(rv) - } - - #[cfg(test)] - fn weight() -> u32 { - 0 - } - - #[cfg(test)] - fn arb() -> BoxedStrategy { - any::() - .prop_map(|_| panic!("Should be ruled out by weight==0")) - .boxed() - } -} diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index eeff6f2e1..962c114d8 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -1,16 +1,18 @@ use std::sync::Arc; use ordered_float::OrderedFloat; +use serde::Serialize; +use serde_with::{DeserializeAs, SerializeAs, serde_as}; -use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeBase, TypeBound, TypeEnum}; +use super::{FuncValueType, SumType, TypeBound}; use super::custom::CustomType; use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::AliasDecl; -use crate::types::type_param::{TermVar, UpperBound}; -use crate::types::{Term, Type}; +use crate::types::type_param::{SeqPart, TermTypeError, TermVar, UpperBound}; +use crate::types::{GeneralSum, Term, Type, sum_bound}; #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] @@ -25,45 +27,53 @@ pub(crate) enum SerSimpleType { R { i: usize, b: TypeBound }, } -impl From> for SerSimpleType { - fn from(value: TypeBase) -> Self { +/// For the things that used to be supported as Types +impl TryFrom for SerSimpleType { + type Error = SignatureError; + fn try_from(value: Type) -> Result { if value == qb_t() { - return SerSimpleType::Q; + return Ok(SerSimpleType::Q); } if value == usize_t() { - return SerSimpleType::I; + return Ok(SerSimpleType::I); } - match value.0 { - TypeEnum::Extension(o) => SerSimpleType::Opaque(o), - TypeEnum::Alias(a) => SerSimpleType::Alias(a), - TypeEnum::Function(sig) => SerSimpleType::G(sig), - TypeEnum::Variable(i, b) => SerSimpleType::V { i, b }, - TypeEnum::RowVar(rv) => { - let RowVariable(idx, bound) = rv.as_rv(); - SerSimpleType::R { i: *idx, b: *bound } + match value { + Term::RuntimeExtension(o) => Ok(SerSimpleType::Opaque(o)), + //TypeEnum::Alias(a) => SerSimpleType::Alias(a), + Term::RuntimeFunction(sig) => Ok(SerSimpleType::G(sig)), + Term::Variable(tv) => { + let i = tv.index(); + match &*tv.cached_decl { + Term::RuntimeType(b) => return Ok(SerSimpleType::V { i, b: *b }), + Term::ListType(b) => { + if let Term::RuntimeType(b) = &**b { + return Ok(SerSimpleType::R { i, b: *b }); + } + } + _ => (), + }; + Err(SignatureError::TypeArgMismatch( + TermTypeError::InvalidValue(tv.cached_decl), + )) } - TypeEnum::Sum(st) => SerSimpleType::Sum(st), + Term::RuntimeSum(st) => Ok(SerSimpleType::Sum(st)), + _ => Err(SignatureError::InvalidTypeArgs), } } } -impl TryFrom for TypeBase { - type Error = SignatureError; - fn try_from(value: SerSimpleType) -> Result { - Ok(match value { - SerSimpleType::Q => qb_t().into_(), - SerSimpleType::I => usize_t().into_(), - SerSimpleType::G(sig) => TypeBase::new_function(*sig), +impl From for Term { + fn from(value: SerSimpleType) -> Self { + match value { + SerSimpleType::Q => qb_t(), + SerSimpleType::I => usize_t(), + SerSimpleType::G(sig) => Type::new_function(*sig), SerSimpleType::Sum(st) => st.into(), - SerSimpleType::Opaque(o) => TypeBase::new_extension(o), - SerSimpleType::Alias(a) => TypeBase::new_alias(a), - SerSimpleType::V { i, b } => TypeBase::new_var_use(i, b), - // We can't use new_row_var because that returns TypeRV not TypeBase. - SerSimpleType::R { i, b } => TypeBase::new(TypeEnum::RowVar( - RV::try_from_rv(RowVariable(i, b)) - .map_err(|var| SignatureError::RowVarWhereTypeExpected { var })?, - )), - }) + SerSimpleType::Opaque(o) => Type::new_extension(o), + SerSimpleType::Alias(_) => todo!("alias?"), + SerSimpleType::V { i, b } => Type::new_var_use(i, b), + SerSimpleType::R { i, b } => Type::new_row_var_use(i, b), + } } } @@ -87,7 +97,7 @@ pub(super) enum TypeParamSer { #[serde(tag = "tya")] pub(super) enum TypeArgSer { Type { - ty: Type, + ty: SerSimpleType, }, BoundedNat { n: u64, @@ -138,7 +148,11 @@ impl From for TermSer { Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), Term::ConstType(ty) => TermSer::TypeParam(TypeParamSer::ConstType { ty: *ty }), - Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), + Term::RuntimeFunction(_) | Term::RuntimeExtension(_) | Term::RuntimeSum(_) => { + TermSer::TypeArg(TypeArgSer::Type { + ty: value.try_into().unwrap(), + }) + } Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { params: (*params).into(), }), @@ -170,7 +184,7 @@ impl From for Term { TypeParamSer::ConstType { ty } => Term::ConstType(Box::new(ty)), }, TermSer::TypeArg(arg) => match arg { - TypeArgSer::Type { ty } => Term::Runtime(ty), + TypeArgSer::Type { ty } => Term::from(ty), TypeArgSer::BoundedNat { n } => Term::BoundedNat(n), TypeArgSer::String { arg } => Term::String(arg), TypeArgSer::Bytes { value } => Term::Bytes(value), @@ -185,6 +199,55 @@ impl From for Term { } } +/// Helper for use with [serde_with::serde_as] to serialize a [Term] +/// that is an instance of [`Term::ListType`]([`Term::RuntimeType`](...)) +/// as a list of types + row variables +pub(crate) enum SerTypeRowRV {} + +/// Helper for use with [serde_with::serde_as] to serialize a [Term] +/// that is an instance of [`Term::RuntimeType`](...) +/// as a json [SerSimpleType] +pub(crate) enum SerType {} + +impl SerializeAs for SerType { + fn serialize_as(ty: &Term, s: S) -> Result { + SerSimpleType::try_from(ty.clone()).unwrap().serialize(s) + } +} + +impl<'de> DeserializeAs<'de, Term> for SerType { + fn deserialize_as>(deser: D) -> Result { + let sertype: SerSimpleType = serde::Deserialize::deserialize(deser)?; + Ok(sertype.into()) + } +} + +/// Helper to (de)serialize GeneralSums without storing the (cached) bound +#[serde_as] +#[derive(serde::Serialize, serde::Deserialize)] +pub(super) struct SerGenSum { + #[serde_as(as = "Vec")] + rows: Vec, +} + +impl From for SerGenSum { + fn from(value: GeneralSum) -> Self { + Self { + rows: value.rows.into_owned(), + } + } +} + +impl From for GeneralSum { + fn from(value: SerGenSum) -> Self { + let bound = sum_bound(value.rows.iter()); + Self { + rows: value.rows.into(), + bound, + } + } +} + /// Helper type that serialises lists as JSON arrays for compatibility. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(untagged)] @@ -233,3 +296,36 @@ mod base64 { .map_err(serde::de::Error::custom) } } + +impl serde_with::SerializeAs for SerTypeRowRV { + fn serialize_as(source: &Term, serializer: S) -> Result { + let items: Vec = source + .clone() + .into_list_parts() + .map(|part| match part { + SeqPart::Item(t) => { + let s = SerSimpleType::try_from(t).unwrap(); + assert!(!matches!(s, SerSimpleType::R { .. })); + s + } + SeqPart::Splice(t) => { + let s = SerSimpleType::try_from(t).unwrap(); + assert!(matches!(s, SerSimpleType::R { .. })); + s + } + }) + .collect(); + items.serialize(serializer) + } +} + +impl<'de> serde_with::DeserializeAs<'de, Term> for SerTypeRowRV { + fn deserialize_as>(deser: D) -> Result { + let items: Vec = serde::Deserialize::deserialize(deser)?; + let list_parts = items.into_iter().map(|s| match s { + SerSimpleType::R { i, b } => SeqPart::Splice(Term::new_row_var_use(i, b)), + s => SeqPart::Item(Term::from(s)), + }); + Ok(Term::new_list_from_parts(list_parts)) + } +} diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 4fc16a693..7e5ac4848 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -1,70 +1,162 @@ //! Abstract and concrete Signature types. use itertools::Either; +use serde_with::serde_as; -use std::borrow::Cow; use std::fmt::{self, Display}; use super::type_param::TypeParam; -use super::type_row::TypeRowBase; -use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeRow, TypeTransformer, -}; +use super::{Substitution, Transformable, Type, TypeRow, TypeTransformer}; use crate::core::PortIndex; use crate::extension::resolution::{ ExtensionCollectionError, WeakExtensionRegistry, collect_signature_exts, }; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::types::type_param::{TermTypeError, check_term_type}; +use crate::types::{Term, TypeBound}; use crate::{Direction, IncomingPort, OutgoingPort, Port}; -#[cfg(test)] -use {crate::proptest::RecursionDepth, proptest::prelude::*, proptest_derive::Arbitrary}; - -#[derive(Clone, Debug, Eq, Hash, serde::Serialize, serde::Deserialize)] -#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] -/// Base type for listing inputs and output types. -/// -/// The exact semantics depend on the use case: -/// - If `ROWVARS=`[`NoRV`], describes the edges required to/from a node or inside a [`FuncDefn`]. -/// - If `ROWVARS=`[`RowVariable`], describes the type of a higher-order [`function value`] or the inputs/outputs from an `OpDef`. +/// The concept of "signature" in the spec - a list of inputs and outputs being +/// the edges required to/from a node or within a [`FuncDefn`]. /// -/// `ROWVARS` specifies whether the type lists may contain [`RowVariable`]s or not. -/// -/// [`function value`]: crate::ops::constant::Value::Function /// [`FuncDefn`]: crate::ops::FuncDefn -pub struct FuncTypeBase { +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +pub struct Signature { /// Value inputs of the function. - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - pub input: TypeRowBase, + /// + /// Each *element* must [check_term_type] against [Term::RuntimeType] of + /// [TypeBound::Linear], hence the arity is fixed as the length of the row. + pub input: TypeRow, /// Value outputs of the function. - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - pub output: TypeRowBase, + /// + /// /// Each *element* must [check_term_type] against [Term::RuntimeType] of + /// [TypeBound::Linear], hence the arity is fixed as the length of the row. + pub output: TypeRow, } -/// The concept of "signature" in the spec - the edges required to/from a node -/// or within a [`FuncDefn`], also the target (value) of a call (static). +/// A function value whose number of inputs and outputs may be unknown. /// -/// [`FuncDefn`]: crate::ops::FuncDefn -pub type Signature = FuncTypeBase; - -/// A function that may contain [`RowVariable`]s and thus has potentially-unknown arity; -/// used for [`OpDef`]'s and passable as a value round a Hugr (see [`Type::new_function`]) -/// but not a valid node type. +/// ([FuncValueType::input] and [FuncValueType::output] are arbitrary [Term]s.) +/// +/// Each must type-check against [Term::ListType]`(`Term::RuntimeType`(`[TypeBound::Linear]`))` +/// so can include variables containing unknown numbers of types. +/// +/// Used for [`OpDef`]'s and may be used as a type (of function-pointer values) +/// on wires of a Hugr (see [`Type::new_function`]) but not a valid node type. /// /// [`OpDef`]: crate::extension::OpDef -pub type FuncValueType = FuncTypeBase; +#[serde_as] +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +pub struct FuncValueType { + /// Value inputs of the function. + /// + /// Must [check_term_type] against [Term::ListType] of [Term::RuntimeType], + /// hence there may be variables ranging over lists of types, and so the + /// arity may vary according to the length of list with whose those variables + /// are instantiated. + #[serde_as(as = "crate::types::serialize::SerTypeRowRV")] + pub input: Term, + /// Value outputs of the function. + /// + /// Must [check_term_type] against [Term::ListType] of [Term::RuntimeType], + /// hence there may be variables ranging over lists of types, and so the + /// arity may vary according to the length of list with whose those variables + /// are instantiated. + #[serde_as(as = "crate::types::serialize::SerTypeRowRV")] + pub output: Term, +} -impl FuncTypeBase { - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { +impl Default for FuncValueType { + fn default() -> Self { Self { - input: self.input.substitute(tr), - output: self.output.substitute(tr), + input: Term::new_list(Vec::new()), + output: Term::new_list(Vec::new()), } } +} - /// Create a new signature with specified inputs and outputs. - pub fn new(input: impl Into>, output: impl Into>) -> Self { +macro_rules! func_type_general { + ($ft: ty, $io: ty) => { + impl Transformable for $ft { + fn transform(&mut self, tr: &T) -> Result { + // TODO handle extension sets? + Ok(self.input.transform(tr)? | self.output.transform(tr)?) + } + } + + impl Display for $ft { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.input.fmt(f)?; + f.write_str(" -> ")?; + self.output.fmt(f) + } + } + + impl $ft { + #[inline] + /// Returns a row of the value inputs of the function. + #[must_use] + pub fn input(&self) -> &$io { + &self.input + } + + #[inline] + /// Returns a row of the value outputs of the function. + #[must_use] + pub fn output(&self) -> &$io { + &self.output + } + + #[inline] + /// Returns a tuple with the input and output rows of the function. + #[must_use] + pub fn io(&self) -> (&$io, &$io) { + (&self.input, &self.output) + } + + pub(crate) fn substitute(&self, tr: &Substitution) -> Self { + Self { + input: self.input.substitute(tr), + output: self.output.substitute(tr), + } + } + } + }; +} + +func_type_general!(Signature, TypeRow); +func_type_general!(FuncValueType, Term); + +impl FuncValueType { + /// Create a new FuncValueType with specified inputs and outputs. + /// + /// # Panics + /// + /// If the inputs, or outputs, are not each lists of runtime types. + /// See [Self::try_new] and [Self::new_unchecked] for alternatives. + pub fn new(input: impl Into, output: impl Into) -> Self { + Self::try_new(input, output).unwrap() + } + + /// Create a new FuncValueType with specified inputs and outputs. + /// + /// # Errors + /// + /// If the inputs, or outputs, are not each lists of runtime types. + /// See [Self::new_unchecked]. + pub fn try_new(input: impl Into, output: impl Into) -> Result { + let input = input.into(); + let output = output.into(); + check_term_type(&input, &Term::new_list_type(TypeBound::Linear))?; + check_term_type(&output, &Term::new_list_type(TypeBound::Linear))?; + Ok(Self::new_unchecked(input, output)) + } + + /// Create a new FuncValueType with specified inputs and outputs. + /// No checks are performed as to whether the inputs and outputs are appropriate + /// (i.e. lists of runtime types). + pub fn new_unchecked(input: impl Into, output: impl Into) -> Self { Self { input: input.into(), output: output.into(), @@ -73,43 +165,147 @@ impl FuncTypeBase { /// Create a new signature with the same input and output types (signature of an endomorphic /// function). - pub fn new_endo(row: impl Into>) -> Self { + /// + /// # Panics + /// + /// If the row is not a list of runtime types. + /// See [Self::try_new_endo] and [Self::new_endo_unchecked] for alternatives. + pub fn new_endo(row: impl Into) -> Self { + Self::try_new_endo(row).unwrap() + } + + /// Create a new signature with the same input and output types (signature of an endomorphic + /// function). + /// + /// # Errors + /// + /// If the row is not a list of runtime types. + pub fn try_new_endo(row: impl Into) -> Result { let row = row.into(); - Self::new(row.clone(), row) + check_term_type(&row, &Term::new_list_type(TypeBound::Linear))?; + Ok(Self::new_endo_unchecked(row)) } - /// True if both inputs and outputs are necessarily empty. - /// (For [`FuncValueType`], even after any possible substitution of row variables) + /// Create a new signature with the same input and output types (signature of an endomorphic + /// function). + /// No checks are performed as to whether the row is appropriate + /// (i.e. a list of runtime types). + pub fn new_endo_unchecked(row: impl Into) -> Self { + let row = row.into(); + Self::new_unchecked(row.clone(), row) + } + + // ALAN definitely opportunities to deduplicate between Signature/FuncValueType here... + pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { + self.input.validate(var_decls)?; + self.output.validate(var_decls)?; + // check_term_type does not look at inputs/outputs, so do that here + for t in [&self.input, &self.output] { + check_term_type(t, &Term::new_list_type(TypeBound::Linear))?; + } + Ok(()) + } + + /// True if both inputs and outputs are necessarily empty + /// (even after any possible substitution of row variables) #[inline(always)] #[must_use] pub fn is_empty(&self) -> bool { - self.input.is_empty() && self.output.is_empty() + self.input.is_empty_list() && self.output.is_empty_list() } +} - #[inline] - /// Returns a row of the value inputs of the function. - #[must_use] - pub fn input(&self) -> &TypeRowBase { - &self.input +impl Signature { + /// Create a new signature with specified inputs and outputs. + /// + /// # Panics + /// + /// If any of the input or output types are not runtime types. + /// See [Self::try_new] or [Self::new_unchecked] for alternatives. + pub fn new(input: impl Into, output: impl Into) -> Self { + Self::try_new(input, output).unwrap() } - #[inline] - /// Returns a row of the value outputs of the function. - #[must_use] - pub fn output(&self) -> &TypeRowBase { - &self.output + /// Create a new signature with specified inputs and outputs. + /// + /// # Errors + /// + /// If any of the input or output types are not runtime types. See [Self::new_unchecked] for an alternative. + pub fn try_new( + input: impl Into, + output: impl Into, + ) -> Result { + let input = input.into(); + let output = output.into(); + for t in input.iter().chain(output.iter()) { + check_term_type(t, &TypeBound::Linear.into())?; + } + Ok(Self::new_unchecked(input, output)) } - #[inline] - /// Returns a tuple with the input and output rows of the function. - #[must_use] - pub fn io(&self) -> (&TypeRowBase, &TypeRowBase) { - (&self.input, &self.output) + /// Create a new signature with specified inputs and outputs. + /// No checks are performed as to whether the input and output types are appropriate + /// (i.e. runtime types). + pub fn new_unchecked(input: impl Into, output: impl Into) -> Self { + Self { + input: input.into(), + output: output.into(), + } + } + + /// Create a new signature with the same input and output types (signature of an endomorphic + /// function). + /// + /// # Panics + /// + /// If any element of the row is not a runtime type. + /// See [Self::try_new_endo] or [Self::new_endo_unchecked] for alternatives. + pub fn new_endo(row: impl Into) -> Self { + let row = row.into(); + Self::new(row.clone(), row) + } + + /// Create a new signature with the same input and output types (signature of an endomorphic + /// function). + /// + /// # Errors + /// + /// If any element of the row is not a runtime type. + /// See [Self::new_endo_unchecked] for an alternative. + pub fn try_new_endo(row: impl Into) -> Result { + let row = row.into(); + for t in row.iter() { + check_term_type(t, &TypeBound::Linear.into())?; + } + Ok(Self::new_endo_unchecked(row)) + } + + /// Create a new signature with the same input and output types (signature of an endomorphic + /// function). + /// No checks are performed as to whether the elements of the row are appropriate + /// (i.e. runtime types). + pub fn new_endo_unchecked(row: impl Into) -> Self { + let row = row.into(); + Self::new_unchecked(row.clone(), row) } pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { self.input.validate(var_decls)?; - self.output.validate(var_decls) + self.output.validate(var_decls)?; + // check_term_type never gets here (and would not look at inputs/outputs if it did), + // so do that here + for t in self.input.iter().chain(self.output.iter()) { + check_term_type(t, &TypeBound::Linear.into())?; + } + Ok(()) + } + + /// True if both inputs and outputs are necessarily empty. + /// (For [`FuncValueType`], even after any possible substitution of row variables) + #[inline(always)] + #[must_use] + pub fn is_empty(&self) -> bool { + self.input.is_empty() && self.output.is_empty() } /// Returns a registry with the concrete extensions used by this signature. @@ -127,34 +323,6 @@ impl FuncTypeBase { } } -impl Transformable for FuncTypeBase { - fn transform(&mut self, tr: &T) -> Result { - // TODO handle extension sets? - Ok(self.input.transform(tr)? | self.output.transform(tr)?) - } -} - -impl FuncValueType { - /// If this `FuncValueType` contains any row variables, return one. - #[must_use] - pub fn find_rowvar(&self) -> Option { - self.input - .iter() - .chain(self.output.iter()) - .find_map(|t| Type::try_from(t.clone()).err()) - } -} - -// deriving Default leads to an impl that only applies for RV: Default -impl Default for FuncTypeBase { - fn default() -> Self { - Self { - input: Default::default(), - output: Default::default(), - } - } -} - impl Signature { /// Returns the type of a value [`Port`]. Returns `None` if the port is out /// of bounds. @@ -275,14 +443,6 @@ impl Signature { } } -impl Display for FuncTypeBase { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.input.fmt(f)?; - f.write_str(" -> ")?; - self.output.fmt(f) - } -} - impl TryFrom for Signature { type Error = SignatureError; @@ -302,31 +462,63 @@ impl From for FuncValueType { } } -impl PartialEq> for FuncTypeBase { - fn eq(&self, other: &FuncTypeBase) -> bool { - self.input == other.input && self.output == other.output - } -} - -impl PartialEq>> for FuncTypeBase { - fn eq(&self, other: &Cow<'_, FuncTypeBase>) -> bool { - self.eq(other.as_ref()) - } -} - -impl PartialEq> for Cow<'_, FuncTypeBase> { - fn eq(&self, other: &FuncTypeBase) -> bool { - self.as_ref().eq(other) +impl PartialEq for FuncValueType { + fn eq(&self, other: &Signature) -> bool { + // Ideally we should normalize input/output first, but assume e.g. substitute has done so already + if let Term::List(input) = &self.input + && let Term::List(output) = &self.output + { + return *input == *other.input && *output == *other.output; + } + false } } #[cfg(test)] mod test { + use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any, any_with}; + use proptest::{collection::vec, strategy::Union}; + use crate::extension::prelude::{bool_t, qb_t, usize_t}; + use crate::proptest::RecursionDepth; use crate::type_row; - use crate::types::{CustomType, TypeEnum, test::FnTransformer}; + use crate::types::test::FnTransformer; + use crate::types::{CustomType, TypeRow, proptest_utils::any_type, type_param::SeqPart}; use super::*; + + impl Arbitrary for Signature { + type Parameters = RecursionDepth; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + let input_strategy = any_with::(depth); + let output_strategy = any_with::(depth); + (input_strategy, output_strategy) + .prop_map(|(input, output)| Signature::new(input, output)) + .boxed() + } + type Strategy = BoxedStrategy; + } + + impl Arbitrary for FuncValueType { + type Parameters = RecursionDepth; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + let io_strategy = vec( + Union::new([ + any_type(depth).prop_map(SeqPart::Item).boxed(), + (any::(), any::()) + .prop_map(|(idx, bound)| SeqPart::Splice(Term::new_row_var_use(idx, bound))) + .boxed(), + ]), + 0..3, + ) + .prop_map(Term::new_list_from_parts); + (io_strategy.clone(), io_strategy) + .prop_map(|(input, output)| FuncValueType::new_unchecked(input, output)) + .boxed() + } + type Strategy = BoxedStrategy; + } + #[test] fn test_function_type() { let mut f_type = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]); @@ -355,7 +547,7 @@ mod test { #[test] fn test_transform() { - let TypeEnum::Extension(usz_t) = usize_t().as_type_enum().clone() else { + let Term::RuntimeExtension(usz_t) = usize_t() else { panic!() }; let tr = FnTransformer(|ct: &CustomType| (ct == &usz_t).then_some(bool_t())); diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index c9ddcc81b..3edf973db 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -4,6 +4,7 @@ //! //! [`TypeDef`]: crate::extension::TypeDef +use itertools::Itertools as _; use ordered_float::OrderedFloat; #[cfg(test)] use proptest_derive::Arbitrary; @@ -14,12 +15,9 @@ use std::sync::Arc; use thiserror::Error; use tracing::warn; -use super::row_var::MaybeRV; -use super::{ - NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer, - check_typevar_decl, -}; +use super::{Substitution, Transformable, Type, TypeBound, TypeTransformer}; use crate::extension::SignatureError; +use crate::types::{CustomType, FuncValueType, GeneralSum, SumType}; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) @@ -95,9 +93,22 @@ pub enum Term { /// The type of static tuples. #[display("TupleType[{_0}]")] TupleType(Box), - /// A runtime type as a term. Instance of [`Term::RuntimeType`]. + /// The type of runtime values defined by an extension type. + /// Instance of [Self::RuntimeType] for some bound. + // + // TODO optimise with `Box`? + // or some static version of this? + RuntimeExtension(CustomType), + /// The type of runtime values that are function pointers. + /// Instance of [Self::RuntimeType]`(`[TypeBound::Copyable]`)`. + /// Function values may be passed around without knowing their arity + /// (i.e. with row vars) as long as they are not called. #[display("{_0}")] - Runtime(Type), + RuntimeFunction(Box), + /// The type of runtime values that are sums of products (ADTs) + /// Instance of [Self::RuntimeType]`(bound)` for `bound` calculated from each variant's elements. + #[display("{_0}")] + RuntimeSum(SumType), /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`]. #[display("{_0}")] BoundedNat(u64), @@ -111,9 +122,14 @@ pub enum Term { #[display("{}", _0.into_inner())] Float(OrderedFloat), /// A list of static terms. Instance of [`Term::ListType`]. + /// Note, not a [TypeRow] because `impl Arbitrary for TypeRow` generates only types. + /// TODO ALAN....so should we serialize *all* TypeRows as `Vec` ? + /// + /// [TypeRow]: super::TypeRow #[display("[{}]", { use itertools::Itertools as _; - _0.iter().map(|t|t.to_string()).join(",") + // extra space matching old Display for Type(Row) - TODO, change Vec to TypeRow? + _0.iter().map(|t|t.to_string()).join(", ") })] List(Vec), /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type. @@ -197,24 +213,41 @@ impl Term { (Term::StringType, Term::StringType) => true, (Term::StaticType, Term::StaticType) => true, (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), + // The term inside a TupleType is a list of types, so this is ok as long as + // supertype holds element-wise (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2), (Term::BytesType, Term::BytesType) => true, (Term::FloatType, Term::FloatType) => true, - (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2, + // Needed for TupleType, does not make a great deal of sense otherwise: + (Term::List(es1), Term::List(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) + } + // The following are not types (they have no instances), so these are just to + // maintain reflexivity of the relation: + (Term::RuntimeSum(t1), Term::RuntimeSum(t2)) => t1 == t2, + (Term::RuntimeFunction(f1), Term::RuntimeFunction(f2)) => f1 == f2, + (Term::RuntimeExtension(c1), Term::RuntimeExtension(c2)) => c1 == c2, (Term::BoundedNat(n1), Term::BoundedNat(n2)) => n1 == n2, (Term::String(s1), Term::String(s2)) => s1 == s2, (Term::Bytes(v1), Term::Bytes(v2)) => v1 == v2, (Term::Float(f1), Term::Float(f2)) => f1 == f2, (Term::Variable(v1), Term::Variable(v2)) => v1 == v2, - (Term::List(es1), Term::List(es2)) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) - } (Term::Tuple(es1), Term::Tuple(es2)) => { es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) } _ => false, } } + + /// Returns true if this term is an empty list (contains no elements) + pub fn is_empty_list(&self) -> bool { + match self { + Term::List(v) => v.is_empty(), + // We probably don't need to be this thorough in dealing with unnormalized forms but it's easy enough + Term::ListConcat(v) => v.iter().all(Term::is_empty_list), + _ => false, + } + } } impl From for Term { @@ -229,15 +262,6 @@ impl From for Term { } } -impl From> for Term { - fn from(value: TypeBase) -> Self { - match value.try_into_type() { - Ok(ty) => Term::Runtime(ty), - Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)), - } - } -} - impl From for Term { fn from(n: u64) -> Self { Self::BoundedNat(n) @@ -280,24 +304,15 @@ pub struct TermVar { } impl Term { - /// [`Type::UNIT`] as a [`Term::Runtime`] - pub const UNIT: Self = Self::Runtime(Type::UNIT); - /// Makes a `TypeArg` representing a use (occurrence) of the type variable /// with the specified index. /// `decl` must be exactly that with which the variable was declared. #[must_use] - pub fn new_var_use(idx: usize, decl: Term) -> Self { - match decl { - // Note a TypeParam::List of TypeParam::Type *cannot* be represented - // as a TypeArg::Type because the latter stores a Type i.e. only a single type, - // not a RowVariable. - Term::RuntimeType(b) => Type::new_var_use(idx, b).into(), - _ => Term::Variable(TermVar { - idx, - cached_decl: Box::new(decl), - }), - } + pub fn new_var_use(idx: usize, decl: impl Into) -> Self { + Term::Variable(TermVar { + idx, + cached_decl: Box::new(decl.into()), + }) } /// Creates a new string literal. @@ -308,8 +323,11 @@ impl Term { /// Creates a new concatenated list. #[inline] - pub fn new_list_concat(lists: impl IntoIterator) -> Self { - Self::ListConcat(lists.into_iter().collect()) + pub fn concat_lists(lists: impl IntoIterator) -> Self { + match lists.into_iter().exactly_one() { + Ok(list) => list, + Err(e) => Self::ListConcat(e.collect()), + } } /// Creates a new tuple from its items. @@ -333,15 +351,57 @@ impl Term { } } - /// Returns a [`Type`] if the [`Term`] is a runtime type. - #[must_use] - pub fn as_runtime(&self) -> Option> { + /// Returns whether this `Term` is a type of runtime values + pub fn is_runtime(&self) -> bool { + matches!( + self, + Term::RuntimeExtension(_) | Term::RuntimeFunction(_) | Term::RuntimeSum(_) + ) + } + + /// Returns the inner [`CustomType`] if the type is from an extension. + pub fn as_extension(&self) -> Option<&CustomType> { + match self { + Self::RuntimeExtension(ct) => Some(ct), + _ => None, + } + } + + /// Returns the inner [`SumType`] if the type is a [Self::RuntimeSum]. + pub fn as_runtime_sum(&self) -> Option<&SumType> { + match self { + Self::RuntimeSum(st) => Some(st), + _ => None, + } + } + + #[allow(rustdoc::private_intra_doc_links)] + /// Returns the [TypeBound] if this `Term` is a runtime type. + /// (Does not check sub-[Term]s inside [Self::RuntimeSum] or [Self::RuntimeFunction]; + /// call [Self::validate] for that.) + pub const fn least_upper_bound(&self) -> Option { match self { - TypeArg::Runtime(ty) => Some(ty.clone()), + Self::RuntimeExtension(ct) => Some(ct.bound()), + Self::RuntimeSum(st) => Some(st.bound()), + Self::RuntimeFunction(_) => Some(TypeBound::Copyable), + Self::Variable(v) => match &*v.cached_decl { + TypeParam::RuntimeType(b) => Some(*b), + _ => None, + }, _ => None, } } + /// Report if this is a copyable runtime type, i.e. an instance + /// of [Self::RuntimeType]`(`[TypeBound::Copyable]`)` + // - i.e.the least upper bound of the type is contained by the copyable bound. + pub const fn copyable(&self) -> bool { + match self.least_upper_bound() { + Some(b) => TypeBound::Copyable.contains(b), + None => false, + } + } + /// Returns a string if the [`Term`] is a string literal. #[must_use] pub fn as_string(&self) -> Option { @@ -351,29 +411,50 @@ impl Term { } } - /// Much as [`Type::validate`], also checks that the type of any [`TypeArg::Opaque`] - /// is valid and closed. + /// Checks all variables used in the type are in the provided list + /// of bound variables, rejecting any [`RowVariable`]s if `allow_row_vars` is False; + /// and that for each [`CustomType`] the corresponding + /// [`TypeDef`] is in the [`ExtensionRegistry`] and the type arguments + /// [validate] and fit into the def's declared parameters. + /// + /// [RowVariable]: TypeEnum::RowVariable + /// [validate]: crate::types::type_param::TypeArg::validate + /// [TypeDef]: crate::extension::TypeDef pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { match self { - Term::Runtime(ty) => ty.validate(var_decls), + Term::RuntimeSum(SumType::General(GeneralSum { rows, bound })) => { + rows.iter().try_for_each(|row| row.validate(var_decls))?; + // check_term_type does not look beyond the cached bound, so do that here. + rows.iter() + .try_for_each(|row| check_term_type(row, &Term::new_list_type(*bound)))?; + debug_assert!( + *bound == TypeBound::Copyable + || !rows.iter().all(|r| { + check_term_type(r, &Term::new_list_type(TypeBound::Copyable)).is_ok() + }), + "Incorrect bound, should have been Copyable" + ); + Ok(()) + } + Term::RuntimeSum(SumType::Unit { .. }) => Ok(()), // No leaves there + Term::RuntimeExtension(custy) => custy.validate(var_decls), + Term::RuntimeFunction(ft) => ft.validate(var_decls), Term::List(elems) => { - // TODO: Full validation would check that the type of the elements agrees + // Full validation might check that the type of the elements agrees. + // However we will leave this to a separate check_term_type which knows + // the required element type. elems.iter().try_for_each(|a| a.validate(var_decls)) } Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), TypeArg::ListConcat(lists) => { - // TODO: Full validation would check that each of the lists is indeed a - // list or list variable of the correct types. + // Full validation might check that each of the lists is indeed a list or + // list variable of the correct types. However we will leave this to a + // separate check_term_type which knows the required element type. lists.iter().try_for_each(|a| a.validate(var_decls)) } TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), Term::Variable(TermVar { idx, cached_decl }) => { - assert!( - !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), - "Malformed TypeArg::Variable {cached_decl} - should be inconstructible" - ); - check_typevar_decl(var_decls, *idx, cached_decl) } Term::RuntimeType { .. } => Ok(()), @@ -388,70 +469,6 @@ impl Term { } } - pub(crate) fn substitute(&self, t: &Substitution) -> Self { - match self { - Term::Runtime(ty) => { - // RowVariables are represented as Term::Variable - ty.substitute1(t).into() - } - TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { - self.clone() - } // We do not allow variables as bounds on BoundedNat's - TypeArg::List(elems) => { - // NOTE: This implements a hack allowing substitutions to - // replace `TypeArg::Variable`s representing "row variables" - // with a list that is to be spliced into the containing list. - // We won't need this code anymore once we stop conflating types - // with lists of types. - - fn is_type(type_arg: &TypeArg) -> bool { - match type_arg { - TypeArg::Runtime(_) => true, - TypeArg::Variable(v) => v.bound_if_row_var().is_some(), - _ => false, - } - } - - let are_types = elems.first().map(is_type).unwrap_or(false); - - Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) { - list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list), - list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list), - elem => SeqPart::Item(elem), - })) - } - TypeArg::ListConcat(lists) => { - // When a substitution instantiates spliced list variables, we - // may be able to merge the concatenated lists. - Self::new_list_from_parts( - lists.iter().map(|list| SeqPart::Splice(list.substitute(t))), - ) - } - Term::Tuple(elems) => { - Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) - } - TypeArg::TupleConcat(tuples) => { - // When a substitution instantiates spliced tuple variables, - // we may be able to merge the concatenated tuples. - Self::new_tuple_from_parts( - tuples - .iter() - .map(|tuple| SeqPart::Splice(tuple.substitute(t))), - ) - } - TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), - Term::RuntimeType(_) => self.clone(), - Term::BoundedNatType(_) => self.clone(), - Term::StringType => self.clone(), - Term::BytesType => self.clone(), - Term::FloatType => self.clone(), - Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), - Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), - Term::StaticType => self.clone(), - Term::ConstType(ty) => Term::new_const(ty.substitute1(t)), - } - } - /// Helper method for [`TypeArg::new_list_from_parts`] and [`TypeArg::new_tuple_from_parts`]. fn new_seq_from_parts( parts: impl IntoIterator>, @@ -488,7 +505,7 @@ impl Term { Self::new_seq_from_parts( parts.into_iter().flat_map(ListPartIter::new), TypeArg::List, - TypeArg::ListConcat, + TypeArg::concat_lists, ) } @@ -518,7 +535,7 @@ impl Term { /// # let b = Term::new_string("b"); /// # let c = Term::new_string("c"); /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); - /// let term = Term::new_list_concat([ + /// let term = Term::concat_lists([ /// Term::new_list([a.clone(), b.clone()]), /// var.clone(), /// Term::new_list([c.clone()]) @@ -537,8 +554,8 @@ impl Term { /// # let a = Term::new_string("a"); /// # let b = Term::new_string("b"); /// # let c = Term::new_string("c"); - /// let term = Term::new_list_concat([ - /// Term::new_list_concat([ + /// let term = Term::concat_lists([ + /// Term::concat_lists([ /// Term::new_list([a.clone()]), /// Term::new_list([b.clone()]) /// ]), @@ -565,7 +582,7 @@ impl Term { /// ); /// ``` #[inline] - pub fn into_list_parts(self) -> ListPartIter { + pub fn into_list_parts(self) -> impl Iterator> { ListPartIter::new(SeqPart::Splice(self)) } @@ -584,15 +601,117 @@ impl Term { /// /// Analogous to [`TypeArg::into_list_parts`]. #[inline] - pub fn into_tuple_parts(self) -> TuplePartIter { + pub fn into_tuple_parts(self) -> impl Iterator> { TuplePartIter::new(SeqPart::Splice(self)) } + + /// Applies a substitution to this instance. Infallible (assuming the `subst` covers all + /// variables) and will not invalidate the instance (assuming all values substituted in, + /// are valid instances of the variables they replace). + /// + /// May change the structure of `self` significantly, e.g. if variables that stand for + /// rows of types are replaced by fixed-length lists of types. + /// + /// May change the [TypeBound] of the resulting type, e.g. if a variable whose bound + /// is [TypeBound::Linear] is replaced by a concrete type that is [TypeBound::Copyable]. + /// + /// # Panics + /// + /// If the substitution does not cover all type variables in `self`. + pub(crate) fn substitute(&self, s: &Substitution) -> Self { + match self { + TypeArg::RuntimeSum(SumType::Unit { .. }) => self.clone(), + TypeArg::RuntimeSum(SumType::General(GeneralSum { rows, .. })) => { + // A substitution of a row variable for an empty list, could make this from + // a GeneralSum into a unary SumType. Even new_unchecked recomputes the bound. + SumType::new_unchecked(rows.substitute(s).into_owned()).into() + } + TypeArg::RuntimeExtension(cty) => Term::new_extension(cty.substitute(s)), + TypeArg::RuntimeFunction(bf) => Term::new_function(bf.substitute(s)), + + TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { + self.clone() + } // We do not allow variables as bounds on BoundedNat's + TypeArg::List(elems) => Self::List(elems.iter().map(|t| t.substitute(s)).collect()), + TypeArg::ListConcat(lists) => { + // When a substitution instantiates spliced list variables, we + // may be able to merge the concatenated lists. + Self::new_list_from_parts( + lists.iter().map(|list| SeqPart::Splice(list.substitute(s))), + ) + } + Term::Tuple(elems) => { + Term::Tuple(elems.iter().map(|elem| elem.substitute(s)).collect()) + } + TypeArg::TupleConcat(tuples) => { + // When a substitution instantiates spliced tuple variables, + // we may be able to merge the concatenated tuples. + Self::new_tuple_from_parts( + tuples + .iter() + .map(|tuple| SeqPart::Splice(tuple.substitute(s))), + ) + } + TypeArg::Variable(TermVar { idx, cached_decl }) => s.apply_var(*idx, cached_decl), + Term::RuntimeType(_) => self.clone(), + Term::BoundedNatType(_) => self.clone(), + Term::StringType => self.clone(), + Term::BytesType => self.clone(), + Term::FloatType => self.clone(), + Term::ListType(item_type) => Term::new_list_type(item_type.substitute(s)), + Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(s)), + Term::StaticType => self.clone(), + Term::ConstType(ty) => Term::new_const(ty.substitute(s)), + } + } +} + +fn check_typevar_decl( + decls: &[TypeParam], + idx: usize, + cached_decl: &TypeParam, +) -> Result<(), SignatureError> { + match decls.get(idx) { + None => Err(SignatureError::FreeTypeVar { + idx, + num_decls: decls.len(), + }), + Some(actual) => { + // The cache here just mirrors the declaration. The typevar can be used + // anywhere expecting a kind *containing* the decl - see `check_type_arg`. + if actual == cached_decl { + Ok(()) + } else { + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + cached: Box::new(cached_decl.clone()), + actual: Box::new(actual.clone()), + }) + } + } + } } impl Transformable for Term { fn transform(&mut self, tr: &T) -> Result { match self { - Term::Runtime(ty) => ty.transform(tr), + Term::RuntimeExtension(custom_type) => { + if let Some(nt) = tr.apply_custom(custom_type)? { + *self = nt; + Ok(true) + } else { + let args_changed = custom_type.args_mut().transform(tr)?; + if args_changed { + *self = Self::new_extension( + custom_type + .get_type_def(&custom_type.get_extension()?)? + .instantiate(custom_type.args())?, + ); + } + Ok(args_changed) + } + } + Term::RuntimeFunction(fty) => fty.transform(tr), + Term::RuntimeSum(sum_type) => sum_type.transform(tr), Term::List(elems) => elems.transform(tr), Term::Tuple(elems) => elems.transform(tr), Term::BoundedNat(_) @@ -641,24 +760,17 @@ pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => { Ok(()) } - (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => { + (Term::RuntimeSum(st), Term::RuntimeType(bound)) if bound.contains(st.bound()) => Ok(()), + (Term::RuntimeFunction(_), Term::RuntimeType(_)) => Ok(()), // Function pointers are always Copyable so fit any bound + (Term::RuntimeExtension(cty), Term::RuntimeType(bound)) if bound.contains(cty.bound()) => { Ok(()) } - (Term::List(elems), Term::ListType(item_type)) => { - elems.iter().try_for_each(|term| { - // Also allow elements that are RowVars if fitting into a List of Types - if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) - && v.bound_if_row_var() - .is_some_and(|arg_bound| param_bound.contains(arg_bound)) - { - return Ok(()); - } - check_term_type(term, item_type) - }) - } - (Term::ListConcat(lists), Term::ListType(item_type)) => lists + (Term::List(elems), Term::ListType(item_type)) => elems .iter() - .try_for_each(|list| check_term_type(list, item_type)), + .try_for_each(|elem| check_term_type(elem, item_type)), + (Term::ListConcat(lists), Term::ListType(_)) => lists + .iter() + .try_for_each(|list| check_term_type(list, type_)), // ALAN this used the element type, which seems very wrong (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => { let term_parts: Vec<_> = term.clone().into_tuple_parts().collect(); let type_parts: Vec<_> = item_types.clone().into_list_parts().collect(); @@ -762,7 +874,7 @@ pub enum SeqPart { /// Iterator created by [`TypeArg::into_list_parts`]. #[derive(Debug, Clone)] -pub struct ListPartIter { +pub(crate) struct ListPartIter { parts: SmallVec<[SeqPart; 1]>, } @@ -797,7 +909,7 @@ impl FusedIterator for ListPartIter {} /// Iterator created by [`TypeArg::into_tuple_parts`]. #[derive(Debug, Clone)] -pub struct TuplePartIter { +pub(crate) struct TuplePartIter { parts: SmallVec<[SeqPart; 1]>, } @@ -836,8 +948,8 @@ mod test { use super::{Substitution, TypeArg, TypeParam, check_term_type}; use crate::extension::prelude::{bool_t, usize_t}; - use crate::types::Term; use crate::types::type_param::SeqPart; + use crate::types::{Term, TypeRow}; use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; #[test] @@ -868,13 +980,13 @@ mod test { let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); let parts = [ SeqPart::Splice(Term::new_list([a.clone(), b.clone()])), - SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])), + SeqPart::Splice(Term::concat_lists([Term::new_list([c.clone()])])), SeqPart::Item(d.clone()), SeqPart::Splice(var.clone()), ]; assert_eq!( Term::new_list_from_parts(parts), - Term::new_list_concat([Term::new_list([a, b, c, d]), var]) + Term::concat_lists([Term::new_list([a, b, c, d]), var]) ); } @@ -918,34 +1030,36 @@ mod test { // Into a list of type, we can fit a single row var check(rowvar(0, TypeBound::Copyable), &seq_param).unwrap(); - // or a list of (types or row vars) + // or a list of types, or a "concat" of row vars check(vec![], &seq_param).unwrap(); - check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap(); - check_seq( - &[ + check( + Term::ListConcat(vec![rowvar(0, TypeBound::Copyable); 2]), + &seq_param, + ) + .unwrap(); + // but a *list* of the rowvar is a list of list of types, which is wrong + check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap_err(); + check( + Term::concat_lists([ rowvar(1, TypeBound::Linear), - usize_t().into(), + vec![usize_t()].into(), rowvar(0, TypeBound::Copyable), - ], + ]), &TypeParam::new_list_type(TypeBound::Linear), ) .unwrap(); - // Next one fails because a list of Eq is required - check_seq( - &[ + // Next one fails because a list of Copyable is required + check( + Term::concat_lists([ rowvar(1, TypeBound::Linear), - usize_t().into(), + vec![usize_t()].into(), rowvar(0, TypeBound::Copyable), - ], + ]), &seq_param, ) .unwrap_err(); // seq of seq of types is not allowed - check( - vec![usize_t().into(), vec![usize_t().into()].into()], - &seq_param, - ) - .unwrap_err(); + check(vec![usize_t(), vec![usize_t()].into()], &seq_param).unwrap_err(); // Similar for nats (but no equivalent of fancy row vars) check(5, &TypeParam::max_nat_type()).unwrap(); @@ -964,16 +1078,8 @@ mod test { // `Term::TupleType` requires a `Term::Tuple` of the same number of elems let usize_and_ty = TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]); - check( - TypeArg::Tuple(vec![5.into(), usize_t().into()]), - &usize_and_ty, - ) - .unwrap(); - check( - TypeArg::Tuple(vec![usize_t().into(), 5.into()]), - &usize_and_ty, - ) - .unwrap_err(); // Wrong way around + check(TypeArg::Tuple(vec![5.into(), usize_t()]), &usize_and_ty).unwrap(); + check(TypeArg::Tuple(vec![usize_t(), 5.into()]), &usize_and_ty).unwrap_err(); // Wrong way around let two_types = TypeParam::new_tuple_type(Term::new_list([ TypeBound::Linear.into(), TypeBound::Linear.into(), @@ -986,23 +1092,20 @@ mod test { #[test] fn type_arg_subst_row() { let row_param = Term::new_list_type(TypeBound::Copyable); - let row_arg: Term = vec![bool_t().into(), Term::UNIT].into(); + let row_arg: Term = vec![bool_t(), Term::UNIT].into(); check_term_type(&row_arg, &row_param).unwrap(); // Now say a row variable referring to *that* row was used // to instantiate an outer "row parameter" (list of type). let outer_param = Term::new_list_type(TypeBound::Linear); - let outer_arg = Term::new_list([ - TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), - usize_t().into(), + let outer_arg = Term::concat_lists([ + TypeRV::new_row_var_use(0, TypeBound::Copyable), + Term::new_list([usize_t()]), ]); check_term_type(&outer_arg, &outer_param).unwrap(); let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); - assert_eq!( - outer_arg2, - vec![bool_t().into(), Term::UNIT, usize_t().into()].into() - ); + assert_eq!(outer_arg2, vec![bool_t(), Term::UNIT, usize_t()].into()); // Of course this is still valid (as substitution is guaranteed to preserve validity) check_term_type(&outer_arg2, &outer_param).unwrap(); @@ -1015,9 +1118,9 @@ mod test { let row_var_use = Term::new_var_use(0, row_var_decl.clone()); let good_arg = Term::new_list([ // The row variables here refer to `row_var_decl` above - vec![usize_t().into()].into(), + vec![usize_t()].into(), row_var_use.clone(), - vec![row_var_use, usize_t().into()].into(), + Term::concat_lists([row_var_use, Term::new_list([usize_t()])]), ]); check_term_type(&good_arg, &outer_param).unwrap(); @@ -1025,31 +1128,44 @@ mod test { let Term::List(mut elems) = good_arg.clone() else { panic!() }; - elems.push(usize_t().into()); + elems.push(usize_t()); assert_eq!( check_term_type(&Term::new_list(elems), &outer_param), Err(TermTypeError::TypeMismatch { - term: Box::new(usize_t().into()), + term: Box::new(usize_t()), // The error reports the type expected for each element of the list: type_: Box::new(TypeParam::new_list_type(TypeBound::Linear)) }) ); // Now substitute a list of two types for that row-variable - let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); + let row_var_arg = vec![usize_t(), bool_t()].into(); check_term_type(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = good_arg.substitute(&Substitution(std::slice::from_ref(&row_var_arg))); check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, Term::new_list([ - Term::new_list([usize_t().into()]), + Term::new_list([usize_t()]), row_var_arg, - Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()]) + Term::new_list([usize_t(), bool_t(), usize_t()]) ]) ); } + #[test] + fn test_try_into_list_elements() { + // Test successful conversion with List + let types = vec![Term::new_unit_sum(1), bool_t()]; + let term = TypeArg::List(types.clone()); + let result = term.try_into(); + assert_eq!(result, Ok(TypeRow::from(types))); + + // Test failure with non-list + let result = TypeRow::try_from(Term::UNIT); + assert!(result.is_err()); + } + #[test] fn bytes_json_roundtrip() { let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into()); @@ -1058,13 +1174,41 @@ mod test { assert_eq!(deserialized, bytes_arg); } - mod proptest { + #[test] + fn list_from_single_part_item() { + // arbitrary, not but worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!( + Term::List(vec![term.clone()]), + Term::new_list_from_parts(std::iter::once(SeqPart::Item(term))) + ); + } + + #[test] + fn list_from_single_part_splice() { + // arbitrary, not but worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!( + term.clone(), + Term::new_list_from_parts(std::iter::once(SeqPart::Splice(term))) + ); + } + + #[test] + fn list_concat_single_item() { + // arbitrary, not but worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!(term.clone(), Term::concat_lists([term])); + } + mod proptest { + use prop::{collection::vec, strategy::Union}; use proptest::prelude::*; use super::super::{TermVar, UpperBound}; use crate::proptest::RecursionDepth; - use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param}; + use crate::types::proptest_utils::{any_serde_type_param, any_type}; + use crate::types::{Term, TypeBound}; impl Arbitrary for TermVar { type Parameters = RecursionDepth; @@ -1083,9 +1227,7 @@ mod test { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - use prop::collection::vec; - use prop::strategy::Union; - let mut strat = Union::new([ + let strat = Union::new([ Just(Self::StringType).boxed(), Just(Self::BytesType).boxed(), Just(Self::FloatType).boxed(), @@ -1100,32 +1242,31 @@ mod test { any::() .prop_map(|value| Self::Float(value.into())) .boxed(), - any_with::(depth).prop_map(Self::from).boxed(), + any_type(depth), ]); - if !depth.leaf() { - // we descend here because we these constructors contain Terms - strat = strat - .or( - // TODO this is a bit dodgy, TypeArgVariables are supposed - // to be constructed from TypeArg::new_var_use. We are only - // using this instance for serialization now, but if we want - // to generate valid TypeArgs this will need to change. - any_with::(depth.descend()) - .prop_map(Self::Variable) - .boxed(), - ) - .or(any_with::(depth.descend()) - .prop_map(Self::new_list_type) - .boxed()) - .or(any_with::(depth.descend()) - .prop_map(Self::new_tuple_type) - .boxed()) - .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(Self::new_list) - .boxed()); + if depth.leaf() { + return strat.boxed(); } - - strat.boxed() + // we descend here because we these constructors contain Terms + let depth = depth.descend(); + strat + .or( + // TODO this is a bit dodgy, TypeArgVariables are supposed + // to be constructed from TypeArg::new_var_use. We are only + // using this instance for serialization now, but if we want + // to generate valid TypeArgs this will need to change. + any_with::(depth).prop_map(Self::Variable).boxed(), + ) + .or(any_with::(depth) + .prop_map(Self::new_list_type) + .boxed()) + .or(any_with::(depth) + .prop_map(Self::new_tuple_type) + .boxed()) + .or(vec(any_with::(depth), 0..3) + .prop_map(Self::new_list) + .boxed()) + .boxed() } } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index db9314ff6..e2576bc06 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -7,43 +7,36 @@ use std::{ ops::{Deref, DerefMut}, }; -use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Term, Transformable, Type, TypeArg, TypeBase, TypeRV, - TypeTransformer, type_param::TypeParam, -}; +use super::{Substitution, Term, Transformable, Type, TypeTransformer, type_param::TypeParam}; use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; use itertools::Itertools; -/// List of types, used for function signatures. -/// The `ROWVARS` parameter controls whether this may contain [`RowVariable`]s -#[derive(Clone, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] +/// List of types. Like a `Vec<`[Term]`>` but serializes into legacy +/// JSON format for types only (serialization will panic if elements +/// are not [Term::RuntimeType]s or row variables thereof). +/// +/// Also allows sharing via `Cow` and static allocation via [type_row!]. +/// +/// [type_row!]: crate::type_row +#[derive(Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] -#[serde(transparent)] -pub struct TypeRowBase { +pub struct TypeRow { /// The datatypes in the row. - types: Cow<'static, [TypeBase]>, + types: Cow<'static, [Term]>, } -/// Row of single types i.e. of known length, for node inputs/outputs -pub type TypeRow = TypeRowBase; - -/// Row of types and/or row variables, the number of actual types is thus -/// unknown -pub type TypeRowRV = TypeRowBase; +/// Legacy alias. Used to indicate a [Term] that `check_term_type`s against +/// [Term::ListType] of [Term::RuntimeType] (of a [TypeBound]), i.e. one of +/// * A [Term::Variable] of type [Term::ListType] (of [Term::RuntimeType]...) +/// * A [Term::List], each of whose elements is of type some [Term::RuntimeType] +/// * A [Term::ListConcat], each of whose sublists is one of these three +/// +/// [TypeBound]: crate::types::TypeBound +// ALAN TODO remove this? or make a wrapper struct? +pub type TypeRowRV = Term; -impl PartialEq> for TypeRowBase { - fn eq(&self, other: &TypeRowBase) -> bool { - self.types.len() == other.types.len() - && self - .types - .iter() - .zip(other.types.iter()) - .all(|(s, o)| s == o) - } -} - -impl Display for TypeRowBase { +impl Display for TypeRow { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_char('[')?; display_list(self.types.as_ref(), f)?; @@ -51,7 +44,7 @@ impl Display for TypeRowBase { } } -impl TypeRowBase { +impl TypeRow { /// Create a new empty row. #[must_use] pub const fn new() -> Self { @@ -61,22 +54,20 @@ impl TypeRowBase { } /// Returns a new `TypeRow` with `xs` concatenated onto `self`. - pub fn extend<'a>(&'a self, rest: impl IntoIterator>) -> Self { + pub fn extend<'a>(&'a self, rest: impl IntoIterator) -> Self { self.iter().chain(rest).cloned().collect_vec().into() } /// Returns a reference to the types in the row. #[must_use] - pub fn as_slice(&self) -> &[TypeBase] { + pub fn as_slice(&self) -> &[Term] { &self.types } /// Applies a substitution to the row. - /// For `TypeRowRV`, note this may change the length of the row. - /// For `TypeRow`, guaranteed not to change the length of the row. pub(crate) fn substitute(&self, s: &Substitution) -> Self { self.iter() - .flat_map(|ty| ty.substitute(s)) + .map(|ty| ty.substitute(s)) .collect::>() .into() } @@ -84,16 +75,16 @@ impl TypeRowBase { delegate! { to self.types { /// Iterator over the types in the row. - pub fn iter(&self) -> impl Iterator>; + pub fn iter(&self) -> impl Iterator; /// Mutable vector of the types in the row. - pub fn to_mut(&mut self) -> &mut Vec>; + pub fn to_mut(&mut self) -> &mut Vec; /// Allow access (consumption) of the contained elements - #[must_use] pub fn into_owned(self) -> Vec>; + #[must_use] pub fn into_owned(self) -> Vec; /// Returns `true` if the row contains no types. - #[must_use] pub fn is_empty(&self) -> bool ; + #[must_use] pub fn is_empty(&self) -> bool; } } @@ -102,12 +93,13 @@ impl TypeRowBase { } } -impl Transformable for TypeRowBase { +impl Transformable for TypeRow { fn transform(&mut self, tr: &T) -> Result { self.to_mut().transform(tr) } } +// ALAN these were considered only good to make available for non-RV TypeRows... impl TypeRow { delegate! { to self.types { @@ -127,128 +119,43 @@ impl TypeRow { } } -impl TryFrom for TypeRow { - type Error = SignatureError; - - fn try_from(value: TypeRowRV) -> Result { - Ok(Self::from( - value - .into_owned() - .into_iter() - .map(std::convert::TryInto::try_into) - .collect::, _>>() - .map_err(|var| SignatureError::RowVarWhereTypeExpected { var })?, - )) - } -} - -impl Default for TypeRowBase { +impl Default for TypeRow { fn default() -> Self { Self::new() } } -impl From>> for TypeRowBase { - fn from(types: Vec>) -> Self { +impl From> for TypeRow { + fn from(types: Vec) -> Self { Self { types: types.into(), } } } -impl From> for TypeRowRV { - fn from(types: Vec) -> Self { - Self { - types: types.into_iter().map(Type::into_).collect(), - } - } -} - -impl From for TypeRowRV { - fn from(value: TypeRow) -> Self { - Self { - types: value.into_owned().into_iter().map(Type::into_).collect(), - } - } -} - -impl From<[TypeBase; N]> for TypeRowBase { - fn from(types: [TypeBase; N]) -> Self { - Self::from(Vec::from(types)) - } -} - -impl From<[Type; N]> for TypeRowRV { +impl From<[Type; N]> for TypeRow { fn from(types: [Type; N]) -> Self { Self::from(Vec::from(types)) } } -impl From<&'static [TypeBase]> for TypeRowBase { - fn from(types: &'static [TypeBase]) -> Self { +impl From<&'static [Type]> for TypeRow { + fn from(types: &'static [Type]) -> Self { Self { types: types.into(), } } } -// Fallibly convert a [Term] to a [TypeRV]. -// -// This will fail if `arg` is of non-type kind (e.g. String). -impl TryFrom for TypeRV { - type Error = SignatureError; - - fn try_from(value: Term) -> Result { - match value { - TypeArg::Runtime(ty) => Ok(ty.into()), - TypeArg::Variable(v) => Ok(TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )), - _ => Err(SignatureError::InvalidTypeArgs), - } - } -} - -// Fallibly convert a [Term] to a [TypeRow]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type) -// or if the sequence contains row variables. +/// Fallibly convert a [Term] to a [TypeRow]. +/// +/// This will fail if `arg` is not a [Term::List]. impl TryFrom for TypeRow { type Error = SignatureError; - fn try_from(value: TypeArg) -> Result { - match value { - TypeArg::List(elems) => elems - .into_iter() - .map(|ta| ta.as_runtime().ok_or(SignatureError::InvalidTypeArgs)) - .collect::, _>>() - .map(TypeRow::from), - _ => Err(SignatureError::InvalidTypeArgs), - } - } -} - -// Fallibly convert a [TypeArg] to a [TypeRowRV]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type). -impl TryFrom for TypeRowRV { - type Error = SignatureError; - fn try_from(value: Term) -> Result { match value { - TypeArg::List(elems) => elems - .into_iter() - .map(TypeRV::try_from) - .collect::, _>>() - .map(|vec| vec.into()), - TypeArg::Variable(v) => Ok(vec![TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )] - .into()), + Term::List(elems) => Ok(Self::from(elems)), _ => Err(SignatureError::InvalidTypeArgs), } } @@ -256,52 +163,71 @@ impl TryFrom for TypeRowRV { impl From for Term { fn from(value: TypeRow) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) - } -} - -impl From for Term { - fn from(value: TypeRowRV) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) + Term::List(value.into_owned()) } } -impl Deref for TypeRowBase { - type Target = [TypeBase]; +impl Deref for TypeRow { + type Target = [Term]; fn deref(&self) -> &Self::Target { self.as_slice() } } -impl DerefMut for TypeRowBase { +impl DerefMut for TypeRow { fn deref_mut(&mut self) -> &mut Self::Target { self.types.to_mut() } } +mod serialize { + use super::TypeRow; + use crate::types::Term; + use crate::types::serialize::SerSimpleType; + use itertools::Itertools as _; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + impl Serialize for TypeRow { + fn serialize(&self, s: S) -> Result { + let elems: Vec = self + .iter() + .map(|ty| ty.clone().try_into().unwrap()) + .collect(); + elems.serialize(s) + } + } + + impl<'de> Deserialize<'de> for TypeRow { + fn deserialize>(deser: D) -> Result { + let sertypes: Vec = Deserialize::deserialize(deser)?; + Ok(Self::from( + sertypes.into_iter().map_into().collect::>(), + )) + } + } +} + #[cfg(test)] mod test { use super::*; - use crate::{ - extension::prelude::bool_t, - types::{Type, TypeArg, TypeRV}, - }; + use crate::{extension::prelude::bool_t, types::Type}; mod proptest { + use super::super::TypeRow; use crate::proptest::RecursionDepth; - use crate::types::{MaybeRV, TypeBase, TypeRowBase}; + use crate::types::proptest_utils::any_type; use ::proptest::prelude::*; - impl Arbitrary for super::super::TypeRowBase { + impl Arbitrary for TypeRow { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { use proptest::collection::vec; if depth.leaf() { - Just(TypeRowBase::new()).boxed() + Just(TypeRow::new()).boxed() } else { - vec(any_with::>(depth), 0..4) + vec(any_type(depth.descend()), 0..4) .prop_map(|ts| ts.clone().into()) .boxed() } @@ -309,77 +235,19 @@ mod test { } } - #[test] - fn test_try_from_term_to_typerv() { - // Test successful conversion with Runtime type - let runtime_type = Type::UNIT; - let term = TypeArg::Runtime(runtime_type.clone()); - let result = TypeRV::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRV::from(runtime_type)); - - // Test failure with non-type kind - let term = Term::String("test".to_string()); - let result = TypeRV::try_from(term); - assert!(result.is_err()); - } - - #[test] - fn test_try_from_term_to_typerow() { - // Test successful conversion with List - let types = vec![Type::new_unit_sum(1), bool_t()]; - let type_args = types.iter().map(|t| TypeArg::Runtime(t.clone())).collect(); - let term = TypeArg::List(type_args); - let result = TypeRow::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRow::from(types)); - - // Test failure with non-list - let term = TypeArg::Runtime(Type::UNIT); - let result = TypeRow::try_from(term); - assert!(result.is_err()); - } - - #[test] - fn test_try_from_term_to_typerowrv() { - // Test successful conversion with List - let types = [TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_args = types.iter().map(|t| t.clone().into()).collect(); - let term = TypeArg::List(type_args); - let result = TypeRowRV::try_from(term); - assert!(result.is_ok()); - - // Test failure with non-sequence kind - let term = Term::String("test".to_string()); - let result = TypeRowRV::try_from(term); - assert!(result.is_err()); - } - #[test] fn test_from_typerow_to_term() { let types = vec![Type::UNIT, bool_t()]; let type_row = TypeRow::from(types); - let term = Term::from(type_row); + let term = Term::from(type_row.clone()); - match term { + match &term { Term::List(elems) => { assert_eq!(elems.len(), 2); } _ => panic!("Expected Term::List"), } - } - - #[test] - fn test_from_typerowrv_to_term() { - let types = vec![TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_row_rv = TypeRowRV::from(types); - let term = Term::from(type_row_rv); - match term { - TypeArg::List(elems) => { - assert_eq!(elems.len(), 2); - } - _ => panic!("Expected Term::List"), - } + assert_eq!(term.try_into(), Ok(type_row)); } } diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 3f8efef20..8a9759d58 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -5,10 +5,8 @@ use hugr_core::ops::{ CFG, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant, LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, constant::Sum, }; -use hugr_core::{ - HugrView, NodeIndex, - types::{SumType, Type, TypeEnum}, -}; +use hugr_core::types::{SumType, Term, Type, TypeBound, type_param::check_term_type}; +use hugr_core::{HugrView, NodeIndex}; use inkwell::types::BasicTypeEnum; use inkwell::values::{BasicValueEnum, CallableValue}; use itertools::{Itertools, zip_eq}; @@ -101,15 +99,16 @@ where } fn get_exactly_one_sum_type(ts: impl IntoIterator) -> Result { - let Some(TypeEnum::Sum(sum_type)) = ts + match ts .into_iter() - .map(|t| t.as_type_enum().clone()) + // ALAN Do we need to error on multiple (non-sum)types? + // if not, we can just take as_runtime_sum? + .filter(|t| check_term_type(t, &TypeBound::Linear.into()).is_ok()) .exactly_one() - .ok() - else { - Err(anyhow!("Not exactly one SumType"))? - }; - Ok(sum_type) + { + Ok(Term::RuntimeSum(st)) => Ok(st), + _ => Err(anyhow!("Not exactly one SumType")), + } } pub fn emit_value<'c, H: HugrView>( diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index e304800b9..d85306d6b 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -24,7 +24,7 @@ use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ self, ArrayClone, ArrayDiscard, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::{Term, TypeArg}; use hugr_core::{HugrView, Node}; use inkwell::builder::Builder; use inkwell::intrinsics::Intrinsic; @@ -214,7 +214,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), ty] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; @@ -485,7 +485,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("ArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -546,7 +546,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("ArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -608,7 +608,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("ArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/borrow_array.rs b/hugr-llvm/src/extension/collections/borrow_array.rs index c8869907a..083883faf 100644 --- a/hugr-llvm/src/extension/collections/borrow_array.rs +++ b/hugr-llvm/src/extension/collections/borrow_array.rs @@ -31,7 +31,7 @@ use hugr_core::std_extensions::collections::borrow_array::{ BArrayRepeat, BArrayScan, BArrayToArray, BArrayToArrayDef, BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::{Term, TypeArg}; use hugr_core::{HugrView, Node}; use inkwell::builder::Builder; use inkwell::intrinsics::Intrinsic; @@ -296,8 +296,7 @@ impl CodegenExtension for BorrowArrayCodegenExtension>( .ok_or(anyhow!("BArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("BArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -1151,7 +1150,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("BArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -1218,7 +1217,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("BArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index f2689f838..8d1082058 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -4,7 +4,7 @@ use hugr_core::{ extension::simple_op::MakeExtensionOp as _, ops::ExtensionOp, std_extensions::collections::list::{self, ListOp, ListValue}, - types::{SumType, Type, TypeArg}, + types::{SumType, Type}, }; use inkwell::values::FunctionValue; use inkwell::{ @@ -203,7 +203,7 @@ fn emit_list_op<'c, H: HugrView>( op: ListOp, ) -> Result<()> { let hugr_elem_ty = match args.node().args() { - [TypeArg::Runtime(ty)] => ty.clone(), + [ty] => ty.clone(), _ => { bail!("Collections: invalid type args for list op"); } @@ -394,7 +394,7 @@ mod test { use hugr_core::extension::simple_op::MakeExtensionOp as _; let ext_op = list::EXTENSION - .instantiate_extension_op(op.op_id().as_ref(), [qb_t().into()]) + .instantiate_extension_op(op.op_id().as_ref(), [qb_t()]) .unwrap(); let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]); es.validate().unwrap(); diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap index 88e720d4f..508099701 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap @@ -5,24 +5,24 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.inner.85e364de.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.f2fe62a1.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.6f214c99.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.f9784340.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.399ad802.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.ab883312.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.ba073e80.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.206c0fa7.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.fcc3ee9.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.79e68bc9.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.4ea91316.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.85e364de.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.f2fe62a1.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.6f214c99.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.f9784340.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.399ad802.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.ab883312.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.ba073e80.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.206c0fa7.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.fcc3ee9.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.79e68bc9.0 to { i64, [0 x i64] }*)] } define private i64 @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 + %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.4ea91316.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 %1 = load i64, i64* %0, align 4 ret i64 %1 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap index 4f009047f..6710b5a79 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap @@ -5,17 +5,17 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.inner.85e364de.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.f2fe62a1.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.6f214c99.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.f9784340.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.399ad802.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.ab883312.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.ba073e80.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.206c0fa7.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.fcc3ee9.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.79e68bc9.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.4ea91316.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.85e364de.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.f2fe62a1.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.6f214c99.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.f9784340.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.399ad802.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.ab883312.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.ba073e80.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.206c0fa7.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.fcc3ee9.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.79e68bc9.0 to { i64, [0 x i64] }*)] } define private i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 + store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.4ea91316.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i64, [0 x i64] }*] }*, { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* %"5_01", i32 0, i32 0 %1 = load i64, i64* %0, align 4 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_0.snap index 1a834b4ac..0042cf0b8 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_0.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.35f6713a.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define private { i64, [0 x i64] }* @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.97cb22bf.0 to { i64, [0 x i64] }*) + ret { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.35f6713a.0 to { i64, [0 x i64] }*) } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_2.snap index f04fec2d6..868633dfd 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_2.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.f37f5956.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define private { i64, [0 x i1] }* @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d2dddd66.0 to { i64, [0 x i1] }*) + ret { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.f37f5956.0 to { i64, [0 x i1] }*) } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_3.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_3.snap index 0bd3db500..fd1eaba63 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_3.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_3.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } +@sa.d.6e9d4a5d.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } define private { i64, [0 x { i1, i64 }] }* @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.eee08a59.0 to { i64, [0 x { i1, i64 }] }*) + ret { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.6e9d4a5d.0 to { i64, [0 x { i1, i64 }] }*) } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap index 738e34eae..13bb313f9 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.35f6713a.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define private { i64, [0 x i64] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.97cb22bf.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }** %"5_0", align 8 + store { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.35f6713a.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }** %"5_0", align 8 %"5_01" = load { i64, [0 x i64] }*, { i64, [0 x i64] }** %"5_0", align 8 store { i64, [0 x i64] }* %"5_01", { i64, [0 x i64] }** %"0", align 8 %"02" = load { i64, [0 x i64] }*, { i64, [0 x i64] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap index 524dae1e4..dc2256978 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.f37f5956.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define private { i64, [0 x i1] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d2dddd66.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8 + store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.f37f5956.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8 %"5_01" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"5_0", align 8 store { i64, [0 x i1] }* %"5_01", { i64, [0 x i1] }** %"0", align 8 %"02" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_3.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_3.snap index 193e5376b..b11563968 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_3.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_3.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } +@sa.d.6e9d4a5d.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } define private { i64, [0 x { i1, i64 }] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.eee08a59.0 to { i64, [0 x { i1, i64 }] }*), { i64, [0 x { i1, i64 }] }** %"5_0", align 8 + store { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.6e9d4a5d.0 to { i64, [0 x { i1, i64 }] }*), { i64, [0 x { i1, i64 }] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i1, i64 }] }*, { i64, [0 x { i1, i64 }] }** %"5_0", align 8 store { i64, [0 x { i1, i64 }] }* %"5_01", { i64, [0 x { i1, i64 }] }** %"0", align 8 %"02" = load { i64, [0 x { i1, i64 }] }*, { i64, [0 x { i1, i64 }] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index 285a1ba3e..b8d6235b8 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -14,7 +14,7 @@ use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ self, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::{Term, TypeArg}; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::{Builder, BuilderError}; @@ -135,7 +135,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), ty] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; @@ -357,7 +357,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("ArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -420,7 +420,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("ArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -494,7 +494,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Term::RuntimeSum(st) = res_hugr_ty else { Err(anyhow!("ArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 6747898ab..8026cc273 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -10,6 +10,8 @@ use hugr_core::{ std_extensions::collections::static_array::{ self, StaticArrayOp, StaticArrayOpDef, StaticArrayValue, }, + types::TypeBound, + types::type_param::check_term_type, }; use inkwell::{ AddressSpace, IntPredicate, @@ -369,10 +371,10 @@ impl CodegenExtension for StaticArrayCodegenE { let sac = self.0.clone(); move |ts, custom_type| { - let element_type = custom_type.args()[0] - .as_runtime() + let element_type = &custom_type.args()[0]; + check_term_type(element_type, &TypeBound::Copyable.into()) .expect("Type argument for static array must be a type"); - sac.static_array_type(ts, &element_type) + sac.static_array_type(ts, element_type) } }, ) @@ -426,7 +428,7 @@ mod test { #[case] op: StaticArrayOpDef, #[case] ty: HugrType, ) { - let op = op.instantiate(&[ty.clone().into()]).unwrap(); + let op = op.instantiate(std::slice::from_ref(&ty)).unwrap(); let op = OpType::from(op.to_extension_op().unwrap()); llvm_ctx.add_extensions(|ceb| { ceb.add_default_static_array_extensions() diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index 65d9b0227..d963c5fe4 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -8,7 +8,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, constant::Value, custom::ExtensionOp}, std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, - types::{TypeEnum, TypeRow}, + types::{Term, TypeRow}, }; use inkwell::{FloatPredicate, IntPredicate, types::IntType, values::BasicValue}; @@ -189,12 +189,10 @@ fn emit_conversion_op<'c, H: HugrView>( .typing_session() .llvm_type(&INT_TYPES[0])? .into_int_type(); - let sum_ty = context - .typing_session() - .llvm_sum_type(match bool_t().as_type_enum() { - TypeEnum::Sum(st) => st.clone(), - _ => panic!("Hugr prelude bool_t() not a Sum"), - })?; + let sum_ty = context.typing_session().llvm_sum_type(match bool_t() { + Term::RuntimeSum(st) => st, + _ => panic!("Hugr prelude bool_t() not a Sum"), + })?; emit_custom_unary_op(context, args, |ctx, arg, _| { let res = if conversion_op == ConvertOpDef::itobool { diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index 57d2f9ecc..e4b02ffa3 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -678,7 +678,7 @@ fn emit_int_op<'c, H: HugrView>( outs, out_log_width, true, - out_ty.as_sum().unwrap().clone(), + out_ty.as_runtime_sum().unwrap().clone(), )?; Ok(vec![result]) }) @@ -696,7 +696,7 @@ fn emit_int_op<'c, H: HugrView>( outs, out_log_width, false, - out_ty.as_sum().unwrap().clone(), + out_ty.as_runtime_sum().unwrap().clone(), )?; Ok(vec![result]) }) @@ -1187,8 +1187,8 @@ mod test { } fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr { - let ty = &INT_TYPES[log_width as usize]; - test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone()) + let ty = INT_TYPES[log_width as usize].to_owned(); + test_int_op_with_results::<2>(ext_op, log_width, None, ty) } fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr { @@ -1212,11 +1212,11 @@ mod test { output_type: Type, process: impl Fn(&mut DFGW, Outputs) -> Result, ) -> Hugr { - let ty = &INT_TYPES[log_width as usize]; + let ty = INT_TYPES[log_width as usize].to_owned(); let input_tys = if inputs.is_some() { vec![] } else { - let input_tys = itertools::repeat_n(ty.clone(), N).collect(); + let input_tys = itertools::repeat_n(ty, N).collect(); assert_eq!(input_tys, ext_op.signature().input.to_vec()); input_tys }; diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index db2e0d20c..5ae3af004 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -575,7 +575,7 @@ mod test { fn prelude_make_tuple(prelude_llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() .with_ins(vec![bool_t(), bool_t()]) - .with_outs([Type::new_tuple(vec![bool_t(); 2])]) + .with_outs([Type::new_runtime_tuple(vec![bool_t(); 2])]) .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let in_wires = builder.input_wires(); @@ -588,7 +588,7 @@ mod test { #[rstest] fn prelude_unpack_tuple(prelude_llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_ins([Type::new_tuple(vec![bool_t(); 2])]) + .with_ins([Type::new_runtime_tuple(vec![bool_t(); 2])]) .with_outs(vec![bool_t(), bool_t()]) .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { @@ -606,7 +606,7 @@ mod test { #[rstest] fn prelude_panic(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: Term = qb_t().into(); + let type_arg_q: Term = qb_t(); let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) @@ -632,7 +632,7 @@ mod test { #[rstest] fn prelude_exit(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "EXIT"); - let type_arg_q: Term = qb_t().into(); + let type_arg_q: Term = qb_t(); let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); let exit_op = PRELUDE .instantiate_extension_op(&EXIT_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) diff --git a/hugr-llvm/src/sum.rs b/hugr-llvm/src/sum.rs index 9f66d477d..b4359a595 100644 --- a/hugr-llvm/src/sum.rs +++ b/hugr-llvm/src/sum.rs @@ -735,7 +735,7 @@ mod test { { // one-variant-elidable-fields -> empty_struct - let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, HugrType::UNIT]); + let hugr_type = HugrType::new_runtime_tuple(vec![HugrType::UNIT, HugrType::UNIT]); assert_eq!(ts.llvm_type(&hugr_type).unwrap(), empty_struct.clone()); } @@ -753,19 +753,19 @@ mod test { { // one-variant-one-field -> bare field - let hugr_type = HugrType::new_tuple(vec![usize_t()]); + let hugr_type = HugrType::new_runtime_tuple(vec![usize_t()]); assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64); } { // one-variant-one-non-elidable-field -> bare field - let hugr_type = HugrType::new_tuple(vec![HugrType::UNIT, usize_t()]); + let hugr_type = HugrType::new_runtime_tuple(vec![HugrType::UNIT, usize_t()]); assert_eq!(ts.llvm_type(&hugr_type).unwrap(), i64); } { // one-variant-multi-field -> struct-of-fields - let hugr_type = HugrType::new_tuple(vec![usize_t(), bool_t(), HugrType::UNIT]); + let hugr_type = HugrType::new_runtime_tuple(vec![usize_t(), bool_t(), HugrType::UNIT]); let llvm_type = iwc.struct_type(&[i64, i1], false).into(); assert_eq!(ts.llvm_type(&hugr_type).unwrap(), llvm_type); } diff --git a/hugr-llvm/src/utils/type_map.rs b/hugr-llvm/src/utils/type_map.rs index c4e64d90e..9aa4525c3 100644 --- a/hugr-llvm/src/utils/type_map.rs +++ b/hugr-llvm/src/utils/type_map.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use hugr_core::{ extension::ExtensionId, - types::{CustomType, TypeEnum, TypeName, TypeRow}, + types::{CustomType, Term, TypeName, TypeRow}, }; use anyhow::{Result, bail}; @@ -115,18 +115,18 @@ impl<'a, TM: TypeMapping + 'a> TypeMap<'a, TM> { /// Map `hugr_type` using the [`TypeMapping`] `TM`, the registered callbacks, /// and the auxiliary data `inv`. pub fn map_type<'c>(&self, hugr_type: &HugrType, inv: TM::InV<'c>) -> Result> { - match hugr_type.as_type_enum() { - TypeEnum::Extension(custom_type) => { + match hugr_type { + Term::RuntimeExtension(custom_type) => { let key = (custom_type.extension().clone(), custom_type.name().clone()); let Some(handler) = self.custom_hooks.get(&key) else { return self.type_map.default_out(inv, &custom_type.clone().into()); }; handler.map_type(inv, custom_type) } - TypeEnum::Sum(sum_type) => self + Term::RuntimeSum(sum_type) => self .map_sum_type(sum_type, inv) .map(|x| self.type_map.sum_into_out(x)), - TypeEnum::Function(function_type) => self + Term::RuntimeFunction(function_type) => self .map_function_type(&function_type.as_ref().clone().try_into()?, inv) .map(|x| self.type_map.func_into_out(x)), _ => self.type_map.default_out(inv, hugr_type), diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 0e148b276..f444f7e84 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -27,7 +27,7 @@ use hugr_core::std_extensions::arithmetic::{ int_types::{ConstInt, INT_TYPES}, }; use hugr_core::std_extensions::logic::LogicOp; -use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; +use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row}; use crate::ComposablePass as _; @@ -800,7 +800,7 @@ fn test_fold_idivmod_checked_u() { // x2 := idivmod_checked_u(x0, x1) // output x2 == error let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); - let elem_type = Type::new_tuple(intpair); + let elem_type = Type::new_runtime_tuple(intpair); let sum_type = sum_with_error([elem_type.clone()]); let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); @@ -848,7 +848,7 @@ fn test_fold_idivmod_checked_s() { // x2 := idivmod_checked_s(x0, x1) // output x2 == error let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into(); - let elem_type = Type::new_tuple(intpair); + let elem_type = Type::new_runtime_tuple(intpair); let sum_type = sum_with_error([elem_type.clone()]); let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); @@ -1592,8 +1592,8 @@ fn test_module() -> Result<(), Box> { // Define a top-level constant, (only) the second of which can be removed let c7 = mb.add_constant(Value::from(ConstInt::new_u(5, 7)?)); let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?)); - let ad1 = mb.add_alias_declare("unused", TypeBound::Linear)?; - let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; + //let ad1 = mb.add_alias_declare("unused", TypeBound::Linear)?; + //let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; let mut main = mb.define_function( "main", Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]), @@ -1609,7 +1609,7 @@ fn test_module() -> Result<(), Box> { assert!(hugr.get_optype(hugr.entrypoint()).is_module()); assert_eq!( hugr.children(hugr.entrypoint()).collect_vec(), - [c7.node(), ad1.node(), ad2.node(), main.node()] + [c7.node(), main.node()] //ad1.node(), ad2.node(), ); let tags = hugr .children(main.node()) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index ddcfc15b1..e6d9f9187 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::Lattice; use ascent::lattice::BoundedLattice; use hugr_core::Node; -use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeRow}; use itertools::{Itertools, zip_eq}; use std::cmp::Ordering; use std::collections::HashMap; @@ -199,7 +199,7 @@ impl PartialSum { /// /// # Errors /// - /// If this `PartialSum` had multiple possible tags; or if `typ` was not a [`TypeEnum::Sum`] + /// If this `PartialSum` had multiple possible tags; or if `typ` was not a [`Type::RuntimeSum`] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [`PartialValue::try_into_concrete`]. #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases @@ -211,9 +211,8 @@ impl PartialSum { return Err(ExtractValueError::MultipleVariants(self)); } let (tag, v) = self.0.into_iter().exactly_one().unwrap(); - if let TypeEnum::Sum(st) = typ.as_type_enum() - && let Some(r) = st.get_variant(tag) - && let Ok(r) = TypeRow::try_from(r.clone()) + if let Some(st) = typ.as_runtime_sum() + && let Some(Ok(r)) = st.get_variant(tag).cloned().map(TypeRow::try_from) && v.len() == r.len() { return Ok(Sum { diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index f7dd470c1..4853f8d15 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -232,7 +232,15 @@ fn escape_dollar(str: impl AsRef) -> String { fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match arg { - TypeArg::Runtime(ty) => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), + TypeArg::RuntimeExtension(cty) => { + f.write_fmt(format_args!("e({})", escape_dollar(cty.to_string()))) + } + TypeArg::RuntimeSum(sty) => { + f.write_fmt(format_args!("t({})", escape_dollar(sty.to_string()))) + } + TypeArg::RuntimeFunction(fty) => { + f.write_fmt(format_args!("f({})", escape_dollar(fty.to_string()))) + } TypeArg::BoundedNat(n) => f.write_fmt(format_args!("n({n})")), TypeArg::String(arg) => f.write_fmt(format_args!("s({})", escape_dollar(arg))), TypeArg::List(elems) => f.write_fmt(format_args!("list({})", TypeArgsSeq(elems))), @@ -282,8 +290,8 @@ mod test { }; use hugr_core::extension::prelude::{ConstUsize, UnpackTuple, UnwrapBuilder, usize_t}; use hugr_core::ops::handle::{FuncID, NodeHandle}; - use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; - use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeEnum}; + use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, ExtensionOp, FuncDefn, Tag}; + use hugr_core::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; @@ -292,11 +300,11 @@ mod test { use super::{is_polymorphic, mangle_name}; fn pair_type(ty: Type) -> Type { - Type::new_tuple(vec![ty.clone(), ty]) + Type::new_runtime_tuple(vec![ty.clone(), ty]) } fn triple_type(ty: Type) -> Type { - Type::new_tuple(vec![ty.clone(), ty.clone(), ty]) + Type::new_runtime_tuple(vec![ty.clone(), ty.clone(), ty]) } #[test] @@ -325,22 +333,18 @@ mod test { // let tag = Tag::new(0, vec![vec![elem_ty; 2].into()]); // let tag = fb.add_dataflow_op(tag, [elem, elem]).unwrap(); // ...but since this will never execute, we can test recursion here - let tag = fb.call( - &FuncID::::from(fb.container_node()), - &[tv0().into()], - [elem], - )?; + let tag = fb.call(&FuncID::::from(fb.container_node()), &[tv0()], [elem])?; fb.finish_with_outputs(tag.outputs())? }; let tr = { - let sig = Signature::new([tv0()], [Type::new_tuple(vec![tv0(); 3])]); + let sig = Signature::new([tv0()], [Type::new_runtime_tuple(vec![tv0(); 3])]); let mut fb = mb.define_function( "triple", PolyFuncType::new([TypeBound::Copyable.into()], sig), )?; let [elem] = fb.input_wires_arr(); - let pair = fb.call(db.handle(), &[tv0().into()], [elem])?; + let pair = fb.call(db.handle(), &[tv0()], [elem])?; let [elem1, elem2] = fb .add_dataflow_op(UnpackTuple::new(vec![tv0(); 2].into()), pair.outputs())? @@ -353,11 +357,9 @@ mod test { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; let mut fb = mb.define_function("main", Signature::new([usize_t()], outs))?; let [elem] = fb.input_wires_arr(); - let [res1] = fb - .call(tr.handle(), &[usize_t().into()], [elem])? - .outputs_arr(); - let pair = fb.call(db.handle(), &[usize_t().into()], [elem])?; - let pty = pair_type(usize_t()).into(); + let [res1] = fb.call(tr.handle(), &[usize_t()], [elem])?.outputs_arr(); + let pair = fb.call(db.handle(), &[usize_t()], [elem])?; + let pty = pair_type(usize_t()); let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr(); fb.finish_with_outputs([res1, res2])? }; @@ -374,10 +376,10 @@ mod test { let mut funcs = list_funcs(&mono); let expected_mangled_names = [ - mangle_name("double", &[usize_t().into()]), - mangle_name("triple", &[usize_t().into()]), - mangle_name("double", &[pair_type(usize_t()).into()]), - mangle_name("triple", &[pair_type(usize_t()).into()]), + mangle_name("double", &[usize_t()]), + mangle_name("triple", &[usize_t()]), + mangle_name("double", &[pair_type(usize_t())]), + mangle_name("triple", &[pair_type(usize_t())]), ]; for n in &expected_mangled_names { @@ -455,8 +457,7 @@ mod test { let op_def = collections::borrow_array::EXTENSION .get_op("borrow") .unwrap(); - let op = hugr_core::ops::ExtensionOp::new(op_def.clone(), vec![sv(0), tv(1).into()]) - .unwrap(); + let op = hugr_core::ops::ExtensionOp::new(op_def.clone(), vec![sv(0), tv(1)]).unwrap(); // borrow the element at that index and return it along with the array let [arr, get] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); pf2.finish_with_outputs([get, arr]).unwrap() @@ -475,7 +476,7 @@ mod test { // pf1: two calls to pf2, one depending on pf1's TypeArg, the other not // first call stays generic in size but specifies the type as an array of 2 usizes let inner = pf1 - .call(pf2.handle(), &[sv(0), arr2u().into()], pf1.input_wires()) + .call(pf2.handle(), &[sv(0), arr2u()], pf1.input_wires()) .unwrap(); let [inner_arr, outer_arr] = inner.outputs_arr(); // discard the outer array output even though it is not all borrowed to get around linearity (would panic if you actually ran this) @@ -483,8 +484,7 @@ mod test { .get_op("discard_all_borrowed") .unwrap(); let discard_op = - hugr_core::ops::ExtensionOp::new(discard_op_def.clone(), vec![sv(0), arr2u().into()]) - .unwrap(); + hugr_core::ops::ExtensionOp::new(discard_op_def.clone(), vec![sv(0), arr2u()]).unwrap(); let [] = pf1 .add_dataflow_op(discard_op, [outer_arr]) .unwrap() @@ -493,7 +493,7 @@ mod test { let elem = pf1 .call( pf2.handle(), - &[TypeArg::BoundedNat(2), usize_t().into()], + &[TypeArg::BoundedNat(2), usize_t()], [inner_arr], ) .unwrap(); @@ -501,7 +501,7 @@ mod test { let [result, inner_arr] = elem.outputs_arr(); let discard_op = hugr_core::ops::ExtensionOp::new( discard_op_def.clone(), - vec![TypeArg::BoundedNat(2), usize_t().into()], + vec![TypeArg::BoundedNat(2), usize_t()], ) .unwrap(); let [] = pf1 @@ -521,15 +521,11 @@ mod test { let popleft = BArrayOpDef::pop_left.to_concrete(arr2u(), n); let ar2 = outer.add_dataflow_op(popleft.clone(), [arr2]).unwrap(); let sig = popleft.to_extension_op().unwrap().signature().into_owned(); - let TypeEnum::Sum(st) = sig.output().get(0).unwrap().as_type_enum() else { - panic!() - }; + let st = sig.output().get(0).unwrap().as_runtime_sum().unwrap(); let [left_arr, ar2_unwrapped] = outer .build_unwrap_sum(1, st.clone(), ar2.out_wire(0)) .unwrap(); - let discard_op = - hugr_core::ops::ExtensionOp::new(discard_op_def.clone(), vec![sa(2), usize_t().into()]) - .unwrap(); + let discard_op = ExtensionOp::new(discard_op_def.clone(), vec![sa(2), usize_t()]).unwrap(); let [] = outer .add_dataflow_op(discard_op, [left_arr]) .unwrap() @@ -551,9 +547,9 @@ mod test { vec![ &mangle_name("pf1", &[TypeArg::BoundedNat(5)]), &mangle_name("pf1", &[TypeArg::BoundedNat(4)]), - &mangle_name("pf2", &[TypeArg::BoundedNat(5), arr2u().into()]), // from pf1<5> - &mangle_name("pf2", &[TypeArg::BoundedNat(4), arr2u().into()]), // from pf1<4> - &mangle_name("pf2", &[TypeArg::BoundedNat(2), usize_t().into()]), // from both pf1<4> and <5> + &mangle_name("pf2", &[TypeArg::BoundedNat(5), arr2u()]), // from pf1<5> + &mangle_name("pf2", &[TypeArg::BoundedNat(4), arr2u()]), // from pf1<4> + &mangle_name("pf2", &[TypeArg::BoundedNat(2), usize_t()]), // from both pf1<4> and <5> "get_usz", "pf2", "mainish", @@ -601,9 +597,7 @@ mod test { let mut builder = module_builder .define_function("main", Signature::new_endo([Type::UNIT])) .unwrap(); - let func_ptr = builder - .load_func(foo.handle(), &[Type::UNIT.into()]) - .unwrap(); + let func_ptr = builder.load_func(foo.handle(), &[Type::UNIT]).unwrap(); let [r] = { let signature = Signature::new_endo([Type::UNIT]); builder @@ -629,12 +623,12 @@ mod test { #[rstest] #[case::bounded_nat(vec![0.into()], "$foo$$n(0)")] - #[case::type_unit(vec![Type::UNIT.into()], "$foo$$t(Unit)")] - #[case::type_int(vec![INT_TYPES[2].clone().into()], "$foo$$t(int(2))")] + #[case::type_unit(vec![Type::UNIT], "$foo$$t(Unit)")] + #[case::type_int(vec![INT_TYPES[2].clone()], "$foo$$e(int(2))")] #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] - #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$list($n(0)$t(Unit))")] - #[case::sequence(vec![TypeArg::Tuple(vec![0.into(),Type::UNIT.into()])], "$foo$$tuple($n(0)$t(Unit))")] + #[case::sequence(vec![vec![0.into(), Type::UNIT].into()], "$foo$$list($n(0)$t(Unit))")] + #[case::sequence(vec![TypeArg::Tuple(vec![0.into(),Type::UNIT])], "$foo$$tuple($n(0)$t(Unit))")] #[should_panic] #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::StringType)], "$foo$$v(1)")] diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs index 24fb98490..fbb07fa95 100644 --- a/hugr-passes/src/non_local/localize.rs +++ b/hugr-passes/src/non_local/localize.rs @@ -293,9 +293,7 @@ fn add_control_prefixes( else { panic!("impossible") }; - let Some(sum_type) = control_type.as_sum() else { - panic!("impossible") - }; + let sum_type = control_type.as_runtime_sum().unwrap(); let mut type_for_source = |source: &(Wire, Type)| { let (w, t) = source; diff --git a/hugr-passes/src/normalize_cfgs.rs b/hugr-passes/src/normalize_cfgs.rs index e88854ad4..0f3392e42 100644 --- a/hugr-passes/src/normalize_cfgs.rs +++ b/hugr-passes/src/normalize_cfgs.rs @@ -471,7 +471,7 @@ fn take_inputs(h: &mut H, n: H::Node) -> (NodePorts, NodePo fn tuple_elems(h: &H, n: H::Node, p: OutgoingPort) -> TypeRow { match h.get_optype(n).port_kind(p) { - Some(EdgeKind::Value(ty)) => ty.as_sum().unwrap().as_tuple().unwrap().clone(), + Some(EdgeKind::Value(ty)) => ty.as_runtime_sum().unwrap().as_tuple().unwrap().clone(), p => panic!("Expected Value port not {:?}", p), } .try_into() diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 96584b709..33018313c 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -21,8 +21,7 @@ use hugr_core::ops::{ ExtensionOp, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, - TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeRow, TypeTransformer, }; use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; @@ -816,8 +815,8 @@ impl ReplaceTypes { Ok(any_change) } Value::Extension { e } => Ok({ - let new_const = match e.get_type().as_type_enum() { - TypeEnum::Extension(exty) => match self.consts.get(exty) { + let new_const = match e.get_type().as_extension() { + Some(exty) => match self.consts.get(exty) { Some(const_fn) => Some(const_fn(e, self)), None => self .param_consts @@ -957,6 +956,7 @@ mod test { }; use hugr_core::types::{ EdgeKind, PolyFuncType, Signature, SumType, Term, Type, TypeArg, TypeBound, TypeRow, + type_param::check_term_type, }; use hugr_core::{Direction, Extension, HugrView, Port, Visibility, type_row}; use itertools::Itertools; @@ -974,14 +974,16 @@ mod test { } fn read_op(ext: &Arc, t: Type) -> ExtensionOp { - ExtensionOp::new(ext.get_op(READ).unwrap().clone(), [t.into()]).unwrap() + ExtensionOp::new(ext.get_op(READ).unwrap().clone(), [t]).unwrap() } fn just_elem_type(args: &[TypeArg]) -> &Type { - let [TypeArg::Runtime(ty)] = args else { - panic!("Expected just elem type") - }; - ty + if let [ty] = args + && check_term_type(ty, &TypeBound::Linear.into()).is_ok() + { + return ty; + } + panic!("Expected just elem type") } fn ext() -> Arc { @@ -998,7 +1000,7 @@ mod test { w, ) .unwrap() - .instantiate(vec![Type::new_var_use(0, TypeBound::Copyable).into()]) + .instantiate(vec![Type::new_var_use(0, TypeBound::Copyable)]) .unwrap(); ext.add_op( READ.into(), @@ -1058,7 +1060,7 @@ mod test { fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); - lw.set_replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); + lw.set_replace_type(pv.instantiate([bool_t()]).unwrap(), i64_t()); lw.set_replace_parametrized_type( pv, Box::new(|args: &[TypeArg]| Some(list_type(just_elem_type(args).clone()))), @@ -1085,8 +1087,8 @@ mod test { fn module_func_cfg_call() { let ext = ext(); let coln = ext.get_type(PACKED_VEC).unwrap(); - let c_int = Type::from(coln.instantiate([i64_t().into()]).unwrap()); - let c_bool = Type::from(coln.instantiate([bool_t().into()]).unwrap()); + let c_int = Type::from(coln.instantiate([i64_t()]).unwrap()); + let c_bool = Type::from(coln.instantiate([bool_t()]).unwrap()); let mut mb = ModuleBuilder::new(); let sig = Signature::new_endo([Type::new_var_use(0, TypeBound::Linear)]); let fb = mb @@ -1099,7 +1101,7 @@ mod test { let mut fb = mb.define_function("main", sig).unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let [indices] = fb - .call(id.handle(), &[c_int.into()], [indices]) + .call(id.handle(), &[c_int], [indices]) .unwrap() .outputs_arr(); let [idx2] = fb @@ -1115,7 +1117,7 @@ mod test { let mut entry = cfg.entry_builder([[bool_t()].into()], type_row![]).unwrap(); let [idx2, bools] = entry.input_wires_arr(); let [bools] = entry - .call(id.handle(), &[c_bool.into()], [bools]) + .call(id.handle(), &[c_bool], [bools]) .unwrap() .outputs_arr(); let bool_read_op = entry @@ -1152,7 +1154,7 @@ mod test { fn dfg_conditional_case() { let ext = ext(); let coln = ext.get_type(PACKED_VEC).unwrap(); - let pv = |t: Type| Type::new_extension(coln.instantiate([t.into()]).unwrap()); + let pv = |t: Type| Type::new_extension(coln.instantiate([t]).unwrap()); let sum_rows = [[pv(pv(bool_t())), i64_t()].into(), [pv(i64_t())].into()]; let mut dfb = DFGBuilder::new(inout_sig( vec![Type::new_sum(sum_rows.clone()), pv(bool_t()), pv(i64_t())], @@ -1218,7 +1220,7 @@ mod test { Value::sum( 0, [ListValue::new(usize_t(), [cu(1), cu(3), cu(3), cu(7)]).into()], - st, + st.clone(), ) .unwrap(), ); @@ -1325,9 +1327,13 @@ mod test { }, ); fn option_contents(ty: &Type) -> Option { - let row = ty.as_sum()?.get_variant(1).unwrap().clone(); - let elem = row.into_owned().into_iter().exactly_one().unwrap(); - Some(elem.try_into_type().unwrap()) + let row = ty.as_runtime_sum()?.get_variant(1).unwrap().clone(); + TypeRow::try_from(row) + .unwrap() + .iter() + .exactly_one() + .ok() + .cloned() } let i32_t = || INT_TYPES[5].clone(); let opt_i32 = Type::from(option_type([i32_t()])); @@ -1430,9 +1436,9 @@ mod test { // monomorphization to happen first so that ReplaceTypes can act upon the concrete types. let e = ext(); let pv = e.get_type(PACKED_VEC).unwrap(); - let inner = pv.instantiate([usize_t().into()]).unwrap(); + let inner = pv.instantiate([usize_t()]).unwrap(); let outer = pv - .instantiate([Type::new_extension(inner.clone()).into()]) + .instantiate([Type::new_extension(inner.clone())]) .unwrap(); let mut dfb = DFGBuilder::new(inout_sig([outer.into(), i64_t()], [usize_t()])).unwrap(); let read_func = dfb @@ -1499,9 +1505,9 @@ mod test { fn op_to_call_monomorphic(#[values(false, true)] i64_to_usize: bool) { let e = ext(); let pv = e.get_type(PACKED_VEC).unwrap(); - let inner = pv.instantiate([usize_t().into()]).unwrap(); + let inner = pv.instantiate([usize_t()]).unwrap(); let outer = pv - .instantiate([Type::new_extension(inner.clone()).into()]) + .instantiate([Type::new_extension(inner.clone())]) .unwrap(); let read_outer = read_op(&e, inner.clone().into()); let mut dfb = DFGBuilder::new(inout_sig( @@ -1528,9 +1534,11 @@ mod test { let mut lw = lowerer(&e); lw.set_replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| { Ok(Some({ - let [Term::Runtime(ty)] = args else { + let [ty] = args else { return Err(SignatureError::InvalidTypeArgs.into()); }; + check_term_type(ty, &TypeBound::Copyable.into()) + .map_err(SignatureError::TypeArgMismatch)?; let mut fb = FunctionBuilder::new("not inserted", endo_sig(vec![])).unwrap(); let read_func = fb .module_root_builder() @@ -1598,7 +1606,7 @@ mod test { fn regions() { let ext = ext(); let coln = ext.get_type(PACKED_VEC).unwrap(); - let c_u = Type::new_extension(coln.instantiate(&[usize_t().into()]).unwrap()); + let c_u = Type::new_extension(coln.instantiate(&[usize_t()]).unwrap()); let mut h = { let db = DFGBuilder::new(endo_sig([c_u.clone()])).unwrap(); let inps = db.input_wires(); @@ -1661,16 +1669,18 @@ mod test { .unwrap() .as_ref(), move |args, _| { - let [sz, Term::Runtime(ty)] = args else { + let [sz, ty] = args else { panic!("Expected two args to array-get") }; - if sz != &Term::BoundedNat(64) { + if sz != &Term::BoundedNat(64) + || check_term_type(ty, &TypeBound::Linear.into()).is_err() + { return Ok(None); } let pv = ext .get_type(PACKED_VEC) .unwrap() - .instantiate([ty.clone().into()]) + .instantiate([ty.clone()]) .unwrap(); let mut dfb = DFGBuilder::new(Signature::new( diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 21cd4541e..b25d898eb 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -20,7 +20,9 @@ use hugr_core::std_extensions::collections::borrow_array::{ BArrayClone, BArrayDiscard, BArrayOpBuilder, BorrowArray, borrow_array_type, }; use hugr_core::std_extensions::collections::list::ListValue; -use hugr_core::types::{SumType, Transformable, Type, TypeArg}; +use hugr_core::types::{ + SumType, Transformable, Type, TypeArg, TypeBound, type_param::check_term_type, +}; use hugr_core::{Visibility, type_row}; use itertools::Itertools; @@ -110,9 +112,10 @@ pub fn linearize_generic_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat(n), ty] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; + check_term_type(ty, &TypeBound::Linear.into()).unwrap(); if num_outports == 0 { // "Simple" discard let array_scan = GenericArrayScan::::new(ty.clone(), Type::UNIT, vec![], *n); @@ -125,7 +128,7 @@ pub fn linearize_generic_array( let mut mb = dfb.module_root_builder(); let mut fb = mb .define_function_vis( - mangle_name(DISCARD_TO_UNIT_PREFIX, &[ty.clone().into()]), + mangle_name(DISCARD_TO_UNIT_PREFIX, std::slice::from_ref(ty)), inout_sig([ty.clone()], [Type::UNIT]), Visibility::Public, ) @@ -169,7 +172,7 @@ pub fn linearize_generic_array( let mut mb = dfb.module_root_builder(); let mut fb = mb .define_function_vis( - mangle_name(MAKE_NONE_PREFIX, &[ty.clone().into()]), + mangle_name(MAKE_NONE_PREFIX, std::slice::from_ref(ty)), inout_sig(vec![], [option_ty.clone()]), Visibility::Public, ) @@ -202,7 +205,7 @@ pub fn linearize_generic_array( .define_function_vis( mangle_name( COPY_SCAN_PREFIX, - &[(*n).into(), ty.clone().into(), (num_new as u64).into()], + &[(*n).into(), ty.clone(), (num_new as u64).into()], ), endo_sig(io), Visibility::Public, @@ -226,7 +229,7 @@ pub fn linearize_generic_array( // Wrap each remaining copy into an option let set_op = OpType::from(GenericArrayOpDef::::set.to_concrete(option_ty.clone(), *n)); let either_st = set_op.dataflow_signature().unwrap().output[0] - .as_sum() + .as_runtime_sum() .unwrap() .clone(); let opt_arrays = opt_arrays @@ -292,7 +295,7 @@ pub fn linearize_generic_array( let mut mb = dfb.module_root_builder(); let mut fb = mb .define_function_vis( - mangle_name(UNWRAP_PREFIX, &[ty.clone().into()]), + mangle_name(UNWRAP_PREFIX, std::slice::from_ref(ty)), inout_sig([option_ty.clone()], [ty.clone()]), Visibility::Public, ) @@ -332,9 +335,10 @@ pub fn copy_discard_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat(n), ty] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; + check_term_type(ty, &TypeBound::Linear.into()).unwrap(); if ty.copyable() { // For arrays with copyable elements, we can just use the clone/discard ops if num_outports == 0 { @@ -379,9 +383,10 @@ pub fn copy_discard_borrow_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { + let [TypeArg::BoundedNat(n), ty] = args else { panic!("Illegal TypeArgs to borrow array: {args:?}") }; + check_term_type(ty, &TypeBound::Linear.into()).unwrap(); if ty.copyable() { // For arrays with copyable elements, we can just use the clone/discard ops if num_outports == 0 { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 454034277..ad25099ed 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -7,7 +7,7 @@ use hugr_core::builder::{ use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::borrow_array::borrow_array_type_def; -use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::types::{CustomType, Signature, Term, Type, TypeArg, TypeRow}; use hugr_core::{HugrView, IncomingPort, Node, Wire, hugr::hugrmut::HugrMut, ops::Tag}; use itertools::Itertools; @@ -106,7 +106,7 @@ pub trait Linearizer { /// A configuration for implementing [Linearizer] by delegating to /// type-specific callbacks, and by composing them in order to handle compound types -/// such as [`TypeEnum::Sum`]s. +/// such as [`Term::RuntimeSum`]s. #[derive(Clone)] pub struct DelegatingLinearizer { // Keyed by lowered type, as only needed when there is an op outputting such @@ -165,8 +165,8 @@ pub enum LinearizeError { #[error(transparent)] SignatureError(#[from] SignatureError), /// We cannot linearize (insert copy and discard functions) for - /// [Variable](TypeEnum::Variable)s, [Row variables](TypeEnum::RowVar), - /// or [Alias](TypeEnum::Alias)es. + /// [Variable](Term::Variable)s (including row variables). + // or Aliases, as there is no Term::Alias #[error("Cannot linearize type {_0}")] UnsupportedType(Box), /// Neither does linearization make sense for copyable types @@ -191,7 +191,7 @@ impl DelegatingLinearizer { /// Configures this instance that the specified monomorphic type can be copied and/or /// discarded via the provided [`NodeTemplate`]s - directly or as part of a compound type - /// e.g. [`TypeEnum::Sum`]. + /// e.g. [`Term::RuntimeSum`]. /// `copy` should have exactly one inport, of type `src`, and two outports, of same type; /// `discard` should have exactly one inport, of type 'src', and no outports. /// @@ -272,8 +272,8 @@ impl Linearizer for DelegatingLinearizer { } assert!(num_outports != 1); - match typ.as_type_enum() { - TypeEnum::Sum(sum_type) => { + match typ { + Term::RuntimeSum(sum_type) => { let variants = sum_type .variants() .map(|trv| trv.clone().try_into()) @@ -319,7 +319,7 @@ impl Linearizer for DelegatingLinearizer { cb.finish_hugr().unwrap(), ))) } - TypeEnum::Extension(cty) => { + Term::RuntimeExtension(cty) => { if let Some((copy, discard)) = self.copy_discard.get(cty) { Ok(if num_outports == 0 { discard.clone() @@ -352,7 +352,7 @@ impl Linearizer for DelegatingLinearizer { Ok(tmpl) } } - TypeEnum::Function(_) => panic!("Ruled out above as copyable"), + Term::RuntimeFunction(_) => panic!("Ruled out above as copyable"), _ => Err(LinearizeError::UnsupportedType(Box::new(typ.clone()))), } } @@ -389,7 +389,7 @@ mod test { use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections::array::array_type; use hugr_core::std_extensions::collections::borrow_array::{BArrayOpDef, borrow_array_type}; - use hugr_core::types::type_param::TypeParam; + use hugr_core::types::type_param::{TypeParam, check_term_type}; use hugr_core::types::{ FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRow, }; @@ -844,24 +844,23 @@ mod test { ); let drop_op = drop_ext.get_op("drop").unwrap(); lowerer.set_replace_parametrized_op(drop_op, |args, rt| { - let [TypeArg::Runtime(ty)] = args else { + let [ty] = args else { panic!("Expected just one type") }; + check_term_type(ty, &TypeBound::Linear.into()).unwrap(); Ok(Some(rt.get_linearizer().copy_discard_op(ty, 0)?)) }); let build_hugr = |ty: Type| { let mut dfb = DFGBuilder::new(Signature::new([ty.clone()], [])).unwrap(); let [inp] = dfb.input_wires_arr(); - let drop_op = drop_ext - .instantiate_extension_op("drop", [ty.into()]) - .unwrap(); + let drop_op = drop_ext.instantiate_extension_op("drop", [ty]).unwrap(); dfb.add_dataflow_op(drop_op, [inp]).unwrap(); dfb.finish_hugr().unwrap() }; // We can drop a tuple of 2* lin_t let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); - let mut h = build_hugr(Type::new_tuple(vec![lin_t.clone(); 2])); + let mut h = build_hugr(Type::new_runtime_tuple(vec![lin_t.clone(); 2])); lowerer.run(&mut h).unwrap(); h.validate().unwrap(); let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index 154148504..0eb16c8f6 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -365,7 +365,7 @@ mod test { vec![ bool_t(), bool_t(), - Type::new_tuple(vec![bool_t(), bool_t()]), + Type::new_runtime_tuple(vec![bool_t(), bool_t()]), ], )) .unwrap(); diff --git a/hugr/benches/benchmarks/types.rs b/hugr/benches/benchmarks/types.rs index d05896f01..dd229a400 100644 --- a/hugr/benches/benchmarks/types.rs +++ b/hugr/benches/benchmarks/types.rs @@ -1,8 +1,7 @@ // Required for black_box uses #![allow(clippy::unit_arg)] use hugr::extension::prelude::{qb_t, usize_t}; -use hugr::ops::AliasDecl; -use hugr::types::{Signature, Type, TypeBound}; +use hugr::types::{Signature, Type}; use criterion::{AxisScale, Criterion, PlotConfiguration, criterion_group}; use std::hint::black_box; @@ -11,10 +10,10 @@ use std::hint::black_box; fn make_complex_type() -> Type { let qb = qb_t(); let int = usize_t(); - let q_register = Type::new_tuple(vec![qb; 8]); - let b_register = Type::new_tuple(vec![int; 8]); - let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); - let sum = Type::new_sum([[q_register], [q_alias]]); + let q_register = Type::new_runtime_tuple(vec![qb; 8]); + let b_register = Type::new_runtime_tuple(vec![int; 8]); + //let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); + let sum = Type::new_sum([[q_register], [Type::UNIT]]); Type::new_function(Signature::new(vec![sum], vec![b_register])) }