diff --git a/cedar-policy-core/src/entities.rs b/cedar-policy-core/src/entities.rs index 2ed2b88637..a97889b9ab 100644 --- a/cedar-policy-core/src/entities.rs +++ b/cedar-policy-core/src/entities.rs @@ -18,8 +18,8 @@ use crate::ast::*; use crate::extensions::Extensions; -use crate::transitive_closure::{compute_tc, enforce_tc_and_dag}; -use std::collections::{hash_map, HashMap}; +use crate::transitive_closure::{compute_tc, enforce_tc_and_dag, repair_tc}; +use std::collections::{hash_map, HashMap, HashSet}; use std::sync::Arc; use serde::Serialize; @@ -143,16 +143,30 @@ impl Entities { extensions: &Extensions<'_>, ) -> Result { let checker = schema.map(|schema| EntitySchemaConformanceChecker::new(schema, extensions)); + let mut entities_touched: HashSet = HashSet::new(); for entity in collection.into_iter() { if let Some(checker) = checker.as_ref() { checker.validate_entity(&entity)?; } + entities_touched.insert(entity.uid().clone()); update_entity_map(&mut self.entities, entity)?; } match tc_computation { TCComputation::AssumeAlreadyComputed => (), TCComputation::EnforceAlreadyComputed => enforce_tc_and_dag(&self.entities)?, - TCComputation::ComputeNow => compute_tc(&mut self.entities, true)?, + TCComputation::ComputeNow => { + for entity in self.entities.values() { + if !entities_touched.is_disjoint( + &entity + .ancestors() + .map(EntityUID::clone) + .collect::>(), + ) { + entities_touched.insert(entity.uid().clone()); + } + } + repair_tc(entities_touched, &mut self.entities, true)? + } }; Ok(self) } @@ -167,12 +181,14 @@ impl Entities { collection: impl IntoIterator, tc_computation: TCComputation, ) -> Result { + let mut entities_touched: HashSet = HashSet::new(); for uid_to_remove in collection.into_iter() { match self.entities.remove(&uid_to_remove) { None => (), Some(entity_to_remove) => { for entity in self.entities.values_mut() { if entity.is_descendant_of(&uid_to_remove) { + entities_touched.insert(entity.uid().clone()); // remove any direct or indirect link between `entity` and `entity_to_remove` Arc::make_mut(entity).remove_indirect_ancestor(&uid_to_remove); Arc::make_mut(entity).remove_parent(&uid_to_remove); @@ -188,7 +204,7 @@ impl Entities { match tc_computation { TCComputation::AssumeAlreadyComputed => (), TCComputation::EnforceAlreadyComputed => enforce_tc_and_dag(&self.entities)?, - TCComputation::ComputeNow => compute_tc(&mut self.entities, true)?, + TCComputation::ComputeNow => repair_tc(entities_touched, &mut self.entities, true)?, } Ok(self) } diff --git a/cedar-policy-core/src/transitive_closure.rs b/cedar-policy-core/src/transitive_closure.rs index 6c1e3a0bc6..595b6a9361 100644 --- a/cedar-policy-core/src/transitive_closure.rs +++ b/cedar-policy-core/src/transitive_closure.rs @@ -55,40 +55,315 @@ where K: Clone + Eq + Hash + Debug + Display, V: TCNode, { - compute_tc_internal::(nodes); + let all_node_ids = nodes.keys().map(K::clone).collect::>(); + + // If the caller does not want to check that the graph is a DAG, + // we assume that the graph is acyclic during the below call. + // This allows the below call to do a single scan of each node + // rather than two scans of each node. + compute_tc_internal(all_node_ids.into_iter(), nodes, HashSet::new(), enforce_dag); + if enforce_dag { return enforce_dag_from_tc(nodes); } Ok(()) } -/// Given graph as a map from keys with type `K` to implementations of `TCNode` -/// with type `V`, compute the transitive closure of the hierarchy. In case of -/// error, the result contains an error structure `Err` which contains the -/// keys (with type `K`) for the nodes in the graph which caused the error. -fn compute_tc_internal(nodes: &mut HashMap) +/// Given Graph as a map from keys with type `K` to implementations of `TCNode` +/// with type `V`, repair the transitive closure of the hierarchy. The below code +/// will assume that for each `node` in `nodes` except the nodes appearing in +/// `nodes_to_fix`, the out-going edges of `node` will contain all ancestors of `node`. +/// That is we may assume the transitive closure for all such nodes is correct while +/// computing the transitive closure of each node appearing in `nodes_to_fix`. +/// In case of error, the result contains an error structure `Err` which contains +/// the keys (with type `K`) for the nodes in the graph which caused the error. +/// If `enforce_dag` then also check that the heirarchy is a DAG +pub fn repair_tc( + nodes_to_fix: HashSet, + nodes: &mut HashMap, + enforce_dag: bool, +) -> Result<(), K> where + K: Clone + Eq + Hash + Debug + Display, + V: TCNode, +{ + let seen: HashSet = nodes + .keys() + .filter_map(|node_id| { + if nodes_to_fix.contains(node_id) { + None + } else { + Some(node_id.clone()) + } + }) + .collect(); + + // If the caller does not want to check that the graph is a DAG, + // we assume that the graph is acyclic during the below call. + // This allows the below call to do a single scan of each node + // rather than two scans of each node. + compute_tc_internal::(nodes_to_fix.into_iter(), nodes, seen, enforce_dag); + + if enforce_dag { + return enforce_dag_from_tc(nodes); + } + Ok(()) +} + +/// Saturate the out-going edges of each node in `node_ids` to include +/// all reachable ancestors within the graph represnted by `nodes`. +/// Assume that all nodes appearing in `seen` already satisfy this property. +/// If `detect_cyles` is false, we assume the the graph represented by `nodes` +/// is a DAG so that we may perform a single scan over the graph. Otherwise, +/// we scan each node twice. This is sufficient for detecting cycles and for computing +/// the exact TC for graphs containing simple cycles. For more complex cyclic graphs, +/// the below code computes enough of the transtive closure to ensure that if one +/// calls `enforce_dag_from_tc` on `nodes` after this function returns then it will +/// correctly detect any cycles (simple or compelx). +fn compute_tc_internal( + node_ids: impl Iterator, + nodes: &mut HashMap, + mut seen: HashSet, + detect_cyles: bool, +) where K: Clone + Eq + Hash, V: TCNode, { - // To avoid needing both immutable and mutable borrows of `nodes`, - // we collect all the needed updates in this structure - // (maps keys to ancestor UIDs to add to it) - // and then do all the updates at once in a second loop - let mut ancestors: HashMap> = HashMap::new(); - for node in nodes.values() { - let this_node_ancestors: &mut HashSet = ancestors.entry(node.get_key()).or_default(); - add_ancestors_to_set(node, nodes, this_node_ancestors); + for node_id in node_ids { + if detect_cyles { + add_ancestors(&node_id, nodes, &mut seen); + } else if seen.insert(node_id.clone()) { + add_ancestors(&node_id, nodes, &mut seen); + } + } +} + +fn cyclic_tc(nodes: &mut HashMap) +where + K: Clone + Eq + Hash + Debug, + V: TCNode, +{ + let node_ids = nodes.keys().map(K::clone).collect::>(); + let mut order_visited = HashMap::new(); + let mut root = HashMap::new(); + let mut vstack = Vec::new(); + let mut cstack = Vec::new(); + let mut component = HashMap::new(); + let mut comp_succ = Vec::new(); + let mut comp_elts = Vec::new(); + for node_id in node_ids { + if !order_visited.contains_key(&node_id) { + cyclic_tc_internal( + &node_id, + &nodes, + &mut order_visited, + &mut root, + &mut vstack, + &mut cstack, + &mut component, + &mut comp_succ, + &mut comp_elts, + ); + } } - for node in nodes.values_mut() { - // PANIC SAFETY All nodes in `ancestors` came from `nodes` + // component_tc => nodes_tc + for comp_id in 0..comp_elts.len() { + let mut elt_succ = HashSet::new(); + // PANIC SAFETY! `comp_succ` and `comp_elts` must have the same length, thus `comp_id` is a valid index to `comp_succ`. + #[allow(clippy::indexing_slicing)] + for comp_parent_id in comp_succ[comp_id].iter() { + // PANIC SAFETY! `comp_parent_id` must be a valid component id to be inserted into `comp_succ` therefore must exist within `comp_elts`. + #[allow(clippy::indexing_slicing)] + for node_id in comp_elts[*comp_parent_id].iter() { + // not fine to consume here + elt_succ.insert(node_id.clone()); + } + } + // PANIC SAFETY! `comp_id` \in [0, |`comp_elts`|) is a valid index into `comp_elts`. + #[allow(clippy::indexing_slicing)] + for node_id in comp_elts[comp_id].iter() { + let node = match nodes.get_mut(node_id) { + Some(node) => node, + None => continue, + }; + for parent_id in elt_succ.iter() { + node.add_edge_to(parent_id.clone()); + } + } + } +} + +fn cyclic_tc_internal( + node_id: &K, + nodes: &HashMap, + order_visited: &mut HashMap, + root: &mut HashMap, + vstack: &mut Vec, + cstack: &mut Vec, + component: &mut HashMap, + comp_succ: &mut Vec>, + comp_elts: &mut Vec>, +) where + K: Clone + Eq + Hash + Debug, + V: TCNode, +{ + let node_order = order_visited.len(); + // when was the root of this node's component visited? + // initially the root of this node's component is this node itself + // keeping track in auxillary function to avoid re-fetching in a loop + let mut root_order = node_order; + order_visited.insert(node_id.clone(), node_order); + root.insert(node_id.clone(), node_id.clone()); + vstack.push(node_id.clone()); + let height = cstack.len(); + let mut self_loop = false; + let out_edges = match nodes.get(node_id) { + Some(node) => node.out_edges().collect(), + None => Vec::new(), + }; + for parent_id in out_edges { + if node_id == parent_id { + self_loop = true; + } else { + // The edge from node_id to parent_id is a forward edge iff + // node_id is visited before parent_id and we do not visit + // parent_id from node_id (i.e., we do not recursively call + // cyclic_tc_interanl on parent_id from this call). + let mut maybe_forward_edge = true; + if !order_visited.contains_key(parent_id) { + maybe_forward_edge = false; + cyclic_tc_internal( + parent_id, + nodes, + order_visited, + root, + vstack, + cstack, + component, + comp_succ, + comp_elts, + ); + } + match component.get(parent_id) { + None => { + // PANIC SAFETY! `parent_id` must have been visited either by a previous call or just above + #[allow(clippy::expect_used)] + let parent_root = root + .get(parent_id) + .expect("Parent has been visited so it must have a root."); + // PANIC SAFETY! in order for `parent_root` to be the parent of `parent_id` it must have been visited. + #[allow(clippy::expect_used)] + let parent_root_order = order_visited + .get(parent_root) + .expect("The parent's root must have been visited."); + if *parent_root_order < root_order { + root_order = *parent_root_order; + root.insert(node_id.clone(), parent_root.clone()); + } + } + Some(parent_component) => { + // PANIC SAFETY! `parent_id` must have been visited either by a previous call or just above + #[allow(clippy::expect_used)] + let parent_order = order_visited + .get(parent_id) + .expect("The parent must have been traversed by this point."); + // if not a forward edge + if !(maybe_forward_edge && &node_order < parent_order) { + cstack.push(*parent_component); + } + } + } + } + } // end for loop over parents + // PANIC SAFETY! `node_id` must have a root. It was inserted at the begining of this function + #[allow(clippy::expect_used)] + let node_root = root + .get(node_id) + .expect("Node must have a root by this point."); + // if this node is the root of its connected component + if node_id == node_root { + let component_id = comp_elts.len(); + let mut succ = HashSet::new(); + let mut elmts = HashSet::new(); + // PANIC SAFETY! The vertex stack must not be empty because at least node_id must be on the stack! #[allow(clippy::expect_used)] - for ancestor_uid in ancestors - .get(&node.get_key()) - .expect("shouldn't have added any new values to the `nodes` map") - { - node.add_edge_to(ancestor_uid.clone()); + if self_loop || vstack.last().expect("vertex stack must be non-empty") != node_id { + succ.insert(component_id); } + let mut cstack_tail = cstack.split_off(height); + // sort by topological order of the components, which should be equivalent to the reverse order of their ids + // cstack_tail are all of the components reachable (1 step) from any node within this component + cstack_tail.sort_by(|a, b| b.cmp(a)); + // iterate through components in topological order + for i in 0..cstack_tail.len() { + // update this component's successors with next component avoiding duplicate components + // PANIC SAFETY! both `i` and `i - 1` are gauranteed to be valid indices into `cstack_tail`. + #[allow(clippy::indexing_slicing)] + if i == 0 || cstack_tail[i - 1] == cstack_tail[i] { + // PANIC SAFETY! `i` is a valid index into `cstack_tail`. + #[allow(clippy::indexing_slicing)] + let X = cstack_tail[i]; + if succ.insert(X) { + // PANIC SAFETY! `X` is a component id created by a previous call to `cyclic_tc_internal` and thus must be a valid index to `comp_succ`. + #[allow(clippy::indexing_slicing)] + succ.extend(comp_succ[X].clone()); + } + } + } + loop { + // PANIC SAFETY! The vertex stack `vstack` must contain at least `node_id` + #[allow(clippy::expect_used)] + let ancestor_id = vstack.pop().expect("Vetex stack must be non-empty"); + component.insert(ancestor_id.clone(), component_id); + elmts.insert(ancestor_id.clone()); + if *node_id == ancestor_id { + break; + } + } + comp_succ.push(succ); + comp_elts.push(elmts); + } +} + +/// Saturate the out-going edges of the node identified by `node_id` within the graph +/// represented by `nodes` assuming that each node appearing in `seen` already satisfies +/// this property. The process works by performing a depth-first search over the ancestors +/// of `node_id` (and stopping if any ancestor is already in the `seen` set). +fn add_ancestors(node_id: &K, nodes: &mut HashMap, seen: &mut HashSet) +where + K: Clone + Eq + Hash, + V: TCNode, +{ + let mut ancestors: HashSet = HashSet::new(); + let out_edges: Vec = match nodes.get(node_id) { + Some(node) => node.out_edges().map(K::clone).collect(), + None => return, + }; + for ancestor_id in out_edges { + if seen.insert(ancestor_id.clone()) { + add_ancestors(&ancestor_id, nodes, seen); + } + // a slight optimization to avoid adding the ancestors of `ancestor_id` if + // `ancestor_id` was an ancestor of any parent already explored by this loop. + if !ancestors.contains(&ancestor_id) { + let ancestor = match nodes.get(&ancestor_id) { + Some(ancestor) => ancestor, + None => return, + }; + for grand_ancestor_id in ancestor.out_edges() { + ancestors.insert(grand_ancestor_id.clone()); + } + } + } + // PANIC SAFETY this node should always exist because of the check to get `out_edges` + #[allow(clippy::expect_used)] + let node = nodes + .get_mut(node_id) + .expect("This node should always exist."); + // Do the actual saturation of out-going edges of `node` here to avoid + // issues with rust's borrow checker. + for ancestor_id in ancestors { + node.add_edge_to(ancestor_id); } } @@ -136,25 +411,6 @@ where Ok(()) } -/// For the given `node` in the given `hierarchy`, add all of the `node`'s -/// transitive ancestors to the given set. Assume that any nodes already in -/// `ancestors` don't need to be searched -- they have been already handled. -fn add_ancestors_to_set(node: &V, hierarchy: &HashMap, ancestors: &mut HashSet) -where - K: Clone + Eq + Hash, - V: TCNode, -{ - for ancestor_uid in node.out_edges() { - if ancestors.insert(ancestor_uid.clone()) { - // discovered a new ancestor, so add the ancestors of `ancestor` as - // well - if let Some(ancestor) = hierarchy.get(ancestor_uid) { - add_ancestors_to_set(ancestor, hierarchy, ancestors); - } - } - } -} - /// Once the transitive closure (as defined above) is computed/enforced for the graph, we have: /// \forall u,v,w \in Vertices . (u,v) \in Edges /\ (v,w) \in Edges -> (u,w) \in Edges /// @@ -200,7 +456,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let c = &entities[&EntityUID::with_eid("C")]; @@ -232,7 +488,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let c = &entities[&EntityUID::with_eid("C")]; @@ -269,7 +525,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let c = &entities[&EntityUID::with_eid("C")]; @@ -312,7 +568,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let foo = &entities[&EntityUID::with_eid("foo")]; let bar = &entities[&EntityUID::with_eid("bar")]; let baz = &entities[&EntityUID::with_eid("baz")]; @@ -356,7 +612,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let d = &entities[&EntityUID::with_eid("D")]; @@ -411,7 +667,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let f = &entities[&EntityUID::with_eid("F")]; @@ -471,7 +727,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let c = &entities[&EntityUID::with_eid("C")]; @@ -523,7 +779,7 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); let a = &entities[&EntityUID::with_eid("A")]; let b = &entities[&EntityUID::with_eid("B")]; let d = &entities[&EntityUID::with_eid("D")]; @@ -562,7 +818,7 @@ mod tests { b.add_indirect_ancestor(EntityUID::with_eid("B")); let mut entities = HashMap::from([(a.uid().clone(), a), (b.uid().clone(), b)]); // computing TC should succeed without panicking, infinitely recursing, etc - compute_tc_internal(&mut entities); + compute_tc(&mut entities, false).expect("Failed to compute transitive closure"); // fails cycle check match enforce_dag_from_tc(&entities) { Ok(_) => panic!("enforce_dag_from_tc should have returned an error"), @@ -577,7 +833,7 @@ mod tests { assert!(a.is_descendant_of(&EntityUID::with_eid("B"))); // but it shouldn't have added a B -> A edge assert!(!b.is_descendant_of(&EntityUID::with_eid("A"))); - // we also check that, whatever compute_tc_internal did with this invalid input, the + // we also check that, whatever compute_tc did with this invalid input, the // final result still passes enforce_tc assert!(enforce_tc(&entities).is_ok()); // still fails cycle check @@ -613,7 +869,16 @@ mod tests { (d.uid().clone(), d), ]); // computing TC should succeed without panicking, infinitely recursing, etc - compute_tc_internal(&mut entities); + compute_tc_internal( + entities + .keys() + .map(EntityUID::clone) + .collect::>() + .into_iter(), + &mut entities, + HashSet::new(), + true, + ); // fails cycle check match enforce_dag_from_tc(&entities) { Ok(_) => panic!("enforce_dag_from_tc should have returned an error"), @@ -634,7 +899,7 @@ mod tests { assert!(a.is_descendant_of(&EntityUID::with_eid("D"))); // and we should also have added a B -> A edge assert!(b.is_descendant_of(&EntityUID::with_eid("A"))); - // we also check that, whatever compute_tc_internal did with this invalid input, the + // we also check that, whatever compute_tc did with this invalid input, the // final result still passes enforce_tc assert!(enforce_tc(&entities).is_ok()); // still fails cycle check @@ -687,7 +952,16 @@ mod tests { // currently doesn't pass TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); + compute_tc_internal( + entities + .keys() + .map(EntityUID::clone) + .collect::>() + .into_iter(), + &mut entities, + HashSet::new(), + true, + ); // now it should pass TC enforcement assert!(enforce_tc(&entities).is_ok()); // still fails cycle check @@ -743,10 +1017,19 @@ mod tests { // fails TC enforcement assert!(enforce_tc(&entities).is_err()); // compute TC - compute_tc_internal(&mut entities); - // now it should pass TC enforcement + cyclic_tc(&mut entities); + // compute_tc_internal( + // entities + // .keys() + // .map(EntityUID::clone) + // .collect::>() + // .into_iter(), + // &mut entities, + // HashSet::new(), + // true, + // ); assert!(enforce_tc(&entities).is_ok()); - // but still fail cycle check + // the graph may or may not pass the TC check but it will always fail cycle check match enforce_dag_from_tc(&entities) { Ok(_) => panic!("enforce_dag_from_tc should have returned an error"), Err(TcError::HasCycle(_)) => (), // Every vertex is in a cycle