diff --git a/enigma/build.gradle b/enigma/build.gradle index 7d2e26205..a73b993ea 100644 --- a/enigma/build.gradle +++ b/enigma/build.gradle @@ -19,6 +19,7 @@ dependencies { implementation libs.vineflower implementation libs.cfr implementation libs.procyon + implementation libs.javaparser implementation libs.quilt.config diff --git a/enigma/src/main/java/org/quiltmc/enigma/api/translation/representation/entry/ClassEntry.java b/enigma/src/main/java/org/quiltmc/enigma/api/translation/representation/entry/ClassEntry.java index 44da2495d..a7b36ed7b 100644 --- a/enigma/src/main/java/org/quiltmc/enigma/api/translation/representation/entry/ClassEntry.java +++ b/enigma/src/main/java/org/quiltmc/enigma/api/translation/representation/entry/ClassEntry.java @@ -64,7 +64,7 @@ public String getSourceRemapName() { @Override public String getContextualName() { if (this.isInnerClass()) { - return this.parent.getSimpleName() + "$" + this.name; + return this.parent.getContextualName() + "$" + this.name; } return this.getSimpleName(); diff --git a/enigma/src/main/java/org/quiltmc/enigma/impl/source/vineflower/EnigmaTextTokenCollector.java b/enigma/src/main/java/org/quiltmc/enigma/impl/source/vineflower/EnigmaTextTokenCollector.java index a24e818d7..e5b74ccc3 100644 --- a/enigma/src/main/java/org/quiltmc/enigma/impl/source/vineflower/EnigmaTextTokenCollector.java +++ b/enigma/src/main/java/org/quiltmc/enigma/impl/source/vineflower/EnigmaTextTokenCollector.java @@ -1,5 +1,35 @@ package org.quiltmc.enigma.impl.source.vineflower; +import com.github.javaparser.JavaParser; +import com.github.javaparser.ParseResult; +import com.github.javaparser.ParserConfiguration; +import com.github.javaparser.Range; +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.ast.Node; +import com.github.javaparser.ast.body.ConstructorDeclaration; +import com.github.javaparser.ast.body.FieldDeclaration; +import com.github.javaparser.ast.body.InitializerDeclaration; +import com.github.javaparser.ast.body.MethodDeclaration; +import com.github.javaparser.ast.body.TypeDeclaration; +import com.github.javaparser.ast.expr.Expression; +import com.github.javaparser.ast.expr.LambdaExpr; +import com.github.javaparser.ast.expr.MethodReferenceExpr; +import org.jetbrains.java.decompiler.code.CodeConstants; +import org.jetbrains.java.decompiler.code.Instruction; +import org.jetbrains.java.decompiler.code.InstructionSequence; +import org.jetbrains.java.decompiler.main.DecompilerContext; +import org.jetbrains.java.decompiler.main.extern.TextTokenVisitor; +import org.jetbrains.java.decompiler.struct.StructClass; +import org.jetbrains.java.decompiler.struct.StructMethod; +import org.jetbrains.java.decompiler.struct.attr.StructBootstrapMethodsAttribute; +import org.jetbrains.java.decompiler.struct.attr.StructGeneralAttribute; +import org.jetbrains.java.decompiler.struct.consts.ConstantPool; +import org.jetbrains.java.decompiler.struct.consts.LinkConstant; +import org.jetbrains.java.decompiler.struct.consts.PooledConstant; +import org.jetbrains.java.decompiler.struct.gen.FieldDescriptor; +import org.jetbrains.java.decompiler.struct.gen.MethodDescriptor; +import org.jetbrains.java.decompiler.util.Pair; +import org.jetbrains.java.decompiler.util.token.TextRange; import org.quiltmc.enigma.api.source.SourceIndex; import org.quiltmc.enigma.api.source.Token; import org.quiltmc.enigma.api.translation.representation.entry.ClassEntry; @@ -7,24 +37,34 @@ import org.quiltmc.enigma.api.translation.representation.entry.FieldEntry; import org.quiltmc.enigma.api.translation.representation.entry.LocalVariableEntry; import org.quiltmc.enigma.api.translation.representation.entry.MethodEntry; -import org.jetbrains.java.decompiler.main.extern.TextTokenVisitor; -import org.jetbrains.java.decompiler.struct.gen.FieldDescriptor; -import org.jetbrains.java.decompiler.struct.gen.MethodDescriptor; -import org.jetbrains.java.decompiler.util.Pair; -import org.jetbrains.java.decompiler.util.token.TextRange; +import org.quiltmc.enigma.util.LineIndexer; +import org.tinylog.Logger; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Deque; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.UnaryOperator; public class EnigmaTextTokenCollector extends TextTokenVisitor { private String content; - private MethodEntry currentMethod; + private LineIndexer lineIndexer; + private final Deque classStack = new ArrayDeque<>(); + private final Deque methodStack = new ArrayDeque<>(); private final Map> declarations = new HashMap<>(); private final Map, Entry>> references = new HashMap<>(); private final Map tokens = new LinkedHashMap<>(); + private final Map classRanges = new HashMap<>(); + private final List syntheticMethods = new ArrayList<>(); + private final Deque openSynthetic = new ArrayDeque<>(); + private final Map syntheticEntryBySpan = new HashMap<>(); public EnigmaTextTokenCollector(TextTokenVisitor next) { super(next); @@ -80,21 +120,320 @@ public void addTokensToIndex(SourceIndex index, UnaryOperator tokenProces } } + private void parseSource() { + ParserConfiguration config = new ParserConfiguration() + .setStoreTokens(true) + .setLanguageLevel(ParserConfiguration.LanguageLevel.RAW); + + ParseResult parseResult = new JavaParser(config).parse(this.content); + if (!parseResult.isSuccessful()) { + Logger.warn("Failed to parse source: {}", parseResult.getProblems()); + return; + } + + CompilationUnit unit = parseResult.getResult().get(); + List initializers = unit.findAll(InitializerDeclaration.class, InitializerDeclaration::isStatic); + for (InitializerDeclaration decl : initializers) { + TextRange range = this.getTextRangeForNode(decl); + if (range == null) { + continue; + } + + this.syntheticMethods.add(new SyntheticMethodSpan(range, false)); + } + + for (FieldDeclaration decl : unit.findAll(FieldDeclaration.class, FieldDeclaration::isStatic)) { + TextRange range = this.getTextRangeForNode(decl); + if (range == null) { + continue; + } + + this.syntheticMethods.add(new SyntheticMethodSpan(range, false)); + } + + String pkgPrefix = unit.getPackageDeclaration().map(decl -> decl.getNameAsString().replace('.', '/') + "/").orElse(""); + for (TypeDeclaration decl : unit.getTypes()) { + this.addClassAndChildren(decl, pkgPrefix + decl.getNameAsString()); + } + + for (ClassEntry classEntry : this.classRanges.keySet()) { + String[] parts = classEntry.getContextualName().split("\\$"); + TypeDeclaration type = null; + for (TypeDeclaration decl : unit.getTypes()) { + if (decl.getNameAsString().equals(parts[0])) { + type = decl; + break; + } + } + + for (int i = 1; i < parts.length; i++) { + if (type != null) { + TypeDeclaration finalType = type; + String name = parts[i]; + type = type.findFirst(TypeDeclaration.class, t -> t != finalType && t.getNameAsString().equals(name)).orElse(null); + } + } + + if (type == null) { + throw new IllegalStateException("Could not find type " + classEntry.getContextualName() + " in parsed source"); + } + + Map rootNodes = new HashMap<>(); + for (var member : type.getMembers()) { + if (member instanceof TypeDeclaration) { + continue; + } + + if (member instanceof ConstructorDeclaration constructor) { + LambdaNode rootNode = rootNodes.computeIfAbsent("", c -> new LambdaNode()); + this.findLambdasInSource(constructor.getBody(), rootNode); + } else if (member instanceof MethodDeclaration method) { + if (method.getBody().isPresent()) { + LambdaNode rootNode = rootNodes.computeIfAbsent(method.getNameAsString(), c -> new LambdaNode()); + this.findLambdasInSource(method.getBody().get(), rootNode); + } + } else { + LambdaNode rootNode = rootNodes.computeIfAbsent("", c -> new LambdaNode()); + for (LambdaExpr lambda : member.findAll(LambdaExpr.class)) { + if (lambda.findAncestor(LambdaExpr.class).isEmpty()) { + LambdaNode lambdaNode = new LambdaNode(lambda, false); + rootNode.children.add(lambdaNode); + this.findLambdasInSource(lambda, rootNode); + } + } + } + } + + StructClass clazz = DecompilerContext.getStructContext().getClass(classEntry.getFullName()); + if (clazz == null) { + throw new IllegalStateException("Class bytecode not found"); + } + + for (StructMethod method : clazz.getMethods()) { + LambdaNode rootNode = rootNodes.get(method.getName()); + if (rootNode == null) { + continue; + } + + this.pairContext(clazz, method, rootNode); + } + } + } + + private void pairContext(StructClass owner, StructMethod method, LambdaNode rootNode) { + List bytecodeLambdas = extractLambdasFromBytecode(owner, method); + int count = Math.min(bytecodeLambdas.size(), rootNode.children.size()); + for (int i = 0; i < count; i++) { + LambdaNode childNode = rootNode.children.get(i); + if (childNode.isMethodReference) { + continue; + } + + MethodEntry entry = bytecodeLambdas.get(i); + if (childNode.range == null) { + continue; + } + + SyntheticMethodSpan span = new SyntheticMethodSpan(childNode.range, true); + this.syntheticMethods.add(span); + this.syntheticEntryBySpan.put(span, entry); + StructClass entryClass = DecompilerContext.getStructContext().getClass(entry.getParent().getFullName()); + if (entryClass == null) { + continue; + } + + StructMethod entryMethod = entryClass.getMethod(entry.getName(), entry.getDesc().toString()); + this.pairContext(entryClass, entryMethod, childNode); + } + } + + private void findLambdasInSource(Node method, LambdaNode parentNode) { + for (var member : method.getChildNodes()) { + if (member instanceof LambdaExpr lambda) { + LambdaNode lambdaNode = new LambdaNode(lambda, false); + parentNode.children.add(lambdaNode); + this.findLambdasInSource(lambda, lambdaNode); + } else if (member instanceof MethodReferenceExpr methodRef) { + LambdaNode lambdaNode = new LambdaNode(methodRef, true); + parentNode.children.add(lambdaNode); + } else { + this.findLambdasInSource(member, parentNode); + } + } + } + + private static List extractLambdasFromBytecode(StructClass clazz, StructMethod method) { + List lambdas = new ArrayList<>(); + ConstantPool pool = clazz.getPool(); + + StructBootstrapMethodsAttribute bootstrapAttr = clazz.getAttribute(StructGeneralAttribute.ATTRIBUTE_BOOTSTRAP_METHODS); + + if (bootstrapAttr == null) { + return lambdas; + } + + if (!method.containsCode()) { + return lambdas; + } + + try { + method.expandData(clazz); + } catch (IOException e) { + return lambdas; + } + + InstructionSequence seq = method.getInstructionSequence(); + if (seq == null) { + return lambdas; + } + + for (int i = 0; i < seq.length(); i++) { + Instruction instr = seq.getInstr(i); + if (instr.opcode != CodeConstants.opc_invokedynamic) { + continue; + } + + int indyIndex = instr.operand(0); + PooledConstant constant = pool.getConstant(indyIndex); + if (!(constant instanceof LinkConstant link) || link.type != LinkConstant.CONSTANT_InvokeDynamic) { + continue; + } + + int bsmIndex = link.index1; + LinkConstant bootstrapMethod = bootstrapAttr.getMethodReference(bsmIndex); + String methodOwner = bootstrapMethod.classname; + String methodName = bootstrapMethod.elementname; + boolean isLambda = "java/lang/invoke/LambdaMetafactory".equals(methodOwner) + && ("metafactory".equals(methodName) || "altMetafactory".equals(methodName)); + if (!isLambda) { + continue; + } + + List args = bootstrapAttr.getMethodArguments(bsmIndex); + + if (args.size() < 3) { + continue; + } + + PooledConstant implConstant = args.get(1); + if (!(implConstant instanceof LinkConstant implMethod)) { + continue; + } + + String owner = implMethod.classname; + String name = implMethod.elementname; + String descriptor = implMethod.descriptor; + + lambdas.add(getMethodEntry(owner, name, MethodDescriptor.parseDescriptor(descriptor))); + } + + method.releaseResources(); + + return lambdas; + } + + private void addClassAndChildren(TypeDeclaration decl, String name) { + TextRange textRange = this.getTextRangeForNode(decl); + if (textRange == null) { + return; + } + + this.classRanges.put(getClassEntry(name), textRange); + decl.getMembers().forEach(member -> { + if (member instanceof TypeDeclaration child) { + this.addClassAndChildren(child, name + "$" + child.getNameAsString()); + } + }); + } + + private TextRange getTextRangeForNode(Node node) { + Optional rangeOpt = node.getRange(); + if (rangeOpt.isEmpty()) { + Logger.error("No range for node of type {}", node.getClass().getSimpleName()); + return null; + } + + Range range = rangeOpt.get(); + int start = this.lineIndexer.getIndex(range.begin); + int end = this.lineIndexer.getIndex(range.end); + return new TextRange(start, end - start); + } + + private void updateMethodStack(TextRange range) { + while (!this.openSynthetic.isEmpty() && !encloses(this.openSynthetic.peek(), range)) { + SyntheticMethodSpan span = this.openSynthetic.pop(); + this.syntheticEntryBySpan.remove(span); + this.methodStack.pop(); + } + + List enclosing = this.syntheticMethods.stream() + .filter(span -> encloses(span, range)) + .sorted(Comparator.comparingInt(span -> span.range.length).reversed()) + .toList(); + + for (SyntheticMethodSpan method : enclosing) { + if (!this.openSynthetic.contains(method)) { + MethodEntry entry = this.syntheticEntryBySpan.computeIfAbsent(method, this::getSyntheticMethodEntry); + if (this.methodStack.isEmpty() || !this.methodStack.peek().equals(entry)) { + this.methodStack.push(entry); + this.openSynthetic.push(method); + } + } + } + } + + private void pruneExitedClasses(TextRange range) { + if (this.classRanges.isEmpty()) { + return; // Parsing failed + } + + while (!this.classStack.isEmpty() + && (!this.classRanges.containsKey(this.classStack.peek()) + || this.classRanges.get(this.classStack.peek()).getEnd() < range.start)) { + this.classStack.pop(); + } + } + + private static boolean encloses(SyntheticMethodSpan outer, TextRange inner) { + return outer.range.start <= inner.start && outer.range.getEnd() >= inner.getEnd(); + } + + private MethodEntry getSyntheticMethodEntry(SyntheticMethodSpan method) { + if (method.isLambda) { + throw new IllegalStateException("Method entries for lambdas should have already been fetched"); + } else { + if (this.classStack.isEmpty()) { + throw new IllegalStateException("No class on the stack for synthetic method at " + method.range); + } + + return getMethodEntry(this.classStack.peek().getFullName(), "", MethodDescriptor.parseDescriptor("()V")); + } + } + @Override public void start(String content) { this.content = content; - this.currentMethod = null; + this.lineIndexer = new LineIndexer(content); + this.classRanges.clear(); + this.methodStack.clear(); + this.openSynthetic.clear(); + this.syntheticMethods.clear(); + this.syntheticEntryBySpan.clear(); + this.parseSource(); } @Override public void visitClass(TextRange range, boolean declaration, String name) { super.visitClass(range, declaration, name); Token token = this.getToken(range); + this.pruneExitedClasses(range); + this.updateMethodStack(range); if (declaration) { + this.classStack.push(getClassEntry(name)); this.addDeclaration(token, getClassEntry(name)); } else { - this.addReference(token, getClassEntry(name), this.currentMethod); + this.addReference(token, getClassEntry(name), this.methodStack.peek()); } } @@ -102,11 +441,13 @@ public void visitClass(TextRange range, boolean declaration, String name) { public void visitField(TextRange range, boolean declaration, String className, String name, FieldDescriptor descriptor) { super.visitField(range, declaration, className, name, descriptor); Token token = this.getToken(range); + this.pruneExitedClasses(range); + this.updateMethodStack(range); if (declaration) { this.addDeclaration(token, getFieldEntry(className, name, descriptor)); } else { - this.addReference(token, getFieldEntry(className, name, descriptor), this.currentMethod); + this.addReference(token, getFieldEntry(className, name, descriptor), this.methodStack.peek()); } } @@ -114,13 +455,20 @@ public void visitField(TextRange range, boolean declaration, String className, S public void visitMethod(TextRange range, boolean declaration, String className, String name, MethodDescriptor descriptor) { super.visitMethod(range, declaration, className, name, descriptor); Token token = this.getToken(range); + this.pruneExitedClasses(range); + this.updateMethodStack(range); MethodEntry entry = getMethodEntry(className, name, descriptor); if (declaration) { this.addDeclaration(token, entry); - this.currentMethod = entry; + if (!this.methodStack.isEmpty()) { + this.methodStack.pop(); + } + + this.methodStack.push(entry); } else { - this.addReference(token, entry, this.currentMethod); + MethodEntry context = !this.methodStack.isEmpty() ? this.methodStack.peek() : getMethodEntry(className, "", MethodDescriptor.parseDescriptor("()V")); + this.addReference(token, entry, context); } } @@ -128,12 +476,14 @@ public void visitMethod(TextRange range, boolean declaration, String className, public void visitParameter(TextRange range, boolean declaration, String className, String methodName, MethodDescriptor methodDescriptor, int idx, String name) { super.visitParameter(range, declaration, className, methodName, methodDescriptor, idx, name); Token token = this.getToken(range); + this.pruneExitedClasses(range); + this.updateMethodStack(range); MethodEntry parent = getMethodEntry(className, methodName, methodDescriptor); if (declaration) { this.addDeclaration(token, getParameterEntry(parent, idx, name)); } else { - this.addReference(token, getParameterEntry(parent, idx, name), this.currentMethod); + this.addReference(token, getParameterEntry(parent, idx, name), this.methodStack.peek()); } } @@ -141,12 +491,31 @@ public void visitParameter(TextRange range, boolean declaration, String classNam public void visitLocal(TextRange range, boolean declaration, String className, String methodName, MethodDescriptor methodDescriptor, int idx, String name) { super.visitLocal(range, declaration, className, methodName, methodDescriptor, idx, name); Token token = this.getToken(range); + this.pruneExitedClasses(range); + this.updateMethodStack(range); MethodEntry parent = getMethodEntry(className, methodName, methodDescriptor); if (declaration) { this.addDeclaration(token, getVariableEntry(parent, idx, name)); } else { - this.addReference(token, getVariableEntry(parent, idx, name), this.currentMethod); + this.addReference(token, getVariableEntry(parent, idx, name), this.methodStack.peek()); + } + } + + private record SyntheticMethodSpan(TextRange range, boolean isLambda) {} + + class LambdaNode { + final TextRange range; + final boolean isMethodReference; + final List children = new ArrayList<>(); + LambdaNode(Expression lambda, boolean isMethodReference) { + this.range = EnigmaTextTokenCollector.this.getTextRangeForNode(lambda); + this.isMethodReference = isMethodReference; + } + + LambdaNode() { + this.range = null; + this.isMethodReference = false; } } } diff --git a/enigma/src/main/java/org/quiltmc/enigma/util/LineIndexer.java b/enigma/src/main/java/org/quiltmc/enigma/util/LineIndexer.java new file mode 100644 index 000000000..e121895b2 --- /dev/null +++ b/enigma/src/main/java/org/quiltmc/enigma/util/LineIndexer.java @@ -0,0 +1,34 @@ +package org.quiltmc.enigma.util; + +import com.github.javaparser.Position; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class LineIndexer { + private static final Pattern LINE_END = Pattern.compile("\\r\\n?|\\n"); + + private final List indexesByLine = new ArrayList<>(); + private final Matcher lineEndMatcher; + + public LineIndexer(String string) { + // the first line always starts at 0 + this.indexesByLine.add(0); + this.lineEndMatcher = LINE_END.matcher(string); + } + + public int getStartIndex(int line) { + while (line >= this.indexesByLine.size() && this.lineEndMatcher.find()) { + this.indexesByLine.add(this.lineEndMatcher.end()); + } + + return line < this.indexesByLine.size() ? this.indexesByLine.get(line) : -1; + } + + public int getIndex(Position position) { + final int lineIndex = this.getStartIndex(position.line - Position.FIRST_LINE); + return lineIndex < 0 ? lineIndex : lineIndex + position.column - Position.FIRST_COLUMN; + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ac23c5adf..4e4509d24 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -11,6 +11,7 @@ swing_dpi = "0.10" fontchooser = "2.5.2" tinylog = "2.6.2" quilt_config = "1.3.2" +javaparser = "3.27.0" vineflower = "1.11.0" cfr = "0.2.2" @@ -42,6 +43,7 @@ flatlaf_extras = { module = "com.formdev:flatlaf-extras", version.ref = "flatlaf syntaxpain = { module = "org.quiltmc:syntaxpain", version.ref = "syntaxpain" } swing_dpi = { module = "com.github.lukeu:swing-dpi", version.ref = "swing_dpi" } fontchooser = { module = "org.drjekyll:fontchooser", version.ref = "fontchooser" } +javaparser = { module = "com.github.javaparser:javaparser-core", version.ref = "javaparser" } vineflower = { module = "org.vineflower:vineflower", version.ref = "vineflower" } cfr = { module = "net.fabricmc:cfr", version.ref = "cfr" }