Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ private static Class<?> getBigArrayType(Class<?> type)
return ObjectBigArray.class;
}

private static Class<?> bigArrayElementType(Class<?> bigArrayType)
{
if (bigArrayType.equals(LongBigArray.class)) {
return long.class;
}
if (bigArrayType.equals(ByteBigArray.class)) {
return byte.class;
}
if (bigArrayType.equals(DoubleBigArray.class)) {
return double.class;
}
if (bigArrayType.equals(BooleanBigArray.class)) {
return boolean.class;
}
if (bigArrayType.equals(IntBigArray.class)) {
return int.class;
}
if (bigArrayType.equals(SliceBigArray.class)) {
return Slice.class;
}
if (bigArrayType.equals(BlockBigArray.class)) {
return Block.class;
}
if (bigArrayType.equals(ObjectBigArray.class)) {
return Object.class;
}
throw new IllegalArgumentException("Unsupported bigArrayType: " + bigArrayType.getName());
}

public static <T extends AccumulatorState> AccumulatorStateSerializer<T> generateStateSerializer(Class<T> clazz)
{
return generateStateSerializer(clazz, ImmutableMap.of());
Expand Down Expand Up @@ -453,24 +482,26 @@ private static ClassDefinition generateInOutGroupedStateClass(Type type, CallSit
type(GroupedAccumulatorState.class),
type(InternalDataAccessor.class));

estimatedSize(definition);

MethodDefinition constructor = definition.declareConstructor(a(PUBLIC));
constructor.getBody()
.append(constructor.getThis())
.invokeConstructor(Object.class);

ImmutableList.Builder<FieldDefinition> fieldDefinitions = ImmutableList.builder();
FieldDefinition groupIdField = definition.declareField(a(PRIVATE), "groupId", long.class);

Class<?> valueElementType = inOutGetterReturnType(type);
FieldDefinition valueField = definition.declareField(a(PRIVATE, FINAL), "value", getBigArrayType(valueElementType));
Class<?> bigArrayType = getBigArrayType(type.getJavaType());
Class<?> valueElementType = bigArrayElementType(bigArrayType);
FieldDefinition valueField = definition.declareField(a(PRIVATE, FINAL), "value", bigArrayType);
fieldDefinitions.add(valueField);
constructor.getBody().append(constructor.getThis().setField(valueField, newInstance(valueField.getType())));
Function<Scope, BytecodeExpression> valueGetter = scope -> scope.getThis().getField(valueField).invoke("get", valueElementType, scope.getThis().getField(groupIdField));

Optional<FieldDefinition> nullField;
Function<Scope, BytecodeExpression> nullGetter;
if (type.getJavaType().isPrimitive()) {
nullField = Optional.of(definition.declareField(a(PRIVATE, FINAL), "valueIdNull", BooleanBigArray.class));
FieldDefinition valueIdNullDefinition = definition.declareField(a(PRIVATE, FINAL), "valueIdNull", BooleanBigArray.class);
nullField = Optional.of(valueIdNullDefinition);
fieldDefinitions.add(valueIdNullDefinition);
constructor.getBody().append(constructor.getThis().setField(nullField.get(), newInstance(BooleanBigArray.class, constantTrue())));
nullGetter = scope -> scope.getThis().getField(nullField.get()).invoke("get", boolean.class, scope.getThis().getField(groupIdField));
}
Expand All @@ -482,6 +513,9 @@ private static ClassDefinition generateInOutGroupedStateClass(Type type, CallSit
constructor.getBody()
.ret();

// Generate getEstimatedSize
estimatedSize(definition, fieldDefinitions.build());

inOutGroupedSetGroupId(definition, groupIdField);
inOutGroupedEnsureCapacity(definition, valueField, nullField);
inOutGroupedCopy(definition, valueField, nullField);
Expand Down Expand Up @@ -535,6 +569,23 @@ private static void estimatedSize(ClassDefinition definition)
.retLong();
}

private static void estimatedSize(ClassDefinition definition, List<FieldDefinition> fieldDefinitions)
{
FieldDefinition instanceSize = generateInstanceSize(definition);

MethodDefinition getEstimatedSize = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class));
BytecodeBlock body = getEstimatedSize.getBody();
Variable size = getEstimatedSize.getScope().declareVariable("size", body, getStatic(instanceSize));

// add field to size
for (FieldDefinition field : fieldDefinitions) {
body.append(size.set(add(size, getEstimatedSize.getThis().getField(field).invoke("sizeOf", long.class))));
}

// return size
body.append(size.ret());
}

private static void inOutSingleCopy(ClassDefinition definition, FieldDefinition valueField, Optional<FieldDefinition> nullField)
{
MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(AccumulatorState.class));
Expand Down Expand Up @@ -870,8 +921,6 @@ private static <T> Class<? extends T> generateGroupedStateClass(Class<T> clazz,
type(AbstractGroupedAccumulatorState.class),
type(clazz));

FieldDefinition instanceSize = generateInstanceSize(definition);

List<StateField> fields = enumerateFields(clazz, fieldTypes);

// Create constructor
Expand All @@ -893,21 +942,7 @@ private static <T> Class<? extends T> generateGroupedStateClass(Class<T> clazz,
ensureCapacity.getBody().ret();

// Generate getEstimatedSize
MethodDefinition getEstimatedSize = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class));
BytecodeBlock body = getEstimatedSize.getBody();

Variable size = getEstimatedSize.getScope().declareVariable(long.class, "size");

// initialize size to the size of the instance
body.append(size.set(getStatic(instanceSize)));

// add field to size
for (FieldDefinition field : fieldDefinitions) {
body.append(size.set(add(size, getEstimatedSize.getThis().getField(field).invoke("sizeOf", long.class))));
}

// return size
body.append(size.ret());
estimatedSize(definition, fieldDefinitions);

return defineClass(definition, clazz, classLoader);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.GroupedAccumulatorState;
import io.trino.spi.function.InOut;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
Expand Down Expand Up @@ -238,6 +239,43 @@ public void testComplexSerialization()
assertEquals(deserializedState.getAnotherBlock().getSlice(1, 0, 9), singleState.getAnotherBlock().getSlice(1, 0, 9));
}

@Test
public void testEstimatedInOutStatesInstanceSizes()
{
AccumulatorStateFactory<InOut> factory = StateCompiler.generateInOutStateFactory(BIGINT);
InOut groupedState = factory.createGroupedState();
InOut singleState = factory.createSingleState();

long expectedGroupedSize =
instanceSize(groupedState.getClass()) +
new LongBigArray().sizeOf() + // values, 1024 longs
new BooleanBigArray().sizeOf(); // isNull, 1024 booleans

assertEquals(groupedState.getEstimatedSize(), expectedGroupedSize);
assertEquals(groupedState.getEstimatedSize(), 17576);

assertEquals(singleState.getEstimatedSize(), instanceSize(singleState.getClass()));
assertEquals(singleState.getEstimatedSize(), 24);
}

@Test
public void testEstimatedStateInstanceSizes()
{
AccumulatorStateFactory<TestSimpleState> stateFactory = StateCompiler.generateStateFactory(TestSimpleState.class);
TestSimpleState groupedState = stateFactory.createGroupedState();
TestSimpleState singleState = stateFactory.createSingleState();

long expectedGroupedSize = instanceSize(groupedState.getClass()) +
new LongBigArray().sizeOf() +
new DoubleBigArray().sizeOf();

assertEquals(groupedState.getEstimatedSize(), expectedGroupedSize);
assertEquals(groupedState.getEstimatedSize(), 24752);

assertEquals(singleState.getEstimatedSize(), instanceSize(singleState.getClass()));
assertEquals(singleState.getEstimatedSize(), 32);
}

private static long getComplexStateRetainedSize(TestComplexState state)
{
long retainedSize = instanceSize(state.getClass());
Expand Down Expand Up @@ -359,6 +397,18 @@ public void testComplexStateEstimatedSize()
}
}

public interface TestSimpleState
extends AccumulatorState
{
long getLong();

void setLong(long value);

double getDouble();

void setDouble(double value);
}

public interface TestComplexState
extends AccumulatorState
{
Expand Down