diff --git a/compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/truffle/test/ByteArraySupportPartialEvaluationTest.java b/compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/truffle/test/ByteArraySupportPartialEvaluationTest.java index 3873a69bd3f0..2543f87a7d3a 100644 --- a/compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/truffle/test/ByteArraySupportPartialEvaluationTest.java +++ b/compiler/src/jdk.graal.compiler.test/src/jdk/graal/compiler/truffle/test/ByteArraySupportPartialEvaluationTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -69,6 +69,21 @@ public int execute(VirtualFrame frame) { } } + static class GetShortUnalignedNonStableNode extends AbstractTestNode { + @CompilationFinal(dimensions = 0) byte[] bytes; + final int offset; + + GetShortUnalignedNonStableNode(String hex, int offset) { + this.bytes = hexToBytes(hex); + this.offset = offset; + } + + @Override + public int execute(VirtualFrame frame) { + return BYTES.getShortUnaligned(bytes, offset); + } + } + static class GetIntNode extends AbstractTestNode { @CompilationFinal(dimensions = 1) byte[] bytes; final int offset; @@ -99,6 +114,21 @@ public int execute(VirtualFrame frame) { } } + static class GetIntUnalignedNonStableNode extends AbstractTestNode { + @CompilationFinal(dimensions = 0) byte[] bytes; + final int offset; + + GetIntUnalignedNonStableNode(String hex, int offset) { + this.bytes = hexToBytes(hex); + this.offset = offset; + } + + @Override + public int execute(VirtualFrame frame) { + return BYTES.getIntUnaligned(bytes, offset); + } + } + static class GetLongNode extends LongNode { @CompilationFinal(dimensions = 1) byte[] bytes; final int offset; @@ -129,6 +159,21 @@ public long execute(VirtualFrame frame) { } } + static class GetLongUnalignedNonStableNode extends LongNode { + @CompilationFinal(dimensions = 0) byte[] bytes; + final int offset; + + GetLongUnalignedNonStableNode(String hex, int offset) { + this.bytes = hexToBytes(hex); + this.offset = offset; + } + + @Override + public long execute(VirtualFrame frame) { + return BYTES.getLongUnaligned(bytes, offset); + } + } + private static byte[] hexToBytes(String s) { int len = s.length(); byte[] data = new byte[len / 2]; @@ -222,4 +267,44 @@ public void testGetLongUnaligned() { assertPartialEvalEquals(constLongRootNode(0x1122334455667788L), new LongRootNode(new GetLongUnalignedNode("00000000000000008877665544332211", 8))); assertPartialEvalEquals(constLongRootNode(0x1122334455667788L), new LongRootNode(new GetLongUnalignedNode("008877665544332211", 1))); } + + @Test + public void testGetUnalignedFromNonStableArray() { + assertPartialEvalEquals(readShortRootNode(), new RootTestNode("getShortUnaligned", new GetShortUnalignedNonStableNode("0089abcdef", 1))); + assertPartialEvalEquals(readIntRootNode(), new RootTestNode("getIntUnaligned", new GetIntUnalignedNonStableNode("000089abcdef", 2))); + assertPartialEvalEquals(readLongRootNode(), new LongRootNode(new GetLongUnalignedNonStableNode("008877665544332211", 1))); + } + + private static RootTestNode readShortRootNode() { + return new RootTestNode("readShort", new AbstractTestNode() { + private final byte[] bytes = new byte[8]; + + @Override + public int execute(VirtualFrame frame) { + return BYTES.getShort(bytes, 0); + } + }); + } + + private static RootTestNode readIntRootNode() { + return new RootTestNode("readInt", new AbstractTestNode() { + private final byte[] bytes = new byte[8]; + + @Override + public int execute(VirtualFrame frame) { + return BYTES.getInt(bytes, 0); + } + }); + } + + private static LongRootNode readLongRootNode() { + return new LongRootNode(new LongNode() { + private final byte[] bytes = new byte[8]; + + @Override + public long execute(VirtualFrame frame) { + return BYTES.getLong(bytes, 0); + } + }); + } } diff --git a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/nodes/ObjectLocationIdentity.java b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/nodes/ObjectLocationIdentity.java index c1c9c25571d3..b5d81b80d3a1 100644 --- a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/nodes/ObjectLocationIdentity.java +++ b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/nodes/ObjectLocationIdentity.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -35,6 +35,13 @@ /** * A {@link LocationIdentity} wrapping an object. + * + * Used by Truffle unsafe accesses to associate unique location identities with DynamicObject and + * FrameWithoutBoxing accesses, backed by non-null object references (compared by identity). The + * compiler can assume that accesses with different location identities (except for "any") do not + * interfere with each other (or when they do are constrained by memory barriers), even when they + * may access the same relative memory address (array or field offset) of objects of the same array + * or instance class. */ public final class ObjectLocationIdentity extends LocationIdentity implements JavaConstantFormattable { diff --git a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/substitutions/TruffleGraphBuilderPlugins.java b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/substitutions/TruffleGraphBuilderPlugins.java index 8b8771469608..48690aa498ad 100644 --- a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/substitutions/TruffleGraphBuilderPlugins.java +++ b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/truffle/substitutions/TruffleGraphBuilderPlugins.java @@ -28,6 +28,7 @@ import static jdk.graal.compiler.replacements.PEGraphDecoder.Options.MaximumLoopExplosionCount; import java.lang.reflect.Type; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; @@ -40,14 +41,17 @@ import com.oracle.truffle.compiler.TruffleCompilationTask; +import jdk.graal.compiler.core.common.NumUtil; import jdk.graal.compiler.core.common.calc.CanonicalCondition; import jdk.graal.compiler.core.common.memory.MemoryOrderMode; +import jdk.graal.compiler.core.common.type.IntegerStamp; import jdk.graal.compiler.core.common.type.ObjectStamp; import jdk.graal.compiler.core.common.type.Stamp; import jdk.graal.compiler.core.common.type.StampFactory; import jdk.graal.compiler.core.common.type.StampPair; import jdk.graal.compiler.core.common.type.TypeReference; import jdk.graal.compiler.debug.DebugContext; +import jdk.graal.compiler.debug.GraalError; import jdk.graal.compiler.graph.Node; import jdk.graal.compiler.lir.gen.ArithmeticLIRGeneratorTool.RoundingMode; import jdk.graal.compiler.nodes.CallTargetNode; @@ -61,6 +65,7 @@ import jdk.graal.compiler.nodes.InvokeNode; import jdk.graal.compiler.nodes.LogicConstantNode; import jdk.graal.compiler.nodes.LogicNode; +import jdk.graal.compiler.nodes.NamedLocationIdentity; import jdk.graal.compiler.nodes.NodeView; import jdk.graal.compiler.nodes.PiArrayNode; import jdk.graal.compiler.nodes.PiNode; @@ -133,6 +138,7 @@ import jdk.vm.ci.meta.DeoptimizationReason; import jdk.vm.ci.meta.JavaConstant; import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.MemoryAccessProvider; import jdk.vm.ci.meta.MetaAccessProvider; import jdk.vm.ci.meta.ResolvedJavaField; import jdk.vm.ci.meta.ResolvedJavaMethod; @@ -172,6 +178,7 @@ public static void registerInvocationPlugins(InvocationPlugins plugins, KnownTru registerDynamicObjectPlugins(plugins, types, canDelayIntrinsification, providers.getConstantReflection()); registerBufferPlugins(plugins, types, canDelayIntrinsification); registerMemorySegmentPlugins(plugins, types, canDelayIntrinsification); + registerByteArraySupportPlugins(plugins, canDelayIntrinsification); } private static void registerTruffleSafepointPlugins(InvocationPlugins plugins, KnownTruffleTypes types, boolean canDelayIntrinsification) { @@ -1332,18 +1339,17 @@ static class CustomizedUnsafeStorePlugin extends RequiredInvocationPlugin { @Override public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode object, ValueNode offset, ValueNode value, ValueNode location) { - ValueNode locationArgument = location; - if (locationArgument.isConstant()) { + if (location.isConstant()) { LocationIdentity locationIdentity; boolean forceLocation; - if (locationArgument.isNullConstant()) { + if (location.isNullConstant()) { locationIdentity = LocationIdentity.any(); forceLocation = false; - } else if (locationArgument.asJavaConstant().equals(anyConstant)) { + } else if (location.asJavaConstant().equals(anyConstant)) { locationIdentity = LocationIdentity.any(); forceLocation = true; } else { - locationIdentity = ObjectLocationIdentity.create(locationArgument.asJavaConstant()); + locationIdentity = ObjectLocationIdentity.create(location.asJavaConstant()); forceLocation = true; } b.add(new RawStoreNode(object, offset, value, kind, locationIdentity, true, null, forceLocation)); @@ -1377,7 +1383,7 @@ static void logPerformanceWarningLocationNotConstant(ValueNode location, Resolve debug.dump(DebugContext.VERBOSE_LEVEL, graph, "perf warn: Location argument is not a partial evaluation constant: %s", location); } } catch (Throwable t) { - debug.handle(t); + throw debug.handle(t); } } } @@ -1411,7 +1417,7 @@ static void logPerformanceWarningUnsafeCastArgNotConst(ResolvedJavaMethod target debug.dump(DebugContext.VERBOSE_LEVEL, graph, "perf warn: unsafeCast arguments could not reduce to a constant: %s, %s, %s", type, nonNull, isExactType); } } catch (Throwable t) { - debug.handle(t); + throw debug.handle(t); } } } @@ -1419,8 +1425,7 @@ static void logPerformanceWarningUnsafeCastArgNotConst(ResolvedJavaMethod target static BailoutException failPEConstant(GraphBuilderContext b, ValueNode value) { StringBuilder sb = new StringBuilder(); sb.append(value); - if (value instanceof ValuePhiNode) { - ValuePhiNode valuePhi = (ValuePhiNode) value; + if (value instanceof ValuePhiNode valuePhi) { sb.append(" ("); for (Node n : valuePhi.inputs()) { sb.append(n); @@ -1443,8 +1448,7 @@ private PEConstantPlugin(boolean canDelayIntrinsification, Type... argumentTypes @Override public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, ValueNode value) { ValueNode curValue = value; - if (curValue instanceof BoxNode) { - BoxNode boxNode = (BoxNode) curValue; + if (curValue instanceof BoxNode boxNode) { curValue = boxNode.getValue(); } if (curValue.isConstant()) { @@ -1457,4 +1461,121 @@ public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Rec } } + + private static void registerByteArraySupportPlugins(InvocationPlugins plugins, boolean canDelayIntrinsification) { + Registration r = new Registration(plugins, "com.oracle.truffle.api.memory.UnsafeByteArraySupport"); + r.register(new UnsafeGetUnalignedPlugin("unsafeGetShortUnaligned", JavaKind.Short, canDelayIntrinsification)); + r.register(new UnsafeGetUnalignedPlugin("unsafeGetIntUnaligned", JavaKind.Int, canDelayIntrinsification)); + r.register(new UnsafeGetUnalignedPlugin("unsafeGetLongUnaligned", JavaKind.Long, canDelayIntrinsification)); + } + + private static class UnsafeGetUnalignedPlugin extends OptionalInvocationPlugin { + private final boolean canDelayIntrinsification; + private final JavaKind resultKind; + + UnsafeGetUnalignedPlugin(String name, JavaKind resultKind, boolean canDelayIntrinsification) { + super(name, byte[].class, long.class); + this.canDelayIntrinsification = canDelayIntrinsification; + this.resultKind = resultKind; + assert resultKind == JavaKind.Short || resultKind == JavaKind.Int || resultKind == JavaKind.Long : resultKind; + GraalError.guarantee(ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN, "only supported on little-endian architecture"); + } + + @Override + public boolean apply(GraphBuilderContext b, ResolvedJavaMethod targetMethod, Receiver receiver, + ValueNode bufferNode, ValueNode byteOffsetNode) { + if (bufferNode instanceof ConstantNode bufferConstNode && byteOffsetNode.isConstant()) { + if (bufferConstNode.getStableDimension() == 1) { // implies non-null + JavaConstant bufferConst = bufferConstNode.asJavaConstant(); + long byteOffset = byteOffsetNode.asJavaConstant().asLong(); + JavaConstant value = readUnaligned(b, resultKind, bufferConst, byteOffset); + if (value != null && (bufferConstNode.isDefaultStable() || !value.isDefaultForKind())) { + b.addPush(resultKind, ConstantNode.forPrimitive(value, b.getGraph())); + return true; + } + } + } else if (canDelayIntrinsification) { + return false; + } + b.addPush(resultKind, new RawLoadNode(bufferNode, byteOffsetNode, resultKind, NamedLocationIdentity.getArrayLocation(JavaKind.Byte), MemoryOrderMode.PLAIN)); + return true; + } + + /** + * Reads a short, int, or long value from a potentially unaligned offset in a byte[] array. + * Performs a single aligned read if the address is aligned, otherwise combines the results + * of multiple reads of the next narrower naturally aligned width or individual bytes. + * + * @param resultKind value kind, either short, int, or long + * @param base byte[] array constant, with stable dimensions = 1 + * @param byteOffset byte[] index, not including array base offset + * @return result value constant or {@code null} if out of bounds + */ + @SuppressWarnings("fallthrough") + private static JavaConstant readUnaligned(GraphBuilderContext b, JavaKind resultKind, JavaConstant base, long byteOffset) { + ConstantReflectionProvider constantReflection = b.getConstantReflection(); + MemoryAccessProvider memoryAccessProvider = constantReflection.getMemoryAccessProvider(); + long displacement = b.getMetaAccess().getArrayBaseOffset(JavaKind.Byte) + byteOffset; + int resultBytes = resultKind.getByteCount(); + if (displacement % resultBytes == 0) { + // Already aligned, so we can read the value directly. + IntegerStamp accessStamp = StampFactory.forInteger(resultKind.getBitCount()); + return (JavaConstant) accessStamp.readConstant(memoryAccessProvider, base, displacement); + } + + // Figure out if we can read the value in wider-than-byte aligned parts. + JavaKind alignedKind = null; + switch (resultKind) { + case Long: + if (displacement % Integer.BYTES == 0) { + alignedKind = JavaKind.Int; + break; + } + // fallthrough + case Int: + if (displacement % Short.BYTES == 0) { + alignedKind = JavaKind.Short; + break; + } + break; + } + if (alignedKind != null) { + long value = 0; + long mask = NumUtil.getNbitNumberLong(alignedKind.getBitCount()); + IntegerStamp accessStamp = StampFactory.forInteger(alignedKind.getBitCount()); + for (int byteCount = 0; byteCount < resultBytes; byteCount += alignedKind.getByteCount()) { + var part = (JavaConstant) accessStamp.readConstant(memoryAccessProvider, base, displacement + byteCount); + if (part == null) { + /* + * Should not normally happen if base+displacement is aligned and in bounds; + * but in the unexpected case that the read fails, handle it gracefully. + */ + return null; + } + value |= ((part.asLong() & mask) << (byteCount * Byte.SIZE)); + } + return JavaConstant.forPrimitive(resultKind, value); + } + + // Displacement is odd, so we have to read the value byte-by-byte. + assert displacement % 2 != 0 : displacement; + long value = 0; + int byteOffsetAsInt = NumUtil.safeToInt(byteOffset); + for (int byteCount = 0; byteCount < resultBytes; byteCount += 2) { + JavaConstant b0 = constantReflection.readArrayElement(base, byteOffsetAsInt + byteCount); + JavaConstant b1 = constantReflection.readArrayElement(base, byteOffsetAsInt + byteCount + 1); + if (b0 == null || b1 == null) { + /* + * Byte offset is out of bounds. This is not necessarily an error since it + * depends on control flow / bounds checks if this read is actually reachable, + * so we must not fail compilation. We can either deoptimize here or fall back + * to a normal unsafe read. + */ + return null; + } + value |= (b0.asInt() & 0xffL | ((b1.asInt() & 0xffL) << Byte.SIZE)) << (byteCount * Byte.SIZE); + } + return JavaConstant.forPrimitive(resultKind, value); + } + } } diff --git a/truffle/src/com.oracle.truffle.api/src/com/oracle/truffle/api/memory/UnsafeByteArraySupport.java b/truffle/src/com.oracle.truffle.api/src/com/oracle/truffle/api/memory/UnsafeByteArraySupport.java index 7d19a4898c21..a489741bcab9 100644 --- a/truffle/src/com.oracle.truffle.api/src/com/oracle/truffle/api/memory/UnsafeByteArraySupport.java +++ b/truffle/src/com.oracle.truffle.api/src/com/oracle/truffle/api/memory/UnsafeByteArraySupport.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * The Universal Permissive License (UPL), Version 1.0 @@ -41,8 +41,6 @@ package com.oracle.truffle.api.memory; -import sun.misc.Unsafe; - import java.lang.reflect.Field; import java.nio.ByteOrder; import java.security.AccessController; @@ -50,18 +48,14 @@ import com.oracle.truffle.api.CompilerDirectives; +import sun.misc.Unsafe; + /** * Implementation of {@link ByteArraySupport} using {@link Unsafe}. *
* Bytes ordering is native endianness ({@link ByteOrder#nativeOrder}).
*/
final class UnsafeByteArraySupport extends ByteArraySupport {
- /**
- * Partial evaluation does not constant-fold unaligned accesses, so in compiled code we
- * decompose unaligned accesses into multiple aligned accesses that can be constant-folded. This
- * optimization is only tested on little-endian platforms.
- */
- private static final boolean OPTIMIZED_UNALIGNED_SUPPORTED = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN;
@SuppressWarnings("deprecation") private static final Unsafe UNSAFE = AccessController.doPrivileged(new PrivilegedAction