Skip to content

Commit d581e54

Browse files
committed
Fix root treatment of parameterless methods
- Map cap to Result in their original types (was Fresh before) - Map Result to Fresh on access
1 parent 055c394 commit d581e54

File tree

15 files changed

+73
-53
lines changed

15 files changed

+73
-53
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,8 @@ object CaptureSet:
11621162
def unify(root1: root.Result, root2: root.Result)(using Context): Boolean =
11631163
(root1, root2) match
11641164
case (root1 @ root.Result(binder1), root2 @ root.Result(binder2))
1165-
if (binder1 eq binder2)
1165+
if ((binder1 eq binder2)
1166+
|| binder1.isInstanceOf[ExprType] && binder2.isInstanceOf[ExprType])
11661167
&& (root1.rootAnnot.originalBinder ne root2.rootAnnot.originalBinder)
11671168
&& eqResultMap(root1) == null
11681169
&& eqResultMap(root2) == null

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+12-2
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,16 @@ class CheckCaptures extends Recheck, SymTransformer:
556556
if sym.exists && curEnv.isOpen then
557557
markFree(capturedVars(sym).filter(isRetained), tree)
558558

559+
/** If `tp` (possibly after widening singletons) is an ExprType
560+
* of a parameterless method, map Result instances in it to Fresh instances
561+
*/
562+
def mapResultRoots(tp: Type, sym: Symbol)(using Context): Type =
563+
tp.widenSingleton match
564+
case tp: ExprType if sym.is(Method) =>
565+
root.resultToFresh(tp, root.Origin.ResultInstance(tp, sym))
566+
case _ =>
567+
tp
568+
559569
/** Under the sealed policy, disallow the root capability in type arguments.
560570
* Type arguments come either from a TypeApply node or from an AppliedType
561571
* which represents a trait parent in a template.
@@ -618,7 +628,7 @@ class CheckCaptures extends Recheck, SymTransformer:
618628
if pathRef.derivesFrom(defn.Caps_Mutable) && pt.isValueType && !pt.isMutableType then
619629
pathRef = pathRef.readOnly
620630
markFree(sym, pathRef, tree)
621-
super.recheckIdent(tree, pt)
631+
mapResultRoots(super.recheckIdent(tree, pt), tree.symbol)
622632

623633
/** The expected type for the qualifier of a selection. If the selection
624634
* could be part of a capability path or is a a read-only method, we return
@@ -665,7 +675,7 @@ class CheckCaptures extends Recheck, SymTransformer:
665675
|since its capture set ${qualType.captureSet} is read-only""",
666676
tree.srcPos)
667677

668-
val selType = recheckSelection(tree, qualType, name, disambiguate)
678+
val selType = mapResultRoots(recheckSelection(tree, qualType, name, disambiguate), tree.symbol)
669679
val selWiden = selType.widen
670680

671681
// Don't apply the rule

compiler/src/dotty/tools/dotc/cc/Setup.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
323323

324324
try
325325
val tp1 = mapInferred(refine = true)(tp)
326-
val tp2 = root.toResultInResults(_ => assert(false))(tp1)
326+
val tp2 = root.toResultInResults(NoSymbol, _ => assert(false))(tp1)
327327
if tp2 ne tp then capt.println(i"expanded inferred in ${ctx.owner}: $tp --> $tp1 --> $tp2")
328328
tp2
329329
catch case ex: AssertionError =>
@@ -458,7 +458,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
458458

459459
def transform(tp: Type): Type =
460460
val tp1 = toCapturing(tp)
461-
val tp2 = root.toResultInResults(fail, toCapturing.keepFunAliases)(tp1)
461+
val tp2 = root.toResultInResults(sym, fail, toCapturing.keepFunAliases)(tp1)
462462
val snd = if toCapturing.keepFunAliases then "" else " 2nd time"
463463
if tp2 ne tp then capt.println(i"expanded explicit$snd in ${ctx.owner}: $tp --> $tp1 --> $tp2")
464464
tp2
@@ -648,7 +648,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
648648

649649
val paramSymss = sym.paramSymss
650650
def newInfo(using Context) = // will be run in this or next phase
651-
root.toResultInResults(report.error(_, tree.srcPos)):
651+
root.toResultInResults(sym, report.error(_, tree.srcPos)):
652652
if sym.is(Method) then
653653
paramsToCap(methodType(paramSymss, localReturnType))
654654
else tree.tpt.nuType

compiler/src/dotty/tools/dotc/cc/root.scala

+9-7
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ object root:
110110
""
111111

112112
enum Kind:
113-
case Result(binder: MethodType)
113+
case Result(binder: MethodicType)
114114
case Fresh(hidden: CaptureSet.HiddenSet)(val origin: Origin)
115115
case Global
116116

@@ -134,14 +134,14 @@ object root:
134134
ccs.rootId += 1
135135
ccs.rootId
136136

137-
//assert(id != 4)
137+
//assert(id != 4, kind)
138138

139139
override def symbol(using Context) = defn.RootCapabilityAnnot
140140
override def tree(using Context) = New(symbol.typeRef, Nil)
141141
override def derivedAnnotation(tree: Tree)(using Context): Annotation = this
142142

143143
private var myOriginalKind = kind
144-
def originalBinder: MethodType = myOriginalKind.asInstanceOf[Kind.Result].binder
144+
def originalBinder: MethodicType = myOriginalKind.asInstanceOf[Kind.Result].binder
145145

146146
def derivedAnnotation(binder: MethodType)(using Context): Annotation = kind match
147147
case Kind.Result(b) if b ne binder =>
@@ -198,13 +198,13 @@ object root:
198198
type Result = AnnotatedType
199199

200200
object Result:
201-
def apply(binder: MethodType)(using Context): Result =
201+
def apply(binder: MethodicType)(using Context): Result =
202202
val hiddenSet = CaptureSet.HiddenSet(NoSymbol)
203203
val res = AnnotatedType(cap, Annot(Kind.Result(binder)))
204204
hiddenSet.owningCap = res
205205
res
206206

207-
def unapply(tp: Result)(using Context): Option[MethodType] = tp.annot match
207+
def unapply(tp: Result)(using Context): Option[MethodicType] = tp.annot match
208208
case Annot(Kind.Result(binder)) => Some(binder)
209209
case _ => None
210210
end Result
@@ -298,7 +298,7 @@ object root:
298298
* variable bound by `mt`.
299299
* Stop at function or method types since these have been mapped before.
300300
*/
301-
def toResult(tp: Type, mt: MethodType, fail: Message => Unit)(using Context): Type =
301+
def toResult(tp: Type, mt: MethodicType, fail: Message => Unit)(using Context): Type =
302302

303303
abstract class CapMap extends BiTypeMap:
304304
override def mapOver(t: Type): Type = t match
@@ -356,7 +356,7 @@ object root:
356356
end toResult
357357

358358
/** Map global roots in function results to result roots */
359-
def toResultInResults(fail: Message => Unit, keepAliases: Boolean = false)(using Context): TypeMap = new TypeMap with FollowAliasesMap:
359+
def toResultInResults(sym: Symbol, fail: Message => Unit, keepAliases: Boolean = false)(using Context): TypeMap = new TypeMap with FollowAliasesMap:
360360
def apply(t: Type): Type = t match
361361
case defn.RefinedFunctionOf(mt) =>
362362
val mt1 = apply(mt)
@@ -369,6 +369,8 @@ object root:
369369
t.derivedCapturingType(this(parent), refs)
370370
case t: (LazyRef | TypeVar) =>
371371
mapConserveSuper(t)
372+
case t: ExprType if sym.is(Method, butNot = Accessor) =>
373+
t.derivedExprType(toResult(t.resType, t, fail))
372374
case _ =>
373375
try
374376
if keepAliases then mapOver(t)

compiler/src/dotty/tools/dotc/core/Types.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -3721,7 +3721,8 @@ object Types extends TypeUtils {
37213721
// is that most poly types are cyclic via poly params,
37223722
// and therefore two different poly types would never be equal.
37233723

3724-
trait MethodicType extends TermType
3724+
trait MethodicType extends TermType:
3725+
def resType: Type
37253726

37263727
/** A by-name parameter type of the form `=> T`, or the type of a method with no parameter list. */
37273728
abstract case class ExprType(resType: Type)
@@ -4279,7 +4280,7 @@ object Types extends TypeUtils {
42794280
ps.get(elemName) match
42804281
case Some(elemRef) => assert(elemRef eq elem, i"bad $mt")
42814282
case _ =>
4282-
case root.Result(binder) if binder ne mt =>
4283+
case root.Result(binder: MethodType) if binder ne mt =>
42834284
assert(binder.paramNames.toList != mt.paramNames.toList, i"bad $mt")
42844285
case _ =>
42854286
checkRefs(refs)

scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala

+10-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.runtime.Statics
2525
import language.experimental.captureChecking
2626
import annotation.unchecked.uncheckedCaptures
2727
import caps.cap
28-
import caps.unsafe.{unsafeAssumeSeparate, untrackedCaptures}
28+
import caps.unsafe.{unsafeAssumeSeparate, unsafeAssumePure, untrackedCaptures}
2929

3030
/** This class implements an immutable linked list. We call it "lazy"
3131
* because it computes its elements only when they are needed.
@@ -989,7 +989,12 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
989989

990990
private sealed trait State[+A] extends Serializable {
991991
def head: A
992-
def tail: LazyListIterable[A]^
992+
def tail: LazyListIterable[A]/*^*/
993+
// should be ^ but this fails checking. The problem is that
994+
// the ^ in LazyListIterable[A]^ is treated as a result reference.
995+
// But then it cannot be subsumed by
996+
// val tail: LazyListIterable[A]^
997+
// in class State.Cons.
993998
}
994999

9951000
private object State {
@@ -1000,14 +1005,15 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10001005
}
10011006

10021007
@SerialVersionUID(3L)
1003-
final class Cons[A](val head: A, val tail: LazyListIterable[A]^) extends State[A]
1008+
final class Cons[A](val head: A, val tail: LazyListIterable[A]/*^*/) extends State[A]
10041009
}
10051010

10061011
/** Creates a new LazyListIterable. */
10071012
@inline private def newLL[A](state: => State[A]^): LazyListIterable[A]^{state} = new LazyListIterable[A](() => state)
10081013

10091014
/** Creates a new State.Cons. */
1010-
@inline private def sCons[A](hd: A, tl: LazyListIterable[A]^): State[A]^{tl} = new State.Cons[A](hd, tl)
1015+
@inline private def sCons[A](hd: A, tl: LazyListIterable[A]^): State[A]^{tl} =
1016+
new State.Cons[A](hd, tl.unsafeAssumePure)
10111017

10121018
private val anyToMarker: Any => Any = _ => Statics.pfMarker
10131019

tests/neg-custom-args/captures/i15772.check

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:35:34 ---------------------------------------
2020
35 | val boxed2 : Observe[C]^ = box2(c) // error
2121
| ^
22-
| Found: box C^
23-
| Required: box C{val arg: C^?}^?
22+
|Found: box C^
23+
|Required: box C{val arg: C^?}^?
2424
|
25-
| where: ^ refers to a fresh root capability in the result type of method c
25+
|where: ^ refers to a fresh root capability created in value boxed2 when instantiating method c's type -> C^{cap}
2626
|
2727
|
28-
| Note that the universal capability `cap`
29-
| cannot be included in capture set ?
28+
|Note that the universal capability `cap`
29+
|cannot be included in capture set ?
3030
|
3131
| longer explanation available when compiling with `-explain`
3232
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:46:2 ----------------------------------------
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-- [E164] Declaration Error: tests/neg-custom-args/captures/lazyListState.scala:12:39 ----------------------------------
2+
12 | final class Cons[A](val head: A, val tail: LazyListIterable[A]^) extends State[A] // error
3+
| ^
4+
| error overriding method tail in trait State of type -> LazyListIterable[A]^{cap};
5+
| value tail of type LazyListIterable[A]^ has incompatible type
6+
|
7+
| where: ^ refers to a fresh root capability in the type of value tail
8+
| cap is a root capability associated with the result type of -> LazyListIterable[A²]^²
9+
|
10+
| longer explanation available when compiling with `-explain`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class LazyListIterable[+A]
2+
3+
private sealed trait State[+A]:
4+
def head: A
5+
def tail: LazyListIterable[A]^
6+
7+
private object State:
8+
object Empty extends State[Nothing]:
9+
def head: Nothing = throw new NoSuchElementException("head of empty lazy list")
10+
def tail: LazyListIterable[Nothing] = throw new UnsupportedOperationException("tail of empty lazy list")
11+
12+
final class Cons[A](val head: A, val tail: LazyListIterable[A]^) extends State[A] // error
13+

tests/neg-custom-args/captures/lazylist.check

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
22 | def tail: LazyList[Nothing]^ = ??? // error overriding
3838
| ^
3939
| error overriding method tail in class LazyList of type -> lazylists.LazyList[Nothing];
40-
| method tail of type -> lazylists.LazyList[Nothing]^ has incompatible type
40+
| method tail of type -> lazylists.LazyList[Nothing]^{cap} has incompatible type
4141
|
42-
| where: ^ refers to a fresh root capability in the result type of method tail
42+
| where: cap is a root capability associated with the result type of -> lazylists.LazyList[Nothing]^
4343
|
4444
| longer explanation available when compiling with `-explain`

tests/neg-custom-args/captures/linear-buffer.check

-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88
| ^^^^^^^^^^^^^
99
| Separation failure: method bar's result type BadBuffer[T]^ hides non-local this of class class BadBuffer.
1010
| The access must be in a @consume method to allow this.
11-
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:6:9 -------------------------------------------------------
12-
6 | def foo = // error
13-
| ^
14-
|Separation failure: method foo's inferred result type BadBuffer[box T^{}]^ hides non-local this of class class BadBuffer.
15-
|The access must be in a @consume method to allow this.
1611
-- Error: tests/neg-custom-args/captures/linear-buffer.scala:19:17 -----------------------------------------------------
1712
19 | val buf3 = app(buf, 3) // error
1813
| ^^^

tests/neg-custom-args/captures/linear-buffer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import language.experimental.captureChecking
33

44
class BadBuffer[T] extends Mutable:
55
mut def append(x: T): BadBuffer[T]^ = this // error
6-
def foo = // error
6+
def foo =
77
def bar: BadBuffer[T]^ = this // error
88
bar
99

tests/neg-custom-args/captures/sep-use2.check

+1-19
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,13 @@
2020
| ^^^^^^^
2121
| Separation failure: method cc's result type Object^ hides non-local parameter c
2222
-- Error: tests/neg-custom-args/captures/sep-use2.scala:18:6 -----------------------------------------------------------
23-
18 | { f(cc) } // error // error
23+
18 | { f(cc) } // error
2424
| ^
2525
| Separation failure: Illegal access to {c} which is hidden by the previous definition
2626
| of method cc with result type Object^.
2727
| This type hides capabilities {c}
2828
|
2929
| where: ^ refers to a fresh root capability in the result type of method cc
30-
-- Error: tests/neg-custom-args/captures/sep-use2.scala:18:8 -----------------------------------------------------------
31-
18 | { f(cc) } // error // error
32-
| ^^
33-
|Separation failure: argument of type (cc : -> Object^)
34-
|to a function of type (x: Object^) ->{c} Object^
35-
|corresponds to capture-polymorphic formal parameter x of type Object^²
36-
|and hides capabilities {cap, c}.
37-
|Some of these overlap with the captures of the function prefix.
38-
|
39-
| Hidden set of current argument : {cap, c}
40-
| Hidden footprint of current argument : {c}
41-
| Capture set of function prefix : {f}
42-
| Footprint set of function prefix : {f, c}
43-
| The two sets overlap at : {c}
44-
|
45-
|where: ^ refers to a fresh root capability in the result type of method cc
46-
| ^² refers to a fresh root capability created in value x1 when checking argument to parameter x of method apply
47-
| cap is a fresh root capability in the result type of method cc
4830
-- Error: tests/neg-custom-args/captures/sep-use2.scala:20:6 -----------------------------------------------------------
4931
20 | { f(c) } // error // error
5032
| ^

tests/neg-custom-args/captures/sep-use2.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test1(@consume c: Object^, f: (x: Object^) => Object^) =
1515
def test2(@consume c: Object^, f: (x: Object^) ->{c} Object^) =
1616
def cc: Object^ = c // error
1717
val x1 =
18-
{ f(cc) } // error // error
18+
{ f(cc) } // error
1919
val x4: Object^ = // ^ hides just c, since the Object^ in the result of `f` is existential
2020
{ f(c) } // error // error
2121

tests/neg-custom-args/captures/unsound-reach-6.check

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@
2525
-- Error: tests/neg-custom-args/captures/unsound-reach-6.scala:19:14 ---------------------------------------------------
2626
19 | val z = f(ys) // error @consume failure
2727
| ^^
28-
|Separation failure: argument to @consume parameter with type (ys : -> List[box () ->{io} Unit]) refers to non-local parameter io
28+
|Separation failure: argument to @consume parameter with type List[box () ->{io} Unit] refers to non-local parameter io

0 commit comments

Comments
 (0)