Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,54 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ##

return False



def traversePotentiallyDirected(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
if (edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE) and \
(edge.get_endpoint2() == Endpoint.ARROW or edge.get_endpoint2() == Endpoint.CIRCLE):
return edge.get_node2()
elif node == edge.get_node2():
if (edge.get_endpoint2() == Endpoint.TAIL or edge.get_endpoint2() == Endpoint.CIRCLE) and \
(edge.get_endpoint1() == Endpoint.ARROW or edge.get_endpoint1() == Endpoint.CIRCLE):
return edge.get_node1()
return None


def existsUncoveredPdPath(node_from: Node, node_next: Node, node_to: Node, G: Graph) -> bool:
Q = Queue()
V = set([node_from, node_next])

for node_u in G.get_adjacent_nodes(node_next):
edge = G.get_edge(node_next, node_u)
node_c = traversePotentiallyDirected(node_next, edge)

if node_c is None:
continue

if not V.__contains__(node_c):
V.add(node_c)
Q.put((node_c, [node_from, node_next, node_c]))

while not Q.empty():
node_t, path = Q.get_nowait()
if node_t == node_to and is_uncovered_path(path, G):
# print(f"Found uncovered pd path: {[node.get_name() for node in path]}")
return True

for node_u in G.get_adjacent_nodes(node_t):
edge = G.get_edge(node_t, node_u)
node_c = traversePotentiallyDirected(node_t, edge)

if node_c is None:
continue

if not V.__contains__(node_c):
V.add(node_c)
Q.put((node_c, path + [node_c]))

return False

def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph, exclude_node: List[Node]) -> Generator[Node] | None:
Q = Queue()
V = set()
Expand Down Expand Up @@ -802,7 +850,8 @@ def rule9(graph: Graph, nodes: List[Node], changeFlag):
for node_B in possible_children:
if graph.is_adjacent_to(node_B, node_C):
continue
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):

if existsUncoveredPdPath(node_from=node_A, node_next=node_B, node_to=node_C, G=graph):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
Expand Down
37 changes: 37 additions & 0 deletions tests/TestDAG2PAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,40 @@ def test_case_selection(self):
dag.add_directed_edge(nodes[0], nodes[4])
pag = dag2pag(dag, islatent=[], isselection=[nodes[4]])
print(pag)

def test_case_orient_rules(self):
nodes = []
X = {}
L = {}
for i in range(7):
node_name = f"X{i + 1}"
if i + 1 == 2:
node_name = f"L{i + 1}"
node = GraphNode(node_name)
nodes.append(node)
if i + 1 == 2:
L[2] = node
else:
X[i + 1] = node
dag = Dag(nodes)
dag.add_directed_edge(L[2], X[4])
dag.add_directed_edge(L[2], X[5])
dag.add_directed_edge(L[2], X[6])

dag.add_directed_edge(X[5], X[7])
dag.add_directed_edge(X[1], X[4])
dag.add_directed_edge(X[1], X[7])
dag.add_directed_edge(X[3], X[7])
pag = dag2pag(dag, [L[2]])
print(pag)
graphviz_pag = GraphUtils.to_pgv(pag)
graphviz_pag.draw("pag.png", prog='dot', format='png')


if __name__ == "__main__":
test_model = TestDAG2PAG()
test_model.test_case1()
test_model.test_case2()
test_model.test_case3()
test_model.test_case_selection()
test_model.test_case_orient_rules()