diff --git a/src/main/scala/viper/silver/ast/utility/Consistency.scala b/src/main/scala/viper/silver/ast/utility/Consistency.scala index c4d51f34d..d8c2c842b 100644 --- a/src/main/scala/viper/silver/ast/utility/Consistency.scala +++ b/src/main/scala/viper/silver/ast/utility/Consistency.scala @@ -348,14 +348,14 @@ object Consistency { def checkNoFunctionRecursesViaPreconditions(program: Program): Seq[ConsistencyError] = { var s = Seq.empty[ConsistencyError] - Functions.findFunctionCyclesViaPreconditions(program) foreach { case (func, cycleSet) => - var msg = s"Function ${func.name} recurses via its precondition" + Functions.findFunctionCyclesViaPreconditions(program) foreach { case (funcName, cycleSet) => + var msg = s"Function ${funcName} recurses via its precondition" if (cycleSet.nonEmpty) { - msg = s"$msg: the cycle contains the function(s) ${cycleSet.map(_.name).mkString(", ")}" + msg = s"$msg: the cycle contains the function(s) ${cycleSet.mkString(", ")}" } - s :+= ConsistencyError(msg, func.pos) + s :+= ConsistencyError(msg, program.findFunction(funcName).pos) } s } diff --git a/src/main/scala/viper/silver/ast/utility/Functions.scala b/src/main/scala/viper/silver/ast/utility/Functions.scala index adbbd3c5d..cd271c181 100644 --- a/src/main/scala/viper/silver/ast/utility/Functions.scala +++ b/src/main/scala/viper/silver/ast/utility/Functions.scala @@ -12,6 +12,7 @@ import org.jgrapht.alg.connectivity.GabowStrongConnectivityInspector import org.jgrapht.graph.{DefaultDirectedGraph, DefaultEdge} import org.jgrapht.traverse.TopologicalOrderIterator +import scala.collection.immutable.ListMap import scala.collection.mutable.{ListBuffer, Set => MSet} import scala.jdk.CollectionConverters._ @@ -19,6 +20,9 @@ import scala.jdk.CollectionConverters._ * Utility methods for functions. */ object Functions { + + type FuncName = String + case class Edge[T](source: T, target: T) def allSubexpressions(func: Function): Seq[Exp] = func.pres ++ func.posts ++ func.body @@ -50,17 +54,17 @@ object Functions { * TODO: Memoize invocations of `getFunctionCallgraph`. Note that it's unclear how to derive a useful key from `subs` */ def getFunctionCallgraph(program: Program, subs: Function => Seq[Exp] = allSubexpressions) - : DefaultDirectedGraph[Function, DefaultEdge] = { - val graph = new DefaultDirectedGraph[Function, DefaultEdge](classOf[DefaultEdge]) + : DefaultDirectedGraph[FuncName, DefaultEdge] = { + val graph = new DefaultDirectedGraph[FuncName, DefaultEdge](classOf[DefaultEdge]) for (f <- program.functions) { - graph.addVertex(f) + graph.addVertex(f.name) } def process(f: Function, e: Exp): Unit = { e visit { case FuncApp(f2name, _) => - graph.addEdge(f, program.findFunction(f2name)) + graph.addEdge(f.name, f2name) } } @@ -80,8 +84,8 @@ object Functions { * calls f1). If the flag considerUnfoldings is set, calls to f2 in the body of * a predicate that is unfolded by f1 are also taken into account. */ - def heights(program: Program, considerUnfoldings: Boolean = false): Map[Function, Int] = { - val result = collection.mutable.Map[Function, Int]() + def heights(program: Program, considerUnfoldings: Boolean = false): Map[FuncName, Int] = { + val result = collection.mutable.Map[FuncName, Int]() /* Compute the call-graph over all functions in the given program. * An edge from f1 to f2 denotes that f1 calls f2, either in the function @@ -95,18 +99,6 @@ object Functions { getFunctionCallgraph(program, allSubexpressions) } -///* debugging */ -// val functionVNP = new org.jgrapht.ext.VertexNameProvider[Function] { -// def getVertexName(vertex: Function) = vertex.name -// } -// -// val dotExporter = new org.jgrapht.ext.DOTExporter(functionVNP, //new org.jgrapht.ext.IntegerNameProvider[Function](), -// functionVNP, -// null) -// -// dotExporter.export(new java.io.FileWriter("callgraph.dot"), callGraph.asInstanceOf[Graph[Function, Nothing]]) -///* /debugging */ - /* Get all strongly connected components (SCCs) of the call-graph, represented as * sets of functions. */ @@ -116,12 +108,12 @@ object Functions { * but where each strongly connected component has been condensed into a * single node. */ - val condensedCallGraph = new DefaultDirectedGraph[MSet[Function], DefaultEdge](classOf[DefaultEdge]) + val condensedCallGraph = new DefaultDirectedGraph[MSet[FuncName], DefaultEdge](classOf[DefaultEdge]) /* Add each SCC as a vertex to the condensed call-graph */ stronglyConnectedSets.foreach(v => condensedCallGraph.addVertex(v.asScala)) - def condensationOf(func: Function): MSet[Function] = + def condensationOf(func: FuncName): MSet[FuncName] = stronglyConnectedSets.find(_ contains func).get.asScala /* Add edges from the call-graph (between individual functions) as edges @@ -136,21 +128,6 @@ object Functions { condensedCallGraph.addEdge(sourceSet, targetSet) } -///* debugging */ -// val functionSetVNP = new org.jgrapht.ext.VertexNameProvider[java.util.Set[Function]] { -// def getVertexName(vertex: java.util.Set[Function]) = vertex.map(_.name).mkString(", ") -// } -// -// val functionSetIDP = new org.jgrapht.ext.VertexNameProvider[java.util.Set[Function]] { -// def getVertexName(vertex: java.util.Set[Function]) = s""""${vertex.map(_.name).mkString(", ")}"""" -// } -// -// val dotExporter2 = new org.jgrapht.ext.DOTExporter(functionSetIDP, // new org.jgrapht.ext.IntegerNameProvider[java.util.Set[Function]](), -// functionSetVNP, -// null) -// dotExporter2.export(new java.io.FileWriter("sccg.dot"), sccg.asInstanceOf[Graph[java.util.Set[Function], Nothing]]) -///* /debugging */ - /* The behaviour of TopologicalOrderIterator is undefined if it is applied * to a cyclic graph, hence this check. */ @@ -210,7 +187,7 @@ object Functions { * functions `fs`, then `f` (transitively) recurses via its precondition, and the * formed cycles involves the set of functions `fs`. */ - def findFunctionCyclesViaPreconditions(program: Program): Map[Function, Set[Function]] = { + def findFunctionCyclesViaPreconditions(program: Program): Map[FuncName, Set[FuncName]] = { findFunctionCyclesVia(program, func => func.pres, allSubexpressions) } @@ -225,17 +202,18 @@ object Functions { * formed cycles involves the set of functions `fs`. */ def findFunctionCyclesVia(program: Program, via: Function => Seq[Exp], subs: Function => Seq[Exp] = allSubexpressions) - :Map[Function, Set[Function]] = { + :Map[FuncName, Set[FuncName]] = { def viaSubs(entryFunc: Function)(otherFunc: Function): Seq[Exp] = if (otherFunc == entryFunc) via(otherFunc) else subs(otherFunc) - program.functions.flatMap(func => { + val res = program.functions.flatMap(func => { val graph = getFunctionCallgraph(program, viaSubs(func)) findCycles(graph, func) - }).toMap[Function, Set[Function]] + }) + ListMap.from(res) } /** Returns all cycles formed by functions that (transitively through certain subexpressions) @@ -249,19 +227,20 @@ object Functions { * formed cycles involves the set of functions `fs`. */ def findFunctionCyclesViaOptimized(program: Program, via: Function => Seq[Exp]) - : Map[Function, Set[Function]] = { + : Map[FuncName, Set[FuncName]] = { val graph = getFunctionCallgraph(program, via) - program.functions.flatMap(func => { + val res = program.functions.flatMap(func => { findCycles(graph, func) - }).toMap[Function, Set[Function]] + }) + ListMap.from(res) } - private def findCycles(graph: DefaultDirectedGraph[Function, DefaultEdge], func: Function): Option[(Function, Set[Function])] = { + private def findCycles(graph: DefaultDirectedGraph[FuncName, DefaultEdge], func: Function): Option[(FuncName, Set[FuncName])] = { val cycleDetector = new CycleDetector(graph) - val cycle = cycleDetector.findCyclesContainingVertex(func).asScala + val cycle = cycleDetector.findCyclesContainingVertex(func.name).asScala if (cycle.isEmpty) None else - Some(func -> cycle.toSet) + Some(func.name -> cycle.toSet) } } diff --git a/src/main/scala/viper/silver/plugin/standard/termination/TerminationPlugin.scala b/src/main/scala/viper/silver/plugin/standard/termination/TerminationPlugin.scala index 716295df4..32a1d0483 100644 --- a/src/main/scala/viper/silver/plugin/standard/termination/TerminationPlugin.scala +++ b/src/main/scala/viper/silver/plugin/standard/termination/TerminationPlugin.scala @@ -249,10 +249,10 @@ class TerminationPlugin(@unused reporter: viper.silver.reporter.Reporter, case dc: DecreasesClause => dc }.nonEmpty) if (!hasDecreasesClause) { - val funcCycles = cycles.get(f) + val funcCycles = cycles.get(f.name) val problematicFuncApps = f.posts.flatMap(p => p.shallowCollect { case fa: FuncApp if fa.func(input) == f => fa - case fa: FuncApp if funcCycles.isDefined && funcCycles.get.contains(fa.func(input)) => fa + case fa: FuncApp if funcCycles.isDefined && funcCycles.get.contains(fa.funcname) => fa }).toSet for (fa <- problematicFuncApps) { val calledFunc = fa.func(input)