Skip to content

Commit

Permalink
Clean up DiffTests and fix its use in compiler subproject
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK committed Jul 30, 2024
1 parent dd9a0cd commit 4a5a038
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 85 deletions.
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ lazy val compiler = crossProject(JSPlatform, JVMPlatform).in(file("compiler"))
sourceDirectory := baseDirectory.value.getParentFile()/"shared",
watchSources += WatchSource(
baseDirectory.value.getParentFile()/"shared"/"test"/"diff", "*.mls", NothingFilter),
watchSources += WatchSource(
baseDirectory.value.getParentFile()/"shared"/"test"/"diff-ir", "*.mls", NothingFilter),
)
.dependsOn(mlscript % "compile->compile;test->test")

Expand Down
20 changes: 6 additions & 14 deletions compiler/shared/test/scala/mlscript/compiler/Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import scala.collection.mutable.StringBuilder
import mlscript.compiler.TreeDebug
import simpledef.SimpleDef

class DiffTestCompiler extends DiffTests {
import DiffTestCompiler.*
import DiffTestCompiler.*

class DiffTestCompiler extends DiffTests(State) {

override def postProcess(mode: ModeType, basePath: List[Str], testName: Str, unit: TypingUnit, output: Str => Unit, raise: Diagnostic => Unit): (List[Str], Option[TypingUnit]) =
val outputBuilder = StringBuilder()

Expand Down Expand Up @@ -47,21 +49,11 @@ class DiffTestCompiler extends DiffTests {
}
None

override protected lazy val files = allFiles.filter { file =>
val fileName = file.baseName
validExt(file.ext) && filter(file.relativeTo(pwd))
}
}

object DiffTestCompiler {

private val pwd = os.pwd
private val dir = pwd/"compiler"/"shared"/"test"/"diff"

private val allFiles = os.walk(dir).filter(_.toIO.isFile)

private val validExt = Set("fun", "mls")

private def filter(file: os.RelPath) = DiffTests.filter(file)
lazy val State =
new DiffTests.State(DiffTests.pwd/"compiler"/"shared"/"test"/"diff")

}
19 changes: 5 additions & 14 deletions compiler/shared/test/scala/mlscript/compiler/TestIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import mlscript.compiler.ir._
import scala.collection.mutable.StringBuilder
import mlscript.compiler.optimizer.TailRecOpt

class IRDiffTestCompiler extends DiffTests {
import IRDiffTestCompiler.*
import IRDiffTestCompiler.*

class IRDiffTestCompiler extends DiffTests(State) {

override def postProcess(mode: ModeType, basePath: List[Str], testName: Str, unit: TypingUnit, output: Str => Unit, raise: Diagnostic => Unit): (List[Str], Option[TypingUnit]) =
val outputBuilder = StringBuilder()
Expand Down Expand Up @@ -58,21 +59,11 @@ class IRDiffTestCompiler extends DiffTests {

(outputBuilder.toString().linesIterator.toList, None)

override protected lazy val files = allFiles.filter { file =>
val fileName = file.baseName
validExt(file.ext) && filter(file.relativeTo(pwd))
}
}

object IRDiffTestCompiler {

private val pwd = os.pwd
private val dir = pwd/"compiler"/"shared"/"test"/"diff-ir"

private val allFiles = os.walk(dir).filter(_.toIO.isFile)

private val validExt = Set("fun", "mls")

private def filter(file: os.RelPath) = DiffTests.filter(file)
lazy val State =
new DiffTests.State(DiffTests.pwd/"compiler"/"shared"/"test"/"diff-ir")

}
95 changes: 38 additions & 57 deletions shared/src/test/scala/mlscript/DiffTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.scalatest.{funsuite, ParallelTestExecution}
import org.scalatest.time._
import org.scalatest.concurrent.{TimeLimitedTests, Signaler}
import pretyper.PreTyper
import os.Path

abstract class ModeType {
def expectTypeErrors: Bool
Expand Down Expand Up @@ -50,12 +51,15 @@ abstract class ModeType {
def nolift: Bool
}

class DiffTests
class DiffTests(state: DiffTests.State)
extends funsuite.AnyFunSuite
with ParallelTestExecution
with TimeLimitedTests
{

def this() = this(DiffTests.State)

import state._

/** Hook for dependent projects, like the monomorphizer. */
def postProcess(mode: ModeType, basePath: Ls[Str], testName: Str, unit: TypingUnit, output: Str => Unit, raise: Diagnostic => Unit): (Ls[Str], Option[TypingUnit]) = (Nil, None)
Expand All @@ -65,14 +69,12 @@ class DiffTests
@SuppressWarnings(Array("org.wartremover.warts.RedundantIsInstanceOf"))
private val inParallel = isInstanceOf[ParallelTestExecution]

import DiffTests._

// scala test will not execute a test if the test class has constructor parameters.
// override this to get the correct paths of test files.
protected lazy val files = allFiles.filter { file =>
val fileName = file.baseName
// validExt(file.ext) && filter(fileName)
validExt(file.ext) && filter(file.relativeTo(pwd))
validExt(file.ext) && filter(file.relativeTo(DiffTests.pwd))
}

val timeLimit = TimeLimit
Expand Down Expand Up @@ -240,7 +242,7 @@ class DiffTests
case "p" => mode.copy(showParse = true)
case "d" => mode.copy(dbg = true)
case "dp" => mode.copy(dbgParsing = true)
case DebugUCSFlags(x) => mode.copy(dbgUCS = mode.dbgUCS.fold(S(x))(y => S(y ++ x)))
case DiffTests.DebugUCSFlags(x) => mode.copy(dbgUCS = mode.dbgUCS.fold(S(x))(y => S(y ++ x)))
case "ds" => mode.copy(dbgSimplif = true)
case "dl" => mode.copy(dbgLifting = true)
case "dd" => mode.copy(dbgDefunc = true)
Expand Down Expand Up @@ -1139,61 +1141,40 @@ class DiffTests

object DiffTests {

private val TimeLimit =
if (sys.env.get("CI").isDefined) Span(60, Seconds)
else Span(30, Seconds)

private val pwd = os.pwd
private val dir = pwd/"shared"/"src"/"test"/"diff"

private val allFiles = os.walk(dir).filter(_.toIO.isFile)
val pwd: Path = os.pwd

private val validExt = Set("fun", "mls")
lazy val State = new State(pwd/"shared"/"src"/"test"/"diff")

// Aggregate unstaged modified files to only run the tests on them, if there are any
private val modified: Set[os.RelPath] =
try os.proc("git", "status", "--porcelain", dir).call().out.lines().iterator.flatMap { gitStr =>
println(" [git] " + gitStr)
val prefix = gitStr.take(2)
val filePath = os.RelPath(gitStr.drop(3))
if (prefix =:= "A " || prefix =:= "M " || prefix =:= "R " || prefix =:= "D ")
N // * Disregard modified files that are staged
else S(filePath)
}.toSet catch {
case err: Throwable => System.err.println("/!\\ git command failed with: " + err)
Set.empty
}

// Allow overriding which specific tests to run, sometimes easier for development:
private val focused = Set[Str](
// "LetRec"
// "Ascribe",
// "Repro",
// "RecursiveTypes",
// "Simple",
// "Inherit",
// "Basics",
// "Paper",
// "Negations",
// "RecFuns",
// "With",
// "Annoying",
// "Tony",
// "Lists",
// "Traits",
// "BadTraits",
// "TraitMatching",
// "Subsume",
// "Methods",
).map(os.RelPath(_))
// private def filter(name: Str): Bool =
def filter(file: os.RelPath): Bool = {
if (focused.nonEmpty) focused(file) else modified(file) || modified.isEmpty &&
true
// name.startsWith("new/")
// file.segments.toList.init.lastOption.contains("parser")
class State(val dir: Path) {

val TimeLimit: Span =
if (sys.env.get("CI").isDefined) Span(60, Seconds)
else Span(30, Seconds)

val allFiles: IndexedSeq[Path] = os.walk(dir).filter(_.toIO.isFile)

val validExt: Set[String] = Set("fun", "mls")

// Aggregate unstaged modified files to only run the tests on them, if there are any
val modified: Set[os.RelPath] =
try os.proc("git", "status", "--porcelain", dir).call().out.lines().iterator.flatMap { gitStr =>
println(" [git] " + gitStr)
val prefix = gitStr.take(2)
val filePath = os.RelPath(gitStr.drop(3))
if (prefix =:= "A " || prefix =:= "M " || prefix =:= "R " || prefix =:= "D ")
N // * Disregard modified files that are staged
else S(filePath)
}.toSet catch {
case err: Throwable => System.err.println("/!\\ git command failed with: " + err)
Set.empty
}

// private def filter(name: Str): Bool =
def filter(file: os.RelPath): Bool =
modified(file) || modified.isEmpty

}

object DebugUCSFlags {
// E.g. "ducs", "ducs:foo", "ducs:foo,bar", "ducs:a.b.c,foo"
private val pattern = "^ducs(?::(\\s*(?:[A-Za-z\\.-]+)(?:,\\s*[A-Za-z\\.-]+)*))?$".r
Expand Down

0 comments on commit 4a5a038

Please sign in to comment.