Skip to content
Closed
21 changes: 15 additions & 6 deletions hugr-core/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@ use std::fmt::Debug;

use crate::ops::Value;
use crate::types::TypeArg;

use crate::IncomingPort;
use crate::OutgoingPort;

use crate::ops;
use crate::{Hugr, IncomingPort, OutgoingPort};

/// Output of constant folding an operation, None indicates folding was either
/// not possible or unsuccessful. An empty vector indicates folding was
/// successful and no values are output.
pub type ConstFoldResult = Option<Vec<(OutgoingPort, ops::Value)>>;
pub type ConstFoldResult = Option<Vec<(OutgoingPort, Value)>>;

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult {
Expand All @@ -27,6 +23,19 @@ pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult

/// Trait implemented by extension operations that can perform constant folding.
pub trait ConstFold: Send + Sync {
/// Given the containing Hugr, type arguments `type_args` and [`crate::ops::Const`]
/// values for inputs at [`crate::IncomingPort`]s, try to evaluate the operation.
///
/// Defaults to calling [Self::fold] (ignoring the Hugr)
fn fold_with_hugr(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Value)],
_hugr: &Hugr,
) -> ConstFoldResult {
self.fold(type_args, consts)
}

/// Given type arguments `type_args` and
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s,
/// try to evaluate the operation.
Expand Down
20 changes: 17 additions & 3 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use super::{
SignatureError,
};

use crate::ops::{OpName, OpNameRef};
use crate::ops::{OpName, OpNameRef, Value};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::Hugr;
use crate::{Hugr, IncomingPort};
mod serialize_signature_func;

/// Trait necessary for binary computations of OpDef signature
Expand Down Expand Up @@ -460,11 +460,25 @@ impl OpDef {
pub fn constant_fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Value)],
consts: &[(IncomingPort, Value)],
) -> ConstFoldResult {
(self.constant_folder.as_ref())?.fold(type_args, consts)
}

/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s and
/// access to the containing Hugr.
pub fn constant_fold_with_hugr(
&self,
type_args: &[TypeArg],
consts: &[(IncomingPort, Value)],
hugr: &Hugr,
) -> ConstFoldResult {
self.constant_folder
.as_ref()?
.fold_with_hugr(type_args, consts, hugr)
}

/// Returns a reference to the signature function of this [`OpDef`].
pub fn signature_func(&self) -> &SignatureFunc {
&self.signature_func
Expand Down
12 changes: 11 additions & 1 deletion hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use {

use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
use crate::types::{type_param::TypeArg, Signature};
use crate::{ops, IncomingPort, Node};
use crate::{ops, Hugr, IncomingPort, Node};

use super::dataflow::DataflowOpTrait;
use super::tag::OpTag;
Expand Down Expand Up @@ -96,6 +96,16 @@ impl ExtensionOp {
self.def().constant_fold(self.args(), consts)
}

/// Attempt to evaluate this operation, See ['OpDef::constant_fold_with_hugr`]
pub fn constant_fold_with_hugr(
&self,
consts: &[(IncomingPort, ops::Value)],
hugr: &Hugr,
) -> ConstFoldResult {
self.def()
.constant_fold_with_hugr(self.args(), consts, hugr)
}

/// Creates a new [`OpaqueOp`] as a downgraded version of this
/// [`ExtensionOp`].
///
Expand Down
73 changes: 47 additions & 26 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
//! Constant-folding pass.
//! An (example) use of the [dataflow analysis framework](super::dataflow).

pub mod value_handle;
mod value_handle;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's quite annoying that we have to make this a breaking change.
It looks like even leaving this as public wouldn't help, since ValueHandle is not non_exhaustive and added new fields -.-

Is it OK to wait in merging this PR until we accumulate other breaking changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was hiding that because I thought that's the way that it should have been, would have avoided this being a breaking change now...

How urgent....I was gonna say, the bit we need first is the constant_fold_with_hugr and that's non-breaking so I can break that out and leave the rest as a demo (just tests that the ValueHandle/CallIndirect works would confirm we can do the same approach in/for BRAT).

However....now I realize we need to get the ValueHandle::NodeRef into the BRAT-specific constant-folding code. At present extension-provided constant-folders only get Hugr Values, so rather than adding constant_fold_with_hugr I think I need some kind of plugin interface to the constant-folding pass that lets us get at those ValueHandles after all. So I think I'll have to restore pubness and figure out what the plugin interface is.

use itertools::{Either, Itertools};
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;

use hugr_core::{
hugr::{
hugrmut::HugrMut,
views::{DescendantsGraph, ExtractHugr, HierarchyView},
},
hugr::hugrmut::HugrMut,
ops::{
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
OpType, Value,
constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value,
},
types::{EdgeKind, TypeArg},
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
Expand Down Expand Up @@ -207,13 +204,6 @@ pub fn constant_fold_pass<H: HugrMut>(h: &mut H) {

struct ConstFoldContext<'a, H>(&'a H);

impl<H: HugrView> std::ops::Deref for ConstFoldContext<'_, H> {
type Target = H;
fn deref(&self) -> &H {
self.0
}
}

impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
type Node = H::Node;

Expand All @@ -238,17 +228,7 @@ impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldCo
node: H::Node,
type_args: &[TypeArg],
) -> Option<ValueHandle<H::Node>> {
if !type_args.is_empty() {
// TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709)
return None;
};
// Returning the function body as a value, here, would be sufficient for inlining IndirectCall
// but not for transforming to a direct Call.
let func = DescendantsGraph::<FuncID<true>>::try_new(&**self, node).ok()?;
Some(ValueHandle::new_const_hugr(
ConstLocation::Node(node),
Box::new(func.extract_hugr()),
))
Some(ValueHandle::NodeRef(node, type_args.to_vec()))
}
}

Expand All @@ -273,11 +253,52 @@ impl<H: HugrView<Node = Node>> DFContext<ValueHandle<H::Node>> for ConstFoldCont
.map(|v| (IncomingPort::from(i), v))
})
.collect::<Vec<_>>();
for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() {
for (p, v) in op
.constant_fold_with_hugr(&known_ins, self.0.base_hugr())
.unwrap_or_default()
{
outs[p.index()] =
partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v);
}
}

fn interpret_call_indirect(
&mut self,
func: &PartialValue<ValueHandle<H::Node>>,
args: &[PartialValue<ValueHandle<H::Node>>],
outs: &mut [PartialValue<ValueHandle<H::Node>>],
) {
let PartialValue::Value(func) = func else {
return;
};
let inputs = args.iter().cloned().enumerate().map(|(i, v)| (i.into(), v));
let vals: Vec<_> = match func {
ValueHandle::NodeRef(node, _) => {
let mut m = Machine::new(self.0);
m.prepopulate_inputs(*node, inputs).unwrap();
let results = m.run(ConstFoldContext(self.0), []);
(0..outs.len())
.map(|p| results.read_out_wire(Wire::new(*node, p)))
.collect()
}
ValueHandle::Unhashable {
leaf: Either::Right(hugr),
..
} => {
let h = hugr.as_ref();
let results = Machine::new(h).run(ConstFoldContext(h), inputs);
(0..outs.len())
.map(|p| results.read_out_wire(Wire::new(h.root(), p)))
.collect()
}
_ => return,
};
for (val, out) in vals.into_iter().zip_eq(outs) {
if let Some(val) = val {
*out = val;
}
}
}
}

#[cfg(test)]
Expand Down
25 changes: 21 additions & 4 deletions hugr-passes/src/const_fold/value_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use hugr_core::core::HugrNode;
use hugr_core::ops::constant::OpaqueValue;
use hugr_core::ops::Value;
use hugr_core::types::TypeArg;
use hugr_core::{Hugr, Node};
use itertools::Either;

Expand Down Expand Up @@ -46,6 +47,15 @@ impl Hash for HashedConst {
/// An [Eq]-able and [Hash]-able leaf (non-[Sum](Value::Sum)) Value
#[derive(Clone, Debug)]
pub enum ValueHandle<N = Node> {
/// The result of [LoadFunction] on a [FuncDefn] (or [FuncDecl]), i.e. a "function
/// pointer" to a function in the Hugr. (Cannot be represented as a [Value::Function]
/// without lots of cloning, because it may have static edges from other
/// functions/constants/etc.)
///
/// [LoadFunction]: hugr_core::ops::LoadFunction
/// [FuncDefn]: hugr_core::ops::FuncDefn
/// [FuncDecl]: hugr_core::ops::FuncDefn
NodeRef(N, Vec<TypeArg>),
/// A [Value::Extension] that has been hashed
Hashable(HashedConst),
/// Either a [Value::Extension] that can't be hashed, or a [Value::Function].
Expand Down Expand Up @@ -108,6 +118,7 @@ impl<N: HugrNode> AbstractValue for ValueHandle<N> {}
impl<N: HugrNode> PartialEq for ValueHandle<N> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::NodeRef(n1, args1), Self::NodeRef(n2, args2)) => n1 == n2 && args1 == args2,
(Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2,
(
Self::Unhashable {
Expand Down Expand Up @@ -138,6 +149,10 @@ impl<N: HugrNode> Eq for ValueHandle<N> {}
impl<N: HugrNode> Hash for ValueHandle<N> {
fn hash<I: Hasher>(&self, state: &mut I) {
match self {
ValueHandle::NodeRef(n, args) => {
n.hash(state);
args.hash(state);
}
ValueHandle::Hashable(hc) => hc.hash(state),
ValueHandle::Unhashable {
node,
Expand All @@ -153,9 +168,11 @@ impl<N: HugrNode> Hash for ValueHandle<N> {

// Unfortunately we need From<ValueHandle> for Value to be able to pass
// Value's into interpret_leaf_op. So that probably doesn't make sense...
impl<N: HugrNode> From<ValueHandle<N>> for Value {
fn from(value: ValueHandle<N>) -> Self {
match value {
impl<N: HugrNode> TryFrom<ValueHandle<N>> for Value {
type Error = N;
fn try_from(value: ValueHandle<N>) -> Result<Value, N> {
Ok(match value {
ValueHandle::NodeRef(n, _) => return Err(n),
ValueHandle::Hashable(HashedConst { val, .. })
| ValueHandle::Unhashable {
leaf: Either::Left(val),
Expand All @@ -169,7 +186,7 @@ impl<N: HugrNode> From<ValueHandle<N>> for Value {
} => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone()))
.map_err(|e| e.to_string())
.unwrap(),
}
})
}
}

Expand Down
12 changes: 12 additions & 0 deletions hugr-passes/src/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ pub trait DFContext<V>: ConstLoader<V> {
_outs: &mut [PartialValue<V>],
) {
}

/// Given lattice values for the called function, and arguments to pass to it, update
/// lattice values for the (dataflow) outputs
/// of a [CallIndirect](hugr_core::ops::CallIndirect).
/// (The default does nothing, i.e. leaves `Top` for all outputs.)
fn interpret_call_indirect(
&mut self,
_func: &PartialValue<V>,
_args: &[PartialValue<V>],
_outs: &mut [PartialValue<V>],
) {
}
}

/// A location where a [Value] could be find in a Hugr. That is,
Expand Down
9 changes: 8 additions & 1 deletion hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
outs
}))
}
o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive"
OpType::CallIndirect(_) => Some(ValueRow::from_iter(if row_contains_bottom(ins) {
vec![PartialValue::Bottom; num_outs]
} else {
let mut outs = vec![PartialValue::Top; num_outs];
ctx.interpret_call_indirect(&ins[0], &ins[1..], &mut outs[..]);
outs
})),
o => todo!("Unhandled: {:?}", o), // OpType is "non-exhaustive"
}
}
Loading