Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc};
use thiserror::Error;

use hugr_core::{
hugr::{
hugrmut::HugrMut,
views::{DescendantsGraph, ExtractHugr, HierarchyView},
},
hugr::hugrmut::HugrMut,
ops::{
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
OpType, Value,
constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value,
},
types::{EdgeKind, TypeArg},
types::EdgeKind,
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
};
use value_handle::ValueHandle;
Expand Down Expand Up @@ -205,14 +201,8 @@ pub fn constant_fold_pass<H: HugrMut>(h: &mut H) {
c.run(h).unwrap()
}

struct ConstFoldContext<'a, H>(&'a H);

impl<H: HugrView> std::ops::Deref for ConstFoldContext<'_, H> {
type Target = H;
fn deref(&self) -> &H {
self.0
}
}
// Probably intend to remove this in a future PR, but not certain, so leaving in for now
struct ConstFoldContext<'a, H>(#[allow(unused)] &'a H);

impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
type Node = H::Node;
Expand All @@ -232,24 +222,6 @@ impl<H: HugrView<Node = Node>> ConstLoader<ValueHandle<H::Node>> for ConstFoldCo
) -> Option<ValueHandle<H::Node>> {
Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone())))
}

fn value_from_function(
&self,
node: H::Node,
type_args: &[TypeArg],
) -> Option<ValueHandle<H::Node>> {
if !type_args.is_empty() {
// TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709)
return None;
};
// Returning the function body as a value, here, would be sufficient for inlining IndirectCall
// but not for transforming to a direct Call.
let func = DescendantsGraph::<FuncID<true>>::try_new(&**self, node).ok()?;
Some(ValueHandle::new_const_hugr(
ConstLocation::Node(node),
Box::new(func.extract_hugr()),
))
}
}

impl<H: HugrView<Node = Node>> DFContext<ValueHandle<H::Node>> for ConstFoldContext<'_, H> {
Expand Down
17 changes: 9 additions & 8 deletions hugr-passes/src/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod results;
pub use results::{AnalysisResults, TailLoopTermination};

mod partial_value;
pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum};
pub use partial_value::{AbstractValue, LoadedFunction, PartialSum, PartialValue, Sum};

use hugr_core::ops::constant::OpaqueValue;
use hugr_core::ops::{ExtensionOp, Value};
Expand All @@ -31,8 +31,8 @@ pub trait DFContext<V>: ConstLoader<V> {
&mut self,
_node: Self::Node,
_e: &ExtensionOp,
_ins: &[PartialValue<V>],
_outs: &mut [PartialValue<V>],
_ins: &[PartialValue<V, Self::Node>],
_outs: &mut [PartialValue<V, Self::Node>],
) {
}
}
Expand All @@ -55,8 +55,8 @@ impl<N> From<N> for ConstLocation<'_, N> {
}

/// Trait for loading [PartialValue]s from constant [Value]s in a Hugr.
/// Implementors will likely want to override some/all of [Self::value_from_opaque],
/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults
/// Implementors will likely want to override either/both of [Self::value_from_opaque]
/// and [Self::value_from_const_hugr]: the defaults
/// are "correct" but maximally conservative (minimally informative).
pub trait ConstLoader<V> {
/// The type of nodes in the Hugr.
Expand All @@ -81,6 +81,7 @@ pub trait ConstLoader<V> {
/// [FuncDefn]: hugr_core::ops::FuncDefn
/// [FuncDecl]: hugr_core::ops::FuncDecl
/// [LoadFunction]: hugr_core::ops::LoadFunction
#[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")]
fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option<V> {
None
}
Expand All @@ -94,7 +95,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader<V>>(
cl: &CL,
loc: impl Into<ConstLocation<'a, CL::Node>>,
cst: &Value,
) -> PartialValue<V>
) -> PartialValue<V, CL::Node>
where
CL::Node: 'a,
{
Expand All @@ -120,8 +121,8 @@ where

/// A row of inputs to a node contains bottom (can't happen, the node
/// can't execute) if any element [contains_bottom](PartialValue::contains_bottom).
pub fn row_contains_bottom<'a, V: AbstractValue + 'a>(
elements: impl IntoIterator<Item = &'a PartialValue<V>>,
pub fn row_contains_bottom<'a, V: 'a, N: 'a>(
elements: impl IntoIterator<Item = &'a PartialValue<V, N>>,
) -> bool {
elements.into_iter().any(PartialValue::contains_bottom)
}
Expand Down
78 changes: 55 additions & 23 deletions hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ use ascent::lattice::BoundedLattice;
use itertools::Itertools;

use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
use hugr_core::ops::{OpTrait, OpType, TailLoop};
use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop};
use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire};

use super::value_row::ValueRow;
use super::{
partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext,
PartialValue,
LoadedFunction, PartialValue,
};

type PV<V> = PartialValue<V>;
type PV<V, N> = PartialValue<V, N>;

type NodeInputs<V, N> = Vec<(IncomingPort, PV<V, N>)>;

/// Basic structure for performing an analysis. Usage:
/// 1. Make a new instance via [Self::new()]
Expand All @@ -25,10 +27,7 @@ type PV<V> = PartialValue<V>;
/// [Self::prepopulate_inputs] can be used on each externally-callable
/// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top].
/// 3. Call [Self::run] to produce [AnalysisResults]
pub struct Machine<H: HugrView, V: AbstractValue>(
H,
HashMap<H::Node, Vec<(IncomingPort, PartialValue<V>)>>,
);
pub struct Machine<H: HugrView, V: AbstractValue>(H, HashMap<H::Node, NodeInputs<V, H::Node>>);

impl<H: HugrView, V: AbstractValue> Machine<H, V> {
/// Create a new Machine to analyse the given Hugr(View)
Expand All @@ -40,7 +39,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
impl<H: HugrView, V: AbstractValue> Machine<H, V> {
/// Provide initial values for a wire - these will be `join`d with any computed
/// or any value previously prepopulated for the same Wire.
pub fn prepopulate_wire(&mut self, w: Wire<H::Node>, v: PartialValue<V>) {
pub fn prepopulate_wire(&mut self, w: Wire<H::Node>, v: PartialValue<V, H::Node>) {
for (n, inp) in self.0.linked_inputs(w.node(), w.source()) {
self.1.entry(n).or_default().push((inp, v.clone()));
}
Expand All @@ -54,7 +53,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
pub fn prepopulate_inputs(
&mut self,
parent: H::Node,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V>)>,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>,
) -> Result<(), OpType> {
match self.0.get_optype(parent) {
OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => {
Expand Down Expand Up @@ -102,7 +101,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
pub fn run(
mut self,
context: impl DFContext<V, Node = H::Node>,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V>)>,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>,
) -> AnalysisResults<V, H> {
let root = self.0.root();
if self.0.get_optype(root).is_module() {
Expand Down Expand Up @@ -135,10 +134,12 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
}
}

pub(super) type InWire<V, N> = (N, IncomingPort, PartialValue<V, N>);

pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
mut ctx: impl DFContext<V, Node = H::Node>,
hugr: H,
in_wire_value_proto: Vec<(H::Node, IncomingPort, PV<V>)>,
in_wire_value_proto: Vec<InWire<V, H::Node>>,
) -> AnalysisResults<V, H> {
// ascent-(macro-)generated code generates a bunch of warnings,
// keep code in here to a minimum.
Expand All @@ -155,9 +156,9 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
relation parent_of_node(H::Node, H::Node); // <Node> is parent of <Node>
relation input_child(H::Node, H::Node); // <Node> has 1st child <Node> that is its `Input`
relation output_child(H::Node, H::Node); // <Node> has 2nd child <Node> that is its `Output`
lattice out_wire_value(H::Node, OutgoingPort, PV<V>); // <Node> produces, on <OutgoingPort>, the value <PV>
lattice in_wire_value(H::Node, IncomingPort, PV<V>); // <Node> receives, on <IncomingPort>, the value <PV>
lattice node_in_value_row(H::Node, ValueRow<V>); // <Node>'s inputs are <ValueRow>
lattice out_wire_value(H::Node, OutgoingPort, PV<V, H::Node>); // <Node> produces, on <OutgoingPort>, the value <PV>
lattice in_wire_value(H::Node, IncomingPort, PV<V, H::Node>); // <Node> receives, on <IncomingPort>, the value <PV>
lattice node_in_value_row(H::Node, ValueRow<V, H::Node>); // <Node>'s inputs are <ValueRow>

node(n) <-- for n in hugr.nodes();

Expand Down Expand Up @@ -322,6 +323,35 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>(
func_call(call, func),
output_child(func, outp),
in_wire_value(outp, p, v);

// CallIndirect --------------------
relation indirect_call(H::Node, H::Node); // <Node> is an `IndirectCall` to `FuncDefn` <Node>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a lattice, and to store the Callee.

As written, if you first get a LoadedFunction, then it goes to TOP, you'll not update the out_wire_values to top.

I think you should deal with polymorphism either by requring 0 type args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point about it needing to be a lattice, thank you! :). Done.
(Annoying about not having a test to show the difference but I bet it'd have show up sooner or later.)

Polymorphism is not wrong atm. Values from different type-instantiations are just joined together so will likely produce Top. If there's a LoadNat - well we don't even handle that yet - but any DFContext::interpret_leaf_op will just see a LoadNat<Var0> and not get the type args, so there's no sensible implementation other than returning Top. It's not ideal, but it's not wrong.

indirect_call(call, func_node) <--
node(call),
if let OpType::CallIndirect(_) = hugr.get_optype(*call),
in_wire_value(call, IncomingPort::from(0), v),
if let PartialValue::LoadedFunction(LoadedFunction {func_node, ..}) = v;

out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <--
indirect_call(call, func),
input_child(func, inp),
in_wire_value(call, p, v)
if p.index() > 0;

out_wire_value(call, OutgoingPort::from(p.index()), v) <--
indirect_call(call, func),
output_child(func, outp),
in_wire_value(outp, p, v);

// Default out-value is Bottom, but if we can't determine the called function,
// assign everything to Top
out_wire_value(call, p, PV::Top) <--
node(call),
if let OpType::CallIndirect(ci) = hugr.get_optype(*call),
in_wire_value(call, IncomingPort::from(0), v),
// Second alternative below addresses function::Value's:
if matches!(v, PartialValue::Top | PartialValue::Value(_)),
for p in ci.signature().output_ports();
};
let out_wire_values = all_results
.out_wire_value
Expand All @@ -341,9 +371,9 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
ctx: &mut impl DFContext<V, Node = H::Node>,
hugr: &H,
n: H::Node,
ins: &[PV<V>],
ins: &[PV<V, H::Node>],
num_outs: usize,
) -> Option<ValueRow<V>> {
) -> Option<ValueRow<V, H::Node>> {
match hugr.get_optype(n) {
// Handle basics here. We could instead leave these to DFContext,
// but at least we'd want these impls to be easily reusable.
Expand All @@ -362,8 +392,7 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
ins.iter().cloned(),
)])),
OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent
OpType::Call(_) => None, // handled via Input/Output of FuncDefn
OpType::Const(_) => None, // handled by LoadConstant:
OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dropped Const: it's not a dataflow node (it's a module/static), so never gets here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be more consistent to deal with callindirect here rather than in the datalog, is that possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only by a massive refactor of propagate_leaf_op. Which is private, so we could, but I don't feel this is inconsistent - I think it's just like what we have been doing for Call

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not worth a massive refactor. I think it's inconsistent only because I don't think "CallIndirect" is morally built in(I claim it could be in prelude)

OpType::LoadConstant(load_op) => {
assert!(ins.is_empty()); // static edge, so need to find constant
let const_node = hugr
Expand All @@ -380,10 +409,12 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
.unwrap()
.0;
// Node could be a FuncDefn or a FuncDecl, so do not pass the node itself
Some(ValueRow::singleton(
ctx.value_from_function(func_node, &load_op.type_args)
.map_or(PV::Top, PV::Value),
))
Some(ValueRow::singleton(PartialValue::LoadedFunction(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend you use PartialValue::new_load to get coverage on that function

LoadedFunction {
func_node,
args: load_op.type_args.clone(),
},
)))
}
OpType::ExtensionOp(e) => {
Some(ValueRow::from_iter(if row_contains_bottom(ins) {
Expand All @@ -401,6 +432,7 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
outs
}))
}
o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive"
// We only call propagate_leaf_op for dataflow op non-containers,
o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive
}
}
Loading
Loading