diff --git a/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala b/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala index 2db37f801349..cf22ac12a879 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala @@ -23,13 +23,13 @@ sealed trait IndexedContext: given ctx: Context def scopeSymbols: List[Symbol] def rename(sym: Symbol): Option[String] - def findSymbol(name: Name): Option[List[Symbol]] + def findSymbol(name: Name, fromPrefix: Option[Type] = None): Option[List[Symbol]] def findSymbolInLocalScope(name: String): Option[List[Symbol]] - final def lookupSym(sym: Symbol): Result = + final def lookupSym(sym: Symbol, fromPrefix: Option[Type] = None): Result = def all(symbol: Symbol): Set[Symbol] = Set(symbol, symbol.companionModule, symbol.companionClass, symbol.companion).filter(_ != NoSymbol) val isRelated = all(sym) ++ all(sym.dealiasType) - findSymbol(sym.name) match + findSymbol(sym.name, fromPrefix) match case Some(symbols) if symbols.exists(isRelated) => Result.InScope case Some(symbols) if symbols.exists(isTermAliasOf(_, sym)) => Result.InScope case Some(symbols) if symbols.map(_.dealiasType).exists(isRelated) => Result.InScope @@ -81,7 +81,7 @@ object IndexedContext: case object Empty extends IndexedContext: given ctx: Context = NoContext - def findSymbol(name: Name): Option[List[Symbol]] = None + def findSymbol(name: Name, fromPrefix: Option[Type]): Option[List[Symbol]] = None def findSymbolInLocalScope(name: String): Option[List[Symbol]] = None def scopeSymbols: List[Symbol] = List.empty def rename(sym: Symbol): Option[String] = None @@ -111,11 +111,31 @@ object IndexedContext: override def findSymbolInLocalScope(name: String): Option[List[Symbol]] = names.get(name).map(_.map(_.symbol).toList).filter(_.nonEmpty) - def findSymbol(name: Name): Option[List[Symbol]] = + def findSymbol(name: Name, fromPrefix: Option[Type]): Option[List[Symbol]] = names .get(name.show) - .map(_.map(_.symbol).toList) - .orElse(defaultScopes(name)) + .map { denots => + def skipThisType(tp: Type): Type = tp match + case ThisType(prefix) => skipThisType(prefix) + case _ => tp + + val filteredDenots = fromPrefix match + case Some(prefix) => + val target = skipThisType(prefix) + denots.filter { denot => + denot.prefix == NoPrefix || + (denot.prefix match + case tref: TermRef => + tref.termSymbol.info <:< target + case otherPrefix => + otherPrefix <:< target + ) + } + case None => denots + + filteredDenots.map(_.symbol).toList + } + .orElse(defaultScopes(name)).filter(_.nonEmpty) def scopeSymbols: List[Symbol] = names.values.flatten.map(_.symbol).toList diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala index 3d4cc45171ae..887d404570c0 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala @@ -94,7 +94,8 @@ final class InferredTypeProvider( tpe match case tref: TypeRef => indexedCtx.lookupSym( - tref.currentSymbol + tref.currentSymbol, + Some(tref.prefix) ) == IndexedContext.Result.InScope case AppliedType(tycon, args) => isInScope(tycon) && args.forall(isInScope) @@ -136,7 +137,6 @@ final class InferredTypeProvider( findNamePos(sourceText, vl, keywordOffset).endPos.toLsp adjustOpt.foreach(adjust => endPos.setEnd(adjust.adjustedEndPos)) val spaceBefore = name.isOperatorName - new TextEdit( endPos, printTypeAscription(optDealias(tpt.typeOpt), spaceBefore) + { diff --git a/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala b/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala index a9cfc9dfb690..530883f3cabf 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala @@ -174,14 +174,13 @@ class ShortenedTypePrinter( res.toPrefixText } - override def toTextPrefixOf(tp: NamedType): Text = controlled { val maybeRenamedPrefix: Option[Text] = findRename(tp) def trimmedPrefix: Text = if !tp.designator.isInstanceOf[Symbol] && tp.typeSymbol == NoSymbol then super.toTextPrefixOf(tp) else - indexedCtx.lookupSym(tp.symbol) match + indexedCtx.lookupSym(tp.symbol, Some(tp.prefix)) match case _ if indexedCtx.rename(tp.symbol).isDefined => Text() // symbol is missing and is accessible statically, we can import it and add proper prefix case Result.Missing if isAccessibleStatically(tp.symbol) => diff --git a/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala index 212b19d31461..23e336244034 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala @@ -990,6 +990,72 @@ class InsertInferredTypeSuite extends BaseCodeActionSuite: |""".stripMargin ) + @Test def `enums` = + checkEdit( + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | def day(d: Day.Value): Unit = ??? + | val <> = + | if (true) Day.Weekday + | else Day.Weekend + |""".stripMargin, + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | def day(d: Day.Value): Unit = ??? + | val d: EnumerationValue.Day.Value = + | if (true) Day.Weekday + | else Day.Weekend + |""".stripMargin + ) + + @Test def `enums2` = + checkEdit( + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | val <> = + | if (true) True + | else False + |""".stripMargin, + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | val b: Value = + | if (true) True + | else False + |""".stripMargin + ) + def checkEdit( original: String, expected: String