From fadba18250a5fbb2070fb67419039b2829a51333 Mon Sep 17 00:00:00 2001 From: Tomasz Godzik Date: Thu, 8 May 2025 19:33:04 +0200 Subject: [PATCH] bugfix: Fix enumeration issues when Value is imported We still add the full prefix in case of conflicts, but that's the current status quo. Making prefixes shorter is something we need to work on separately. --- .../main/dotty/tools/pc/IndexedContext.scala | 34 ++++++++-- .../dotty/tools/pc/InferredTypeProvider.scala | 4 +- .../pc/printer/ShortenedTypePrinter.scala | 3 +- .../tests/edit/InsertInferredTypeSuite.scala | 66 +++++++++++++++++++ 4 files changed, 96 insertions(+), 11 deletions(-) diff --git a/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala b/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala index 2db37f801349..cf22ac12a879 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala @@ -23,13 +23,13 @@ sealed trait IndexedContext: given ctx: Context def scopeSymbols: List[Symbol] def rename(sym: Symbol): Option[String] - def findSymbol(name: Name): Option[List[Symbol]] + def findSymbol(name: Name, fromPrefix: Option[Type] = None): Option[List[Symbol]] def findSymbolInLocalScope(name: String): Option[List[Symbol]] - final def lookupSym(sym: Symbol): Result = + final def lookupSym(sym: Symbol, fromPrefix: Option[Type] = None): Result = def all(symbol: Symbol): Set[Symbol] = Set(symbol, symbol.companionModule, symbol.companionClass, symbol.companion).filter(_ != NoSymbol) val isRelated = all(sym) ++ all(sym.dealiasType) - findSymbol(sym.name) match + findSymbol(sym.name, fromPrefix) match case Some(symbols) if symbols.exists(isRelated) => Result.InScope case Some(symbols) if symbols.exists(isTermAliasOf(_, sym)) => Result.InScope case Some(symbols) if symbols.map(_.dealiasType).exists(isRelated) => Result.InScope @@ -81,7 +81,7 @@ object IndexedContext: case object Empty extends IndexedContext: given ctx: Context = NoContext - def findSymbol(name: Name): Option[List[Symbol]] = None + def findSymbol(name: Name, fromPrefix: Option[Type]): Option[List[Symbol]] = None def findSymbolInLocalScope(name: String): Option[List[Symbol]] = None def scopeSymbols: List[Symbol] = List.empty def rename(sym: Symbol): Option[String] = None @@ -111,11 +111,31 @@ object IndexedContext: override def findSymbolInLocalScope(name: String): Option[List[Symbol]] = names.get(name).map(_.map(_.symbol).toList).filter(_.nonEmpty) - def findSymbol(name: Name): Option[List[Symbol]] = + def findSymbol(name: Name, fromPrefix: Option[Type]): Option[List[Symbol]] = names .get(name.show) - .map(_.map(_.symbol).toList) - .orElse(defaultScopes(name)) + .map { denots => + def skipThisType(tp: Type): Type = tp match + case ThisType(prefix) => skipThisType(prefix) + case _ => tp + + val filteredDenots = fromPrefix match + case Some(prefix) => + val target = skipThisType(prefix) + denots.filter { denot => + denot.prefix == NoPrefix || + (denot.prefix match + case tref: TermRef => + tref.termSymbol.info <:< target + case otherPrefix => + otherPrefix <:< target + ) + } + case None => denots + + filteredDenots.map(_.symbol).toList + } + .orElse(defaultScopes(name)).filter(_.nonEmpty) def scopeSymbols: List[Symbol] = names.values.flatten.map(_.symbol).toList diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala index 3d4cc45171ae..887d404570c0 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/InferredTypeProvider.scala @@ -94,7 +94,8 @@ final class InferredTypeProvider( tpe match case tref: TypeRef => indexedCtx.lookupSym( - tref.currentSymbol + tref.currentSymbol, + Some(tref.prefix) ) == IndexedContext.Result.InScope case AppliedType(tycon, args) => isInScope(tycon) && args.forall(isInScope) @@ -136,7 +137,6 @@ final class InferredTypeProvider( findNamePos(sourceText, vl, keywordOffset).endPos.toLsp adjustOpt.foreach(adjust => endPos.setEnd(adjust.adjustedEndPos)) val spaceBefore = name.isOperatorName - new TextEdit( endPos, printTypeAscription(optDealias(tpt.typeOpt), spaceBefore) + { diff --git a/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala b/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala index a9cfc9dfb690..530883f3cabf 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/printer/ShortenedTypePrinter.scala @@ -174,14 +174,13 @@ class ShortenedTypePrinter( res.toPrefixText } - override def toTextPrefixOf(tp: NamedType): Text = controlled { val maybeRenamedPrefix: Option[Text] = findRename(tp) def trimmedPrefix: Text = if !tp.designator.isInstanceOf[Symbol] && tp.typeSymbol == NoSymbol then super.toTextPrefixOf(tp) else - indexedCtx.lookupSym(tp.symbol) match + indexedCtx.lookupSym(tp.symbol, Some(tp.prefix)) match case _ if indexedCtx.rename(tp.symbol).isDefined => Text() // symbol is missing and is accessible statically, we can import it and add proper prefix case Result.Missing if isAccessibleStatically(tp.symbol) => diff --git a/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala index 212b19d31461..23e336244034 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredTypeSuite.scala @@ -990,6 +990,72 @@ class InsertInferredTypeSuite extends BaseCodeActionSuite: |""".stripMargin ) + @Test def `enums` = + checkEdit( + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | def day(d: Day.Value): Unit = ??? + | val <> = + | if (true) Day.Weekday + | else Day.Weekend + |""".stripMargin, + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | def day(d: Day.Value): Unit = ??? + | val d: EnumerationValue.Day.Value = + | if (true) Day.Weekday + | else Day.Weekend + |""".stripMargin + ) + + @Test def `enums2` = + checkEdit( + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | val <> = + | if (true) True + | else False + |""".stripMargin, + """|object EnumerationValue: + | object Day extends Enumeration { + | type Day = Value + | val Weekday, Weekend = Value + | } + | object Bool extends Enumeration { + | type Bool = Value + | val True, False = Value + | } + | import Bool._ + | val b: Value = + | if (true) True + | else False + |""".stripMargin + ) + def checkEdit( original: String, expected: String