Skip to content

Commit 107a2f1

Browse files
rochalaNPCRUSYummy-Yums
authored
Use untpd.Tree instead of tpd.Tree for SelectionRangeProvider (#22702)
Thanks to @NPCRUS @Yummy-Yums for taking a part in a spree. Sorry for taking that long to create a PR, but I was unavailable. Fixes #22566 --------- Co-authored-by: NPCRUS <[email protected]> Co-authored-by: Yummy-Yums <[email protected]>
1 parent d87bbb1 commit 107a2f1

File tree

3 files changed

+151
-27
lines changed

3 files changed

+151
-27
lines changed

presentation-compiler/src/main/dotty/tools/pc/SelectionRangeProvider.scala

+27-26
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import java.util as ju
66
import scala.jdk.CollectionConverters._
77
import scala.meta.pc.OffsetParams
88

9-
import dotty.tools.dotc.ast.tpd
9+
import dotty.tools.dotc.ast.untpd.*
10+
import dotty.tools.dotc.ast.NavigateAST
1011
import dotty.tools.dotc.core.Contexts.Context
1112
import dotty.tools.dotc.interactive.Interactive
1213
import dotty.tools.dotc.interactive.InteractiveDriver
@@ -23,10 +24,7 @@ import org.eclipse.lsp4j.SelectionRange
2324
* @param compiler Metals Global presentation compiler wrapper.
2425
* @param params offset params converted from the selectionRange params.
2526
*/
26-
class SelectionRangeProvider(
27-
driver: InteractiveDriver,
28-
params: ju.List[OffsetParams]
29-
):
27+
class SelectionRangeProvider(driver: InteractiveDriver, params: ju.List[OffsetParams]):
3028

3129
/**
3230
* Get the seletion ranges for the provider params
@@ -44,10 +42,13 @@ class SelectionRangeProvider(
4442
val source = SourceFile.virtual(filePath.toString, text)
4543
driver.run(uri, source)
4644
val pos = driver.sourcePosition(param)
47-
val path =
48-
Interactive.pathTo(driver.openedTrees(uri), pos)(using ctx)
45+
val unit = driver.compilationUnits(uri)
4946

50-
val bareRanges = path
47+
val untpdPath: List[Tree] = NavigateAST
48+
.pathTo(pos.span, List(unit.untpdTree), true).collect:
49+
case untpdTree: Tree => untpdTree
50+
51+
val bareRanges = untpdPath
5152
.flatMap(selectionRangesFromTree(pos))
5253

5354
val comments =
@@ -78,31 +79,31 @@ class SelectionRangeProvider(
7879
end selectionRange
7980

8081
/** Given a tree, create a seq of [[SelectionRange]]s corresponding to that tree. */
81-
private def selectionRangesFromTree(pos: SourcePosition)(tree: tpd.Tree)(using Context) =
82+
private def selectionRangesFromTree(pos: SourcePosition)(tree: Tree)(using Context) =
8283
def toSelectionRange(srcPos: SourcePosition) =
8384
val selectionRange = new SelectionRange()
8485
selectionRange.setRange(srcPos.toLsp)
8586
selectionRange
8687

87-
val treeSelectionRange = toSelectionRange(tree.sourcePos)
88+
val treeSelectionRange = Seq(toSelectionRange(tree.sourcePos))
89+
90+
def allArgsSelectionRange(args: List[Tree]): Option[SelectionRange] =
91+
args match
92+
case Nil => None
93+
case list =>
94+
val srcPos = list.head.sourcePos
95+
val lastSpan = list.last.span
96+
val allArgsSrcPos = SourcePosition(srcPos.source, srcPos.span union lastSpan, srcPos.outer)
97+
if allArgsSrcPos.contains(pos) then Some(toSelectionRange(allArgsSrcPos))
98+
else None
8899

89100
tree match
90-
case tpd.DefDef(name, paramss, tpt, rhs) =>
91-
// If source position is within a parameter list, add a selection range covering that whole list.
92-
val selectedParams =
93-
paramss
94-
.iterator
95-
.flatMap: // parameter list to a sourcePosition covering the whole list
96-
case Seq(param) => Some(param.sourcePos)
97-
case params @ Seq(head, tail*) =>
98-
val srcPos = head.sourcePos
99-
val lastSpan = tail.last.span
100-
Some(SourcePosition(srcPos.source, srcPos.span union lastSpan, srcPos.outer))
101-
case Seq() => None
102-
.find(_.contains(pos))
103-
.map(toSelectionRange)
104-
selectedParams ++ Seq(treeSelectionRange)
105-
case _ => Seq(treeSelectionRange)
101+
case DefDef(_, paramss, _, _) => paramss.flatMap(allArgsSelectionRange) ++ treeSelectionRange
102+
case Apply(_, args) => allArgsSelectionRange(args) ++ treeSelectionRange
103+
case TypeApply(_, args) => allArgsSelectionRange(args) ++ treeSelectionRange
104+
case UnApply(_, _, pattern) => allArgsSelectionRange(pattern) ++ treeSelectionRange
105+
case Function(args, body) => allArgsSelectionRange(args) ++ treeSelectionRange
106+
case _ => treeSelectionRange
106107

107108
private def setParent(
108109
child: SelectionRange,

presentation-compiler/test/dotty/tools/pc/tests/SelectionRangeSuite.scala

+120-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite:
7575
| b <- Some(2)
7676
| } yield a + b
7777
|}""".stripMargin,
78+
"""|object Main extends App {
79+
| val total = for {
80+
| >>region>>a <- Some(1)<<region<<
81+
| b <- Some(2)
82+
| } yield a + b
83+
|}""".stripMargin,
7884
"""|object Main extends App {
7985
| val total = >>region>>for {
8086
| a <- Some(1)
@@ -102,7 +108,7 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite:
102108
)
103109
)
104110

105-
@Test def `function params` =
111+
@Test def `function-params-1` =
106112
check(
107113
"""|object Main extends App {
108114
| def func(a@@: Int, b: Int) =
@@ -124,6 +130,32 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite:
124130
)
125131
)
126132

133+
@Test def `function-params-2` =
134+
check(
135+
"""|object Main extends App {
136+
| val func = (a@@: Int, b: Int) =>
137+
| a + b
138+
|}""".stripMargin,
139+
List[String](
140+
"""|object Main extends App {
141+
| val func = (>>region>>a: Int<<region<<, b: Int) =>
142+
| a + b
143+
|}""".stripMargin,
144+
"""|object Main extends App {
145+
| val func = (>>region>>a: Int, b: Int<<region<<) =>
146+
| a + b
147+
|}""".stripMargin,
148+
"""|object Main extends App {
149+
| val func = >>region>>(a: Int, b: Int) =>
150+
| a + b<<region<<
151+
|}""".stripMargin,
152+
"""|object Main extends App {
153+
| >>region>>val func = (a: Int, b: Int) =>
154+
| a + b<<region<<
155+
|}""".stripMargin
156+
)
157+
)
158+
127159
@Test def `def - type params` =
128160
check(
129161
"object Main extends App { def foo[Type@@ <: T1, B](hi: Int, b: Int, c:Int) = ??? }",
@@ -133,3 +165,90 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite:
133165
"object Main extends App { >>region>>def foo[Type <: T1, B](hi: Int, b: Int, c:Int) = ???<<region<< }"
134166
)
135167
)
168+
169+
170+
@Test def `arithmetic` =
171+
check(
172+
"""|object Main extends App {
173+
| def x = 12 * (34 + 5@@6)
174+
|}""".stripMargin,
175+
List(
176+
"""|object Main extends App {
177+
| def x = 12 * (34 + >>region>>56<<region<<)
178+
|}""".stripMargin,
179+
"""|object Main extends App {
180+
| def x = 12 * (>>region>>34 + 56<<region<<)
181+
|}""".stripMargin,
182+
"""|object Main extends App {
183+
| def x = 12 * >>region>>(34 + 56)<<region<<
184+
|}""".stripMargin,
185+
"""|object Main extends App {
186+
| def x = >>region>>12 * (34 + 56)<<region<<
187+
|}""".stripMargin
188+
)
189+
)
190+
191+
@Test def `function` =
192+
check(
193+
"val hello = (aaa: Int, bb@@b: Int, ccc: Int) => ???",
194+
List(
195+
"val hello = (aaa: Int, >>region>>bbb: Int<<region<<, ccc: Int) => ???",
196+
"val hello = (>>region>>aaa: Int, bbb: Int, ccc: Int<<region<<) => ???",
197+
"val hello = >>region>>(aaa: Int, bbb: Int, ccc: Int) => ???<<region<<",
198+
">>region>>val hello = (aaa: Int, bbb: Int, ccc: Int) => ???<<region<<",
199+
)
200+
)
201+
202+
@Test def `defdef` =
203+
check(
204+
"def hello(aaa: Int, bb@@b: Int, ccc: Int) = ???",
205+
List(
206+
"def hello(aaa: Int, >>region>>bbb: Int<<region<<, ccc: Int) = ???",
207+
"def hello(>>region>>aaa: Int, bbb: Int, ccc: Int<<region<<) = ???",
208+
">>region>>def hello(aaa: Int, bbb: Int, ccc: Int) = ???<<region<<",
209+
)
210+
)
211+
212+
@Test def `apply` =
213+
check(
214+
"def hello = List(111, 2@@22, 333)",
215+
List(
216+
"def hello = List(111, >>region>>222<<region<<, 333)",
217+
"def hello = List(>>region>>111, 222, 333<<region<<)",
218+
"def hello = >>region>>List(111, 222, 333)<<region<<",
219+
">>region>>def hello = List(111, 222, 333)<<region<<",
220+
)
221+
)
222+
223+
@Test def `type-apply` =
224+
check(
225+
"def hello = Map[String, I@@nt]()",
226+
List(
227+
"def hello = Map[String, >>region>>Int<<region<<]()",
228+
"def hello = Map[>>region>>String, Int<<region<<]()",
229+
"def hello = >>region>>Map[String, Int]<<region<<()",
230+
"def hello = >>region>>Map[String, Int]()<<region<<",
231+
">>region>>def hello = Map[String, Int]()<<region<<",
232+
)
233+
)
234+
235+
@Test def `unapply` =
236+
check(
237+
"val List(aaa, b@@bb, ccc) = List(111, 222, 333)",
238+
List(
239+
"val List(aaa, >>region>>bbb<<region<<, ccc) = List(111, 222, 333)",
240+
"val List(>>region>>aaa, bbb, ccc<<region<<) = List(111, 222, 333)",
241+
"val >>region>>List(aaa, bbb, ccc)<<region<< = List(111, 222, 333)",
242+
">>region>>val List(aaa, bbb, ccc) = List(111, 222, 333)<<region<<",
243+
)
244+
)
245+
246+
@Test def `single` =
247+
check(
248+
"def hello = List(2@@22)",
249+
List(
250+
"def hello = List(>>region>>222<<region<<)",
251+
"def hello = >>region>>List(222)<<region<<",
252+
">>region>>def hello = List(222)<<region<<",
253+
)
254+
)

project/Build.scala

+4
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,10 @@ object Build {
15121512
ivyConfigurations += SourceDeps.hide,
15131513
transitiveClassifiers := Seq("sources"),
15141514
scalacOptions ++= Seq("-source", "3.3"), // To avoid fatal migration warnings
1515+
publishLocal := publishLocal.dependsOn( // It is best to publish all together. It is not rare to make changes in both compiler / presentation compiler and it can get misaligned
1516+
`scala3-compiler-bootstrapped` / publishLocal,
1517+
`scala3-library-bootstrapped` / publishLocal,
1518+
).value,
15151519
Compile / scalacOptions ++= Seq("-Yexplicit-nulls", "-Wsafe-init"),
15161520
Compile / sourceGenerators += Def.task {
15171521
val s = streams.value

0 commit comments

Comments
 (0)