Skip to content

Commit

Permalink
fix lcg and add constraints filtering by lowerbound
Browse files Browse the repository at this point in the history
  • Loading branch information
auht committed Jan 14, 2024
1 parent 3ad402f commit d40c9b7
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 60 deletions.
11 changes: 10 additions & 1 deletion shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -820,14 +820,23 @@ class ConstraintSolver extends NormalForms { self: Typer =>
val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
lhs.upperBounds ::= newBound // update the bound
lhs.tsc.foreach { case (tsc, i) => tsc.filterUB(i, rhs) }
lhs.lbtsc.foreach {
case (tsc, i) =>
tsc.filterUB(i, rhs)
if (tsc.constraints.isEmpty) reportError()
}
lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound

case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level =>
println(s"NEW $rhs LB (${lhs.level})")
val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) =>
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
rhs.lowerBounds ::= newBound // update the bound
rhs.ubtsc.foreach {
case (tsc, i) =>
tsc.filterLB(i, lhs)
if (tsc.constraints.isEmpty) reportError()
}
rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound


Expand Down
1 change: 0 additions & 1 deletion shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ trait TypeSimplifier { self: Typer =>
.reduceOption(_ &- _).filterNot(_.isTop).toList
else Nil
}

nv

case ComposedType(true, l, r) =>
Expand Down
97 changes: 61 additions & 36 deletions shared/src/main/scala/mlscript/TyperDatatypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
_assignedTo = value
}

var tsc: Opt[(TupleSetConstraints, Int)] = N
var lbtsc: Opt[(TupleSetConstraints, Int)] = N
var ubtsc: Opt[(TupleSetConstraints, Int)] = N

// * Bounds should always be disregarded when `equatedTo` is defined, as they are then irrelevant:
def lowerBounds: List[SimpleType] = { require(assignedTo.isEmpty, this); _lowerBounds }
Expand Down Expand Up @@ -654,7 +655,7 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
def go(ub: ST): Unit = ub match {
case ub: TV =>
ub.upperBounds.foreach(go)
ub.tsc = S(this, index)
ub.lbtsc = S(this, index)
case _ =>
constraints.filterInPlace { constrs =>
val ty = constrs(index)
Expand All @@ -667,65 +668,89 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
if (constraints.sizeCompare(1) === 0) {
constraints.head.zip(tvs).foreach {
case (ty, tv) =>
tv.tsc = N
tv.lbtsc = N
tv.ubtsc = N
constrain(tv, ty)(raise, prov, ctx)
constrain(ty, tv)(raise, prov, ctx)
}
}
}
def filterLB(index: Int, lb: ST)(implicit raise: Raise, ctx: Ctx): Unit = {
constraints.filterInPlace { constrs =>
val ty = constrs(index)
val dnf = DNF.mk(MaxLevel, Nil, lb & ty.neg(), true)
dnf.isBot || dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty))
}
println(s"TSC filterLB: $tvs in $constraints")
if (constraints.sizeCompare(1) === 0) {
constraints.head.zip(tvs).foreach {
case (ty, tv) =>
tv.lbtsc = N
tv.ubtsc = N
constrain(tv, ty)(raise, prov, ctx)
constrain(ty, tv)(raise, prov, ctx)
}
}
}
}
object TupleSetConstraints {
def lcgField(a: FieldType, b: FieldType)
def lcgField(first: FieldType, rest: Ls[FieldType])
(implicit prov: TypeProvenance, lvl: Level)
: (FieldType, Ls[TV], Ls[Ls[ST]]) = {
val (ub, tvs, constrs) = lcg(a.ub, b.ub)
if (a.lb.isEmpty && b.lb.isEmpty) {
val (ub, tvs, constrs) = lcg(first.ub, rest.map(_.ub))
if (first.lb.isEmpty && rest.forall(_.lb.isEmpty)) {
(FieldType(N, ub)(prov), tvs, constrs)
} else {
val (lb, ltvs, lconstrs) = lcg(a.lb.getOrElse(BotType), b.lb.getOrElse(BotType))
val (lb, ltvs, lconstrs) = lcg(first.lb.getOrElse(BotType), rest.map(_.lb.getOrElse(BotType)))
(FieldType(S(lb), ub)(prov), tvs ++ ltvs, constrs ++ lconstrs)
}
}
def lcg(a: ST, b: ST)
def lcg(first: ST, rest: Ls[ST])
(implicit prov: TypeProvenance, lvl: Level)
: (ST, Ls[TV], Ls[Ls[ST]]) = (a, b) match {
case (_, b: ProvType) => lcg(a, b.underlying)
case (a: ProvType, _) => lcg(a.underlying, b)
case (a: FT, b: FT) => lcgFunction(a, b)
case (a: ArrayType, b: ArrayType) =>
val (t, tvs, constrs) = lcgField(a.inner, b.inner)
: (ST, Ls[TV], Ls[Ls[ST]]) = first match {
case a: FunctionType if rest.forall(_.isInstanceOf[FunctionType]) =>
val (lhss, rhss) = rest.collect {
case FunctionType(lhs, rhs) => lhs -> rhs
}.unzip
val (lhs, ltvs, lconstrs) = lcg(a.lhs, lhss)
val (rhs, rtvs, rconstrs) = lcg(a.rhs, rhss)
(FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
case a: ArrayType if rest.forall(_.isInstanceOf[ArrayType]) =>
val inners = rest.collect { case b: ArrayType => b.inner }
val (t, tvs, constrs) = lcgField(a.inner, inners)
(ArrayType(t)(prov), tvs, constrs)
case (a: TupleType, b: TupleType) if a.fields.sizeCompare(b.fields.size) === 0 =>
val (fts, tvss, constrss) = a.fields.map(_._2).zip(b.fields.map(_._2)).map {
case (a, b) => lcgField(a, b)
}.unzip3
case a: TupleType if rest.forall { case b: TupleType => a.fields.sizeCompare(b.fields.size) === 0; case _ => false } =>
val fields = rest.collect { case TupleType(fields) => fields.map(_._2) }
val (fts, tvss, constrss) = a.fields.map(_._2).zip(fields.transpose).map { case (a, bs) => lcgField(a, bs) }.unzip3
(TupleType(fts.map(N -> _))(prov), tvss.flatten, constrss.flatten)
case (a: TR, b: TR) if a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0 =>
val (ts, tvss, constrss) = a.targs.zip(b.targs).map {
case (a, b) => lcg(a, b)
}.unzip3
case a: TR if rest.forall { case b: TR => a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0; case _ => false } =>
val targs = rest.collect { case b: TR => b.targs }
val (ts, tvss, constrss) = a.targs.zip(targs.transpose).map { case (a, bs) => lcg(a, bs) }.unzip3
(TypeRef(a.defn, ts)(prov), tvss.flatten, constrss.flatten)
case (a: TV, b: TV) if a.compare(b) === 0 => (a, Nil, Nil)
case (a: ExtrType, b: ExtrType) if a.pol === b.pol => (a, Nil, Nil)
case a: TV if rest.forall { case b: TV => a.compare(b) === 0; case _ => false } => (a, Nil, Nil)
case a if rest.forall(_ === a) => (a, Nil, Nil)
case _ =>
val tv = freshVar(prov, N)
(tv, List(tv), List(List(a, b)))
(tv, List(tv), List(first :: rest))
}
def lcgFunction(a: FT, b: FT)
(implicit prov: TypeProvenance, lvl: Level)
: (FT, Ls[TV], Ls[Ls[ST]]) = {
val (lhs, ltvs, lconstrs) = lcg(a.lhs, b.lhs)
val (rhs, rtvs, rconstrs) = lcg(a.rhs, b.rhs)
def lcgFunction(first: FunctionType, rest: Ls[FunctionType])(implicit prov: TypeProvenance, lvl: Level)
: (FunctionType, Ls[TV], Ls[Ls[ST]]) = {
val (lhss, rhss) = rest.map {
case FunctionType(lhs, rhs) => lhs -> rhs
}.unzip
val (lhs, ltvs, lconstrs) = lcg(first.lhs, lhss)
val (rhs, rtvs, rconstrs) = lcg(first.rhs, rhss)
(FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
}
def mk(ov: Overload)(implicit lvl: Level): FunctionType = {
val (t, tvs, constrs) =
ov.alts.tail.foldLeft((ov.alts.head, Nil: Ls[TV], Nil: Ls[Ls[ST]])) {
case ((a, tvs, constrs), b) => lcgFunction(a, b)(ov.prov, lvl)
}
// val (t, tvs, constrs) = lcgFunction(ov.alts.head, ov.alts.tail)(ov.prov, lvl)
def unwrap(t: ST): ST = t.map(unwrap)
val f = ov.mapAlts(unwrap)(unwrap)
val (t, tvs, constrs) = lcgFunction(f.alts.head, f.alts.tail)(ov.prov, lvl)
val tsc = new TupleSetConstraints(MutSet.empty ++ constrs.transpose, tvs)(ov.prov)
tvs.zipWithIndex.foreach { case (tv, i) => tv.tsc = S((tsc, i)) }
tvs.zipWithIndex.foreach { case (tv, i) =>
tv.lbtsc = S((tsc, i))
tv.ubtsc = S((tsc, i))
}
println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}")
t
}
Expand Down
19 changes: 15 additions & 4 deletions shared/src/main/scala/mlscript/TyperHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -983,13 +983,24 @@ abstract class TyperHelpers { Typer: Typer =>
def getVars: SortedSet[TypeVariable] = getVarsImpl(includeBounds = true)

def showBounds: String =
getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty).map {
getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty || (tv.lbtsc.fold(false)(!_._1.tvs.contains(tv)))).map {
case tv @ AssignedVariable(ty) => "\n\t\t" + tv.toString + " := " + ty
case tv => ("\n\t\t" + tv.toString
+ (if (tv.lowerBounds.isEmpty) "" else " :> " + tv.lowerBounds.mkString(" | "))
+ (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & ")))
}.mkString

+ (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & "))
+ tv.lbtsc.fold(""){ case (tsc, i) => " :> " + tsc.tvs(i) } )
}.mkString + {
val visited: MutSet[TV] = MutSet.empty
getVars.iterator.filter(tv => tv.lbtsc.fold(false)(_._1.tvs.contains(tv))).map {
case tv if visited.contains(tv) => ""
case tv =>
visited ++= tv.lbtsc.fold(Nil: Ls[TV])(_._1.tvs)
tv.lbtsc.fold("") { case (tsc, _) => ("\n\t\t[ "
+ tsc.tvs.mkString(", ")
+ " ] in { " + tsc.constraints.mkString(", ") + " }")
}
}.mkString
}
}


Expand Down
62 changes: 44 additions & 18 deletions shared/src/test/diff/nu/HeungTung.mls
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ fun g = h
//│ fun g: (Bool | Int) -> (Int | false | true)

// * In one step
:e // TODO: argument of union type
fun g: (Int | Bool) -> (Int | Bool)
fun g = f
//│ ╔══[ERROR] Type mismatch in definition:
//│ ║ l.71: fun g = f
//│ ║ ^^^^^
//│ ╟── expression of type `Int | false | true` does not match type `?a`
//│ ╟── Note: constraint arises from function type:
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ fun g: Int -> Int & Bool -> Bool
//│ fun g: (Bool | Int) -> (Int | false | true)

Expand All @@ -88,9 +96,11 @@ fun j = i
fun j: (Int & Bool) -> (Int & Bool)
fun j = f
//│ ╔══[ERROR] Type mismatch in definition:
//│ ║ l.89: fun j = f
//│ ║ l.97: fun j = f
//│ ║ ^^^^^
//│ ╙── expression of type `Int` does not match type `nothing`
//│ ╟── type `?a` does not match type `nothing`
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ fun j: Int -> Int & Bool -> Bool
//│ fun j: nothing -> nothing

Expand All @@ -106,23 +116,30 @@ fun g = f
// * With match-type-based constraint solving, we could return Int here

f(0)
//│ Int | false | true
//│ Int
//│ res
//│ = 0

// f(0) : case 0 of { Int => Int; Bool => Bool } == Int


x => f(x)
//│ (Bool | Int) -> (Int | false | true)
//│ anything -> nothing
//│ res
//│ = [Function: res]

// : forall 'a: 'a -> case 'a of { Int => Int; Bool => Bool } where 'a <: Int | Bool


:e
f(if true then 0 else false)
//│ Int | false | true
//│ ╔══[ERROR] Type mismatch in application:
//│ ║ l.134: f(if true then 0 else false)
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── expression of type `0 | false` does not match type `?a`
//│ ╟── Note: constraint arises from function type:
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ error
//│ res
//│ = 0

Expand All @@ -132,12 +149,21 @@ f(if true then 0 else false)
:w
f(refined if true then 0 else false) // this one can be precise again!
//│ ╔══[WARNING] Paren-less applications should use the 'of' keyword
//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╔══[ERROR] identifier not found: refined
//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ╙── ^^^^^^^
//│ Int | false | true
//│ ╔══[ERROR] Type mismatch in application:
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── application of type `error` does not match type `?a`
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ╟── Note: constraint arises from function type:
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
//│ ╙── ^^^^^^^^^^^^^^
//│ error
//│ Code generation encountered an error:
//│ unresolved symbol refined

Expand Down Expand Up @@ -193,7 +219,7 @@ type T = List[Int]
:e // TODO application types
type Res = M(T)
//│ ╔══[ERROR] Wrong number of type arguments – expected 0, found 1
//│ ║ l.194: type Res = M(T)
//│ ║ l.220: type Res = M(T)
//│ ╙── ^^^^
//│ type Res = M

Expand All @@ -216,21 +242,21 @@ fun f: Int -> Int
fun f: Bool -> Bool
fun f = id
//│ ╔══[ERROR] A type signature for 'f' was already given
//│ ║ l.216: fun f: Bool -> Bool
//│ ║ l.242: fun f: Bool -> Bool
//│ ╙── ^^^^^^^^^^^^^^^^^^^
//│ fun f: forall 'a. 'a -> 'a
//│ fun f: Int -> Int

:e // TODO support
f: (Int -> Int) & (Bool -> Bool)
//│ ╔══[ERROR] Type mismatch in type ascription:
//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool)
//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool)
//│ ║ ^
//│ ╟── type `Bool` is not an instance of `Int`
//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool)
//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool)
//│ ║ ^^^^
//│ ╟── Note: constraint arises from type reference:
//│ ║ l.215: fun f: Int -> Int
//│ ║ l.241: fun f: Int -> Int
//│ ╙── ^^^
//│ Int -> Int & Bool -> Bool
//│ res
Expand Down Expand Up @@ -297,14 +323,14 @@ fun test(x) = refined if x is
A then 0
B then 1
//│ ╔══[WARNING] Paren-less applications should use the 'of' keyword
//│ ║ l.296: fun test(x) = refined if x is
//│ ║ l.322: fun test(x) = refined if x is
//│ ║ ^^^^^^^^^^^^^^^
//│ ║ l.297: A then 0
//│ ║ l.323: A then 0
//│ ║ ^^^^^^^^^^
//│ ║ l.298: B then 1
//│ ║ l.324: B then 1
//│ ╙── ^^^^^^^^^^
//│ ╔══[ERROR] identifier not found: refined
//│ ║ l.296: fun test(x) = refined if x is
//│ ║ l.322: fun test(x) = refined if x is
//│ ╙── ^^^^^^^
//│ fun test: (A | B) -> error
//│ Code generation encountered an error:
Expand Down

0 comments on commit d40c9b7

Please sign in to comment.