diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java index 015305b0a78f..5911d0e4674e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java @@ -141,6 +141,35 @@ private static Class getBigArrayType(Class type) return ObjectBigArray.class; } + private static Class bigArrayElementType(Class bigArrayType) + { + if (bigArrayType.equals(LongBigArray.class)) { + return long.class; + } + if (bigArrayType.equals(ByteBigArray.class)) { + return byte.class; + } + if (bigArrayType.equals(DoubleBigArray.class)) { + return double.class; + } + if (bigArrayType.equals(BooleanBigArray.class)) { + return boolean.class; + } + if (bigArrayType.equals(IntBigArray.class)) { + return int.class; + } + if (bigArrayType.equals(SliceBigArray.class)) { + return Slice.class; + } + if (bigArrayType.equals(BlockBigArray.class)) { + return Block.class; + } + if (bigArrayType.equals(ObjectBigArray.class)) { + return Object.class; + } + throw new IllegalArgumentException("Unsupported bigArrayType: " + bigArrayType.getName()); + } + public static AccumulatorStateSerializer generateStateSerializer(Class clazz) { return generateStateSerializer(clazz, ImmutableMap.of()); @@ -453,24 +482,26 @@ private static ClassDefinition generateInOutGroupedStateClass(Type type, CallSit type(GroupedAccumulatorState.class), type(InternalDataAccessor.class)); - estimatedSize(definition); - MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); constructor.getBody() .append(constructor.getThis()) .invokeConstructor(Object.class); + ImmutableList.Builder fieldDefinitions = ImmutableList.builder(); FieldDefinition groupIdField = definition.declareField(a(PRIVATE), "groupId", long.class); - - Class valueElementType = inOutGetterReturnType(type); - FieldDefinition valueField = definition.declareField(a(PRIVATE, FINAL), "value", getBigArrayType(valueElementType)); + Class bigArrayType = getBigArrayType(type.getJavaType()); + Class valueElementType = bigArrayElementType(bigArrayType); + FieldDefinition valueField = definition.declareField(a(PRIVATE, FINAL), "value", bigArrayType); + fieldDefinitions.add(valueField); constructor.getBody().append(constructor.getThis().setField(valueField, newInstance(valueField.getType()))); Function valueGetter = scope -> scope.getThis().getField(valueField).invoke("get", valueElementType, scope.getThis().getField(groupIdField)); Optional nullField; Function nullGetter; if (type.getJavaType().isPrimitive()) { - nullField = Optional.of(definition.declareField(a(PRIVATE, FINAL), "valueIdNull", BooleanBigArray.class)); + FieldDefinition valueIdNullDefinition = definition.declareField(a(PRIVATE, FINAL), "valueIdNull", BooleanBigArray.class); + nullField = Optional.of(valueIdNullDefinition); + fieldDefinitions.add(valueIdNullDefinition); constructor.getBody().append(constructor.getThis().setField(nullField.get(), newInstance(BooleanBigArray.class, constantTrue()))); nullGetter = scope -> scope.getThis().getField(nullField.get()).invoke("get", boolean.class, scope.getThis().getField(groupIdField)); } @@ -482,6 +513,9 @@ private static ClassDefinition generateInOutGroupedStateClass(Type type, CallSit constructor.getBody() .ret(); + // Generate getEstimatedSize + estimatedSize(definition, fieldDefinitions.build()); + inOutGroupedSetGroupId(definition, groupIdField); inOutGroupedEnsureCapacity(definition, valueField, nullField); inOutGroupedCopy(definition, valueField, nullField); @@ -535,6 +569,23 @@ private static void estimatedSize(ClassDefinition definition) .retLong(); } + private static void estimatedSize(ClassDefinition definition, List fieldDefinitions) + { + FieldDefinition instanceSize = generateInstanceSize(definition); + + MethodDefinition getEstimatedSize = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class)); + BytecodeBlock body = getEstimatedSize.getBody(); + Variable size = getEstimatedSize.getScope().declareVariable("size", body, getStatic(instanceSize)); + + // add field to size + for (FieldDefinition field : fieldDefinitions) { + body.append(size.set(add(size, getEstimatedSize.getThis().getField(field).invoke("sizeOf", long.class)))); + } + + // return size + body.append(size.ret()); + } + private static void inOutSingleCopy(ClassDefinition definition, FieldDefinition valueField, Optional nullField) { MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(AccumulatorState.class)); @@ -870,8 +921,6 @@ private static Class generateGroupedStateClass(Class clazz, type(AbstractGroupedAccumulatorState.class), type(clazz)); - FieldDefinition instanceSize = generateInstanceSize(definition); - List fields = enumerateFields(clazz, fieldTypes); // Create constructor @@ -893,21 +942,7 @@ private static Class generateGroupedStateClass(Class clazz, ensureCapacity.getBody().ret(); // Generate getEstimatedSize - MethodDefinition getEstimatedSize = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class)); - BytecodeBlock body = getEstimatedSize.getBody(); - - Variable size = getEstimatedSize.getScope().declareVariable(long.class, "size"); - - // initialize size to the size of the instance - body.append(size.set(getStatic(instanceSize))); - - // add field to size - for (FieldDefinition field : fieldDefinitions) { - body.append(size.set(add(size, getEstimatedSize.getThis().getField(field).invoke("sizeOf", long.class)))); - } - - // return size - body.append(size.ret()); + estimatedSize(definition, fieldDefinitions); return defineClass(definition, clazz, classLoader); } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java index 11364de7e7d3..1603ab0a058c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestStateCompiler.java @@ -31,6 +31,7 @@ import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.function.GroupedAccumulatorState; +import io.trino.spi.function.InOut; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; @@ -238,6 +239,43 @@ public void testComplexSerialization() assertEquals(deserializedState.getAnotherBlock().getSlice(1, 0, 9), singleState.getAnotherBlock().getSlice(1, 0, 9)); } + @Test + public void testEstimatedInOutStatesInstanceSizes() + { + AccumulatorStateFactory factory = StateCompiler.generateInOutStateFactory(BIGINT); + InOut groupedState = factory.createGroupedState(); + InOut singleState = factory.createSingleState(); + + long expectedGroupedSize = + instanceSize(groupedState.getClass()) + + new LongBigArray().sizeOf() + // values, 1024 longs + new BooleanBigArray().sizeOf(); // isNull, 1024 booleans + + assertEquals(groupedState.getEstimatedSize(), expectedGroupedSize); + assertEquals(groupedState.getEstimatedSize(), 17576); + + assertEquals(singleState.getEstimatedSize(), instanceSize(singleState.getClass())); + assertEquals(singleState.getEstimatedSize(), 24); + } + + @Test + public void testEstimatedStateInstanceSizes() + { + AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(TestSimpleState.class); + TestSimpleState groupedState = stateFactory.createGroupedState(); + TestSimpleState singleState = stateFactory.createSingleState(); + + long expectedGroupedSize = instanceSize(groupedState.getClass()) + + new LongBigArray().sizeOf() + + new DoubleBigArray().sizeOf(); + + assertEquals(groupedState.getEstimatedSize(), expectedGroupedSize); + assertEquals(groupedState.getEstimatedSize(), 24752); + + assertEquals(singleState.getEstimatedSize(), instanceSize(singleState.getClass())); + assertEquals(singleState.getEstimatedSize(), 32); + } + private static long getComplexStateRetainedSize(TestComplexState state) { long retainedSize = instanceSize(state.getClass()); @@ -359,6 +397,18 @@ public void testComplexStateEstimatedSize() } } + public interface TestSimpleState + extends AccumulatorState + { + long getLong(); + + void setLong(long value); + + double getDouble(); + + void setDouble(double value); + } + public interface TestComplexState extends AccumulatorState {