-
Notifications
You must be signed in to change notification settings - Fork 15
feat!: Handle CallIndirect in Dataflow Analysis #2059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
692a759
48143f5
5809594
cccefe4
ca442c4
a90dc97
6e2ad7f
2cd8547
ba33062
5c9c6ec
ee128ce
2757a33
3c79739
1191150
4ba5603
a8e4553
f032848
3066b65
3eb0c7f
887a386
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,16 +6,18 @@ use ascent::lattice::BoundedLattice; | |
| use itertools::Itertools; | ||
|
|
||
| use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; | ||
| use hugr_core::ops::{OpTrait, OpType, TailLoop}; | ||
| use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; | ||
| use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; | ||
|
|
||
| use super::value_row::ValueRow; | ||
| use super::{ | ||
| partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, | ||
| PartialValue, | ||
| LoadedFunction, PartialValue, | ||
| }; | ||
|
|
||
| type PV<V> = PartialValue<V>; | ||
| type PV<V, N> = PartialValue<V, N>; | ||
|
|
||
| type NodeInputs<V, N> = Vec<(IncomingPort, PV<V, N>)>; | ||
|
|
||
| /// Basic structure for performing an analysis. Usage: | ||
| /// 1. Make a new instance via [Self::new()] | ||
|
|
@@ -25,10 +27,7 @@ type PV<V> = PartialValue<V>; | |
| /// [Self::prepopulate_inputs] can be used on each externally-callable | ||
| /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. | ||
| /// 3. Call [Self::run] to produce [AnalysisResults] | ||
| pub struct Machine<H: HugrView, V: AbstractValue>( | ||
| H, | ||
| HashMap<H::Node, Vec<(IncomingPort, PartialValue<V>)>>, | ||
| ); | ||
| pub struct Machine<H: HugrView, V: AbstractValue>(H, HashMap<H::Node, NodeInputs<V, H::Node>>); | ||
|
|
||
| impl<H: HugrView, V: AbstractValue> Machine<H, V> { | ||
| /// Create a new Machine to analyse the given Hugr(View) | ||
|
|
@@ -40,7 +39,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> { | |
| impl<H: HugrView, V: AbstractValue> Machine<H, V> { | ||
| /// Provide initial values for a wire - these will be `join`d with any computed | ||
| /// or any value previously prepopulated for the same Wire. | ||
| pub fn prepopulate_wire(&mut self, w: Wire<H::Node>, v: PartialValue<V>) { | ||
| pub fn prepopulate_wire(&mut self, w: Wire<H::Node>, v: PartialValue<V, H::Node>) { | ||
| for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { | ||
| self.1.entry(n).or_default().push((inp, v.clone())); | ||
| } | ||
|
|
@@ -54,7 +53,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> { | |
| pub fn prepopulate_inputs( | ||
| &mut self, | ||
| parent: H::Node, | ||
| in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V>)>, | ||
| in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>, | ||
| ) -> Result<(), OpType> { | ||
| match self.0.get_optype(parent) { | ||
| OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { | ||
|
|
@@ -102,7 +101,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> { | |
| pub fn run( | ||
| mut self, | ||
| context: impl DFContext<V, Node = H::Node>, | ||
| in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V>)>, | ||
| in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>, | ||
| ) -> AnalysisResults<V, H> { | ||
| let root = self.0.root(); | ||
| if self.0.get_optype(root).is_module() { | ||
|
|
@@ -135,10 +134,12 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> { | |
| } | ||
| } | ||
|
|
||
| pub(super) type InWire<V, N> = (N, IncomingPort, PartialValue<V, N>); | ||
|
|
||
| pub(super) fn run_datalog<V: AbstractValue, H: HugrView>( | ||
| mut ctx: impl DFContext<V, Node = H::Node>, | ||
| hugr: H, | ||
| in_wire_value_proto: Vec<(H::Node, IncomingPort, PV<V>)>, | ||
| in_wire_value_proto: Vec<InWire<V, H::Node>>, | ||
| ) -> AnalysisResults<V, H> { | ||
| // ascent-(macro-)generated code generates a bunch of warnings, | ||
| // keep code in here to a minimum. | ||
|
|
@@ -155,9 +156,9 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>( | |
| relation parent_of_node(H::Node, H::Node); // <Node> is parent of <Node> | ||
| relation input_child(H::Node, H::Node); // <Node> has 1st child <Node> that is its `Input` | ||
| relation output_child(H::Node, H::Node); // <Node> has 2nd child <Node> that is its `Output` | ||
| lattice out_wire_value(H::Node, OutgoingPort, PV<V>); // <Node> produces, on <OutgoingPort>, the value <PV> | ||
| lattice in_wire_value(H::Node, IncomingPort, PV<V>); // <Node> receives, on <IncomingPort>, the value <PV> | ||
| lattice node_in_value_row(H::Node, ValueRow<V>); // <Node>'s inputs are <ValueRow> | ||
| lattice out_wire_value(H::Node, OutgoingPort, PV<V, H::Node>); // <Node> produces, on <OutgoingPort>, the value <PV> | ||
| lattice in_wire_value(H::Node, IncomingPort, PV<V, H::Node>); // <Node> receives, on <IncomingPort>, the value <PV> | ||
| lattice node_in_value_row(H::Node, ValueRow<V, H::Node>); // <Node>'s inputs are <ValueRow> | ||
|
|
||
| node(n) <-- for n in hugr.nodes(); | ||
|
|
||
|
|
@@ -322,6 +323,35 @@ pub(super) fn run_datalog<V: AbstractValue, H: HugrView>( | |
| func_call(call, func), | ||
| output_child(func, outp), | ||
| in_wire_value(outp, p, v); | ||
|
|
||
| // CallIndirect -------------------- | ||
| relation indirect_call(H::Node, H::Node); // <Node> is an `IndirectCall` to `FuncDefn` <Node> | ||
|
||
| indirect_call(call, func_node) <-- | ||
| node(call), | ||
| if let OpType::CallIndirect(_) = hugr.get_optype(*call), | ||
| in_wire_value(call, IncomingPort::from(0), v), | ||
| if let PartialValue::LoadedFunction(LoadedFunction {func_node, ..}) = v; | ||
|
|
||
| out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- | ||
| indirect_call(call, func), | ||
| input_child(func, inp), | ||
| in_wire_value(call, p, v) | ||
| if p.index() > 0; | ||
|
|
||
| out_wire_value(call, OutgoingPort::from(p.index()), v) <-- | ||
| indirect_call(call, func), | ||
| output_child(func, outp), | ||
| in_wire_value(outp, p, v); | ||
|
|
||
| // Default out-value is Bottom, but if we can't determine the called function, | ||
| // assign everything to Top | ||
| out_wire_value(call, p, PV::Top) <-- | ||
| node(call), | ||
| if let OpType::CallIndirect(ci) = hugr.get_optype(*call), | ||
| in_wire_value(call, IncomingPort::from(0), v), | ||
| // Second alternative below addresses function::Value's: | ||
| if matches!(v, PartialValue::Top | PartialValue::Value(_)), | ||
| for p in ci.signature().output_ports(); | ||
| }; | ||
| let out_wire_values = all_results | ||
| .out_wire_value | ||
|
|
@@ -341,9 +371,9 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>( | |
| ctx: &mut impl DFContext<V, Node = H::Node>, | ||
| hugr: &H, | ||
| n: H::Node, | ||
| ins: &[PV<V>], | ||
| ins: &[PV<V, H::Node>], | ||
| num_outs: usize, | ||
| ) -> Option<ValueRow<V>> { | ||
| ) -> Option<ValueRow<V, H::Node>> { | ||
| match hugr.get_optype(n) { | ||
| // Handle basics here. We could instead leave these to DFContext, | ||
| // but at least we'd want these impls to be easily reusable. | ||
|
|
@@ -362,8 +392,7 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>( | |
| ins.iter().cloned(), | ||
| )])), | ||
| OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent | ||
| OpType::Call(_) => None, // handled via Input/Output of FuncDefn | ||
| OpType::Const(_) => None, // handled by LoadConstant: | ||
| OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dropped
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be more consistent to deal with callindirect here rather than in the datalog, is that possible?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only by a massive refactor of
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not worth a massive refactor. I think it's inconsistent only because I don't think "CallIndirect" is morally built in(I claim it could be in prelude) |
||
| OpType::LoadConstant(load_op) => { | ||
| assert!(ins.is_empty()); // static edge, so need to find constant | ||
| let const_node = hugr | ||
|
|
@@ -380,10 +409,12 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>( | |
| .unwrap() | ||
| .0; | ||
| // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself | ||
| Some(ValueRow::singleton( | ||
| ctx.value_from_function(func_node, &load_op.type_args) | ||
| .map_or(PV::Top, PV::Value), | ||
| )) | ||
| Some(ValueRow::singleton(PartialValue::LoadedFunction( | ||
|
||
| LoadedFunction { | ||
| func_node, | ||
| args: load_op.type_args.clone(), | ||
| }, | ||
| ))) | ||
| } | ||
| OpType::ExtensionOp(e) => { | ||
| Some(ValueRow::from_iter(if row_contains_bottom(ins) { | ||
|
|
@@ -401,6 +432,7 @@ fn propagate_leaf_op<V: AbstractValue, H: HugrView>( | |
| outs | ||
| })) | ||
| } | ||
| o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" | ||
| // We only call propagate_leaf_op for dataflow op non-containers, | ||
| o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.