From a9b57d03c9b15146d3a37b14c9ad405a4d86b94a Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Tue, 4 Feb 2025 18:47:09 -0800 Subject: [PATCH 01/17] more types --- .../kernel/internal/util/VectorUtils.java | 139 ++++++++++++++++-- 1 file changed, 127 insertions(+), 12 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 1f7daf2913e..5d2e9d881c4 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -23,10 +23,13 @@ import io.delta.kernel.data.Row; import io.delta.kernel.internal.data.StructRow; import io.delta.kernel.types.*; + +import java.math.BigDecimal; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; public final class VectorUtils { @@ -71,14 +74,12 @@ public static Map toJavaMap(MapValue mapValue) { } /** - * Creates an {@link ArrayValue} from list of strings. The type {@code array(string)} is a common - * occurrence in Delta Log schema. We don't have any non-string array type in Delta Log. If we end - * up needing to support other types, we can make this generic. + * Creates an {@link ArrayValue} from list of object. * - * @param values list of strings + * @param values list of object * @return an {@link ArrayValue} with the given values of type {@link StringType} */ - public static ArrayValue stringArrayValue(List values) { + public static ArrayValue buildArrayValue(List values, DataType dataType) { if (values == null) { return null; } @@ -90,7 +91,7 @@ public int getSize() { @Override public ColumnVector getElements() { - return stringVector(values); + return buildColumnVector(values, dataType); } }; } @@ -117,27 +118,28 @@ public int getSize() { @Override public ColumnVector getKeys() { - return stringVector(keys); + return buildColumnVector(keys, StringType.STRING); } @Override public ColumnVector getValues() { - return stringVector(values); + return buildColumnVector(values, StringType.STRING); } }; } /** - * Utility method to create a {@link ColumnVector} for given list of strings. + * Utility method to create a {@link ColumnVector} for given list of object, the object should be + * primitive type or an Row instance. * * @param values list of strings * @return a {@link ColumnVector} with the given values of type {@link StringType} */ - public static ColumnVector stringVector(List values) { + public static ColumnVector buildColumnVector(List values, DataType dataType) { return new ColumnVector() { @Override public DataType getDataType() { - return StringType.STRING; + return dataType; } @Override @@ -156,10 +158,123 @@ public boolean isNullAt(int rowId) { return values.get(rowId) == null; } + @Override + public int getInt(int rowId) { + checkArgument(IntegerType.INTEGER.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Integer); + return (Integer) values.get(rowId); + } + + @Override + public long getLong(int rowId) { + checkArgument(LongType.LONG.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Long); + return (Long) values.get(rowId); + } + + @Override + public float getFloat(int rowId) { + checkArgument(FloatType.FLOAT.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Float); + return (Float) values.get(rowId); + } + + @Override + public double getDouble(int rowId) { + checkArgument(DoubleType.DOUBLE.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Float); + return (Double) values.get(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + checkArgument(DoubleType.DOUBLE.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof byte[]); + return (byte[]) values.get(rowId); + } + + @Override + public BigDecimal getDecimal(int rowId) { + checkArgument(dataType instanceof DecimalType); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof BigDecimal); + return (BigDecimal) values.get(rowId); + } + + @Override + public ArrayValue getArray(int rowId) { + checkArgument(dataType instanceof ArrayValue); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof ArrayValue); + return (ArrayValue) values.get(rowId); + } + @Override public String getString(int rowId) { + checkArgument(StringType.STRING.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + Object value = values.get(rowId); + checkArgument(value instanceof String); + return (String) values.get(rowId); + } + + @Override + public boolean getBoolean(int rowId) { + checkArgument(BooleanType.BOOLEAN.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + return (Boolean) values.get(rowId); + } + + @Override + public MapValue getMap(int rowId) { + checkArgument(dataType instanceof MapType); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - return values.get(rowId); + Object value = values.get(rowId); + checkArgument(value instanceof MapType); + return (MapValue) values.get(rowId); + } + + @Override + public ColumnVector getChild(int ordinal) { + checkArgument(dataType instanceof StructType); + checkArgument(ordinal < ((StructType) dataType).length()); + DataType childDatatype = ((StructType) dataType).at(ordinal).getDataType(); + return buildColumnVector( + values.stream() + .map( + e -> { + checkArgument(e instanceof Row); + Row row = (Row) e; + if (row.isNullAt(ordinal)) { + return null; + } + if (childDatatype.equals(StringType.STRING)) { + return row.getString(ordinal); + } else if (childDatatype.equals(LongType.LONG)) { + return row.getLong(ordinal); + } else if (childDatatype.equals(IntegerType.INTEGER)) { + return row.getInt(ordinal); + } else if (childDatatype.equals(BooleanType.BOOLEAN)) { + return row.getBoolean(ordinal); + } else if (childDatatype instanceof MapType) { + return row.getMap(ordinal); + } + return row.getStruct(ordinal); + }) + .collect(Collectors.toList()), + childDatatype); } }; } From 8f6ff11f3c1e486d5c611104d5927f3cf3ee7d86 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Tue, 4 Feb 2025 19:49:43 -0800 Subject: [PATCH 02/17] all types --- .../kernel/internal/util/VectorUtils.java | 97 +++++++++++++------ 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 5d2e9d881c4..abddb7f11f7 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -23,7 +23,6 @@ import io.delta.kernel.data.Row; import io.delta.kernel.internal.data.StructRow; import io.delta.kernel.types.*; - import java.math.BigDecimal; import java.util.ArrayList; import java.util.HashMap; @@ -158,6 +157,27 @@ public boolean isNullAt(int rowId) { return values.get(rowId) == null; } + @Override + public boolean getBoolean(int rowId) { + checkArgument(BooleanType.BOOLEAN.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + return (Boolean) values.get(rowId); + } + + @Override + public byte getByte(int rowId) { + checkArgument(ByteType.BYTE.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + return (byte) values.get(rowId); + } + + @Override + public short getShort(int rowId) { + checkArgument(ShortType.SHORT.equals(dataType)); + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + return (short) values.get(rowId); + } + @Override public int getInt(int rowId) { checkArgument(IntegerType.INTEGER.equals(dataType)); @@ -194,15 +214,6 @@ public double getDouble(int rowId) { return (Double) values.get(rowId); } - @Override - public byte[] getBinary(int rowId) { - checkArgument(DoubleType.DOUBLE.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof byte[]); - return (byte[]) values.get(rowId); - } - @Override public BigDecimal getDecimal(int rowId) { checkArgument(dataType instanceof DecimalType); @@ -213,28 +224,30 @@ public BigDecimal getDecimal(int rowId) { } @Override - public ArrayValue getArray(int rowId) { - checkArgument(dataType instanceof ArrayValue); + public String getString(int rowId) { + checkArgument(StringType.STRING.equals(dataType)); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); - checkArgument(value instanceof ArrayValue); - return (ArrayValue) values.get(rowId); + checkArgument(value instanceof String); + return (String) values.get(rowId); } @Override - public String getString(int rowId) { - checkArgument(StringType.STRING.equals(dataType)); + public byte[] getBinary(int rowId) { + checkArgument(BinaryType.BINARY.equals(dataType)); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); - checkArgument(value instanceof String); - return (String) values.get(rowId); + checkArgument(value instanceof byte[]); + return (byte[]) values.get(rowId); } @Override - public boolean getBoolean(int rowId) { - checkArgument(BooleanType.BOOLEAN.equals(dataType)); + public ArrayValue getArray(int rowId) { + checkArgument(dataType instanceof ArrayValue); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - return (Boolean) values.get(rowId); + Object value = values.get(rowId); + checkArgument(value instanceof ArrayValue); + return (ArrayValue) values.get(rowId); } @Override @@ -260,18 +273,40 @@ public ColumnVector getChild(int ordinal) { if (row.isNullAt(ordinal)) { return null; } - if (childDatatype.equals(StringType.STRING)) { - return row.getString(ordinal); - } else if (childDatatype.equals(LongType.LONG)) { - return row.getLong(ordinal); - } else if (childDatatype.equals(IntegerType.INTEGER)) { - return row.getInt(ordinal); - } else if (childDatatype.equals(BooleanType.BOOLEAN)) { - return row.getBoolean(ordinal); - } else if (childDatatype instanceof MapType) { + if (childDatatype instanceof BooleanType) { + return row.getBoolean(ordinal); + } else if (childDatatype instanceof ByteType) { + return row.getByte(ordinal); + } else if (childDatatype instanceof ShortType) { + return row.getShort(ordinal); + } else if (childDatatype instanceof IntegerType + || childDatatype instanceof DateType) { + // DateType data is stored internally as the number of days since 1970-01-01 + return row.getInt(ordinal); + } else if (childDatatype instanceof LongType + || childDatatype instanceof TimestampType) { + // TimestampType data is stored internally as the number of microseconds + // since the unix epoch + return row.getLong(ordinal); + } else if (childDatatype instanceof FloatType) { + return row.getFloat(ordinal); + } else if (childDatatype instanceof DoubleType) { + return row.getDouble(ordinal); + } else if (childDatatype instanceof StringType) { + return row.getString(ordinal); + } else if (childDatatype instanceof BinaryType) { + return row.getBinary(ordinal); + } else if (childDatatype instanceof StructType) { + return row.getStruct(ordinal); + } else if (childDatatype instanceof DecimalType) { + return row.getDecimal(ordinal); + } else if (childDatatype instanceof ArrayType) { + return row.getArray(ordinal); + } else if (dataType instanceof MapType) { return row.getMap(ordinal); + } else { + throw new UnsupportedOperationException("unsupported data type"); } - return row.getStruct(ordinal); }) .collect(Collectors.toList()), childDatatype); From 42ba000d1dd384cac8bf9fba7fe1c00d132fb1b4 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Tue, 4 Feb 2025 20:11:15 -0800 Subject: [PATCH 03/17] impl --- .../internal/TransactionBuilderImpl.java | 5 +-- .../kernel/internal/actions/Protocol.java | 6 ++-- .../kernel/internal/util/VectorUtils.java | 32 +++++++++++-------- .../io/delta/kernel/TransactionSuite.scala | 2 +- .../internal/util/ColumnMappingSuite.scala | 2 +- .../internal/util/VectorUtilsSuite.scala | 26 +++++++++++++++ 6 files changed, 53 insertions(+), 20 deletions(-) create mode 100644 kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionBuilderImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionBuilderImpl.java index fd16f04ce12..2841eb62b7b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionBuilderImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionBuilderImpl.java @@ -22,7 +22,7 @@ import static io.delta.kernel.internal.util.ColumnMapping.isColumnMappingModeEnabled; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import static io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames; -import static io.delta.kernel.internal.util.VectorUtils.stringArrayValue; +import static io.delta.kernel.internal.util.VectorUtils.buildArrayValue; import static io.delta.kernel.internal.util.VectorUtils.stringStringMapValue; import static java.util.Objects.requireNonNull; @@ -40,6 +40,7 @@ import io.delta.kernel.internal.util.ColumnMapping.ColumnMappingMode; import io.delta.kernel.internal.util.SchemaUtils; import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.types.StringType; import io.delta.kernel.types.StructType; import java.util.*; import org.slf4j.Logger; @@ -284,7 +285,7 @@ private Metadata getInitialMetadata() { new Format(), /* format */ schema.get().toJson(), /* schemaString */ schema.get(), /* schema */ - stringArrayValue(partitionColumnsCasePreserving), /* partitionColumns */ + buildArrayValue(partitionColumnsCasePreserving, StringType.STRING), /* partitionColumns */ Optional.of(currentTimeMillis), /* createdTime */ stringStringMapValue(Collections.emptyMap()) /* configuration */); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java index e71036935dd..7ffdc61e5cb 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java @@ -15,7 +15,7 @@ */ package io.delta.kernel.internal.actions; -import static io.delta.kernel.internal.util.VectorUtils.stringArrayValue; +import static io.delta.kernel.internal.util.VectorUtils.buildArrayValue; import io.delta.kernel.data.*; import io.delta.kernel.internal.TableFeatures; @@ -105,8 +105,8 @@ public Row toRow() { Map protocolMap = new HashMap<>(); protocolMap.put(0, minReaderVersion); protocolMap.put(1, minWriterVersion); - protocolMap.put(2, stringArrayValue(readerFeatures)); - protocolMap.put(3, stringArrayValue(writerFeatures)); + protocolMap.put(2, buildArrayValue(readerFeatures, StringType.STRING)); + protocolMap.put(3, buildArrayValue(writerFeatures, StringType.STRING)); return new GenericRow(Protocol.FULL_SCHEMA, protocolMap); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index abddb7f11f7..57392cb0585 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -161,21 +161,27 @@ public boolean isNullAt(int rowId) { public boolean getBoolean(int rowId) { checkArgument(BooleanType.BOOLEAN.equals(dataType)); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - return (Boolean) values.get(rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Boolean); + return (Boolean) value; } @Override public byte getByte(int rowId) { checkArgument(ByteType.BYTE.equals(dataType)); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - return (byte) values.get(rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Byte); + return (Byte) value; } @Override public short getShort(int rowId) { checkArgument(ShortType.SHORT.equals(dataType)); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - return (short) values.get(rowId); + Object value = values.get(rowId); + checkArgument(value instanceof Short); + return (Short) value; } @Override @@ -184,7 +190,7 @@ public int getInt(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof Integer); - return (Integer) values.get(rowId); + return (Integer) value; } @Override @@ -193,7 +199,7 @@ public long getLong(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof Long); - return (Long) values.get(rowId); + return (Long) value; } @Override @@ -202,7 +208,7 @@ public float getFloat(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof Float); - return (Float) values.get(rowId); + return (Float) value; } @Override @@ -210,8 +216,8 @@ public double getDouble(int rowId) { checkArgument(DoubleType.DOUBLE.equals(dataType)); checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); - checkArgument(value instanceof Float); - return (Double) values.get(rowId); + checkArgument(value instanceof Double); + return (Double) value; } @Override @@ -220,7 +226,7 @@ public BigDecimal getDecimal(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof BigDecimal); - return (BigDecimal) values.get(rowId); + return (BigDecimal) value; } @Override @@ -229,7 +235,7 @@ public String getString(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof String); - return (String) values.get(rowId); + return (String) value; } @Override @@ -238,7 +244,7 @@ public byte[] getBinary(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof byte[]); - return (byte[]) values.get(rowId); + return (byte[]) value; } @Override @@ -247,7 +253,7 @@ public ArrayValue getArray(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof ArrayValue); - return (ArrayValue) values.get(rowId); + return (ArrayValue) value; } @Override @@ -256,7 +262,7 @@ public MapValue getMap(int rowId) { checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); Object value = values.get(rowId); checkArgument(value instanceof MapType); - return (MapValue) values.get(rowId); + return (MapValue) value; } @Override diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/TransactionSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/TransactionSuite.scala index 5500eea8dba..d56dfd8f346 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/TransactionSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/TransactionSuite.scala @@ -198,7 +198,7 @@ object TransactionSuite extends VectorTestUtils with MockEngineUtils { new Format(), DataTypeJsonSerDe.serializeDataType(schema), schema, - VectorUtils.stringArrayValue(partitionCols.asJava), // partitionColumns + VectorUtils.buildArrayValue(partitionCols.asJava, StringType.STRING), // partitionColumns Optional.empty(), // createdTime stringStringMapValue(configurationMap.asJava) // configurationMap ) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/ColumnMappingSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/ColumnMappingSuite.scala index 0e1c4fafaec..d552df918bd 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/ColumnMappingSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/ColumnMappingSuite.scala @@ -473,7 +473,7 @@ class ColumnMappingSuite extends AnyFunSuite { new Format, schema.toJson, schema, - VectorUtils.stringArrayValue(Collections.emptyList()), + VectorUtils.buildArrayValue(Collections.emptyList(), StringType.STRING), Optional.empty(), VectorUtils.stringStringMapValue(Collections.emptyMap())) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala new file mode 100644 index 00000000000..e2f278c9fd4 --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -0,0 +1,26 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.internal.util +import org.scalatest.funsuite.AnyFunSuite + +class VectorUtilsSuite extends AnyFunSuite{ + + test("test build column vector from list of primitives") { + + } + +} From 3360fa7fc5cb8bc01ea7212cd66159ec5d2d9f40 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Wed, 5 Feb 2025 10:51:47 -0800 Subject: [PATCH 04/17] refactor test utils --- .../internal/util/VectorUtilsSuite.scala | 13 ++++- .../delta/kernel/test/VectorTestUtils.scala | 49 +++++++++++++++++++ .../expressions/ExpressionSuiteBase.scala | 44 ----------------- .../utils/DefaultVectorTestUtils.scala | 3 ++ 4 files changed, 63 insertions(+), 46 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala index e2f278c9fd4..459a52670b7 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (2024) The Delta Lake Project Authors. + * Copyright (2025) The Delta Lake Project Authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,12 +15,21 @@ */ package io.delta.kernel.internal.util + +import io.delta.kernel.test.VectorTestUtils +import io.delta.kernel.types.BooleanType import org.scalatest.funsuite.AnyFunSuite -class VectorUtilsSuite extends AnyFunSuite{ +import java.lang.{Boolean => BooleanJ} +import java.util + +class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { test("test build column vector from list of primitives") { + val childColumn = booleanVector(Seq[BooleanJ](true, false, null)) + VectorUtils.buildColumnVector(util.Arrays.asList(true, false, null), BooleanType.BOOLEAN) + } } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala index 7396522909d..979ed6d3307 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala @@ -21,6 +21,7 @@ import io.delta.kernel.types._ import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ} import scala.collection.JavaConverters._ +import org.scalatest.Assertions.convertToEqualizer trait VectorTestUtils { @@ -133,4 +134,52 @@ trait VectorTestUtils { override def getBoolean(rowId: Int): Boolean = rowId == selectRowId } + + protected def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getDataType === expected.getDataType) + assert(actual.getSize === expected.getSize) + Seq.range(0, actual.getSize).foreach { rowId => + assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) + if (!actual.isNullAt(rowId)) { + assert( + actual.getBoolean(rowId) === expected.getBoolean(rowId), + s"unexpected value at $rowId" + ) + } + } + } + + protected def checkTimestampVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getDataType === TimestampType.TIMESTAMP) + assert(actual.getDataType === expected.getDataType) + assert(actual.getSize === expected.getSize) + Seq.range(0, actual.getSize).foreach { rowId => + assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) + if (!actual.isNullAt(rowId)) { + assert( + actual.getLong(rowId) === expected.getLong(rowId), + s"unexpected value at $rowId: " + + s"expected: ${expected.getLong(rowId)} " + + s"actual: ${actual.getLong(rowId)} " + ) + } + } + } + + protected def checkStringVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getDataType === StringType.STRING) + assert(actual.getDataType === expected.getDataType) + assert(actual.getSize === expected.getSize) + Seq.range(0, actual.getSize).foreach { rowId => + assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) + if (!actual.isNullAt(rowId)) { + assert( + actual.getString(rowId) === expected.getString(rowId), + s"unexpected value at $rowId: " + + s"expected: ${expected.getString(rowId)} " + + s"actual: ${actual.getString(rowId)} " + ) + } + } + } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala index d4e3ac4f315..11d4961caf1 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala @@ -65,48 +65,4 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils { new Predicate(symbol, left, right) } - protected def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - assert(actual.getDataType === expected.getDataType) - assert(actual.getSize === expected.getSize) - Seq.range(0, actual.getSize).foreach { rowId => - assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) - if (!actual.isNullAt(rowId)) { - assert( - actual.getBoolean(rowId) === expected.getBoolean(rowId), - s"unexpected value at $rowId" - ) - } - } - } - - protected def checkTimestampVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - assert(actual.getSize === expected.getSize) - for (rowId <- 0 until actual.getSize) { - if (expected.isNullAt(rowId)) { - assert(actual.isNullAt(rowId), s"Expected null at row $rowId") - } else { - val expectedValue = getValueAsObject(expected, rowId).asInstanceOf[Long] - val actualValue = getValueAsObject(actual, rowId).asInstanceOf[Long] - assert(actualValue === expectedValue, s"Unexpected value at row $rowId") - } - } - } - - protected def checkStringVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - assert(actual.getDataType === StringType.STRING) - assert(actual.getDataType === expected.getDataType) - assert(actual.getSize === expected.getSize) - Seq.range(0, actual.getSize).foreach { rowId => - assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) - if (!actual.isNullAt(rowId)) { - assert( - actual.getString(rowId) === expected.getString(rowId), - s"unexpected value at $rowId: " + - s"expected: ${expected.getString(rowId)} " + - s"actual: ${actual.getString(rowId)} " - ) - } - } - } - } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala index c5c8e8b6ff1..6284c5dc9cb 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala @@ -18,9 +18,12 @@ package io.delta.kernel.defaults.utils import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch import io.delta.kernel.data.{ColumnVector, ColumnarBatch} +import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getValueAsObject import io.delta.kernel.test.VectorTestUtils import io.delta.kernel.types._ + + trait DefaultVectorTestUtils extends VectorTestUtils { /** * Returns a [[ColumnarBatch]] with each given vector is a top-level column col_i where i is From 16f256c6166cc9c6fe5fafef6eccd0b0b51d9bfe Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Thu, 6 Feb 2025 15:53:37 -0800 Subject: [PATCH 05/17] save --- .../internal/util/VectorUtilsSuite.scala | 7 +- .../delta/kernel/test/VectorTestUtils.scala | 64 ++++++++----------- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala index 459a52670b7..6ca9cee8fa8 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -26,9 +26,12 @@ import java.util class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { test("test build column vector from list of primitives") { + checkVectors(booleanVector(Seq[BooleanJ](true, false, null)), + VectorUtils.buildColumnVector(util.Arrays.asList(true, false, null), BooleanType.BOOLEAN), + BooleanType.BOOLEAN, + (vec, id) => vec.getBoolean(id) + ) - val childColumn = booleanVector(Seq[BooleanJ](true, false, null)) - VectorUtils.buildColumnVector(util.Arrays.asList(true, false, null), BooleanType.BOOLEAN) } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala index 979ed6d3307..0edfd63e5a6 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala @@ -19,7 +19,7 @@ import io.delta.kernel.data.{ColumnVector, MapValue} import io.delta.kernel.internal.util.VectorUtils import io.delta.kernel.types._ -import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ} +import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ, Byte => ByteJ} import scala.collection.JavaConverters._ import org.scalatest.Assertions.convertToEqualizer @@ -39,6 +39,20 @@ trait VectorTestUtils { } } + protected def byteVector(values: Seq[ByteJ]): ColumnVector = { + new ColumnVector { + override def getDataType: DataType = ByteType.BYTE + + override def getSize: Int = values.length + + override def close(): Unit = {} + + override def isNullAt(rowId: Int): Boolean = values(rowId) == null + + override def getByte(rowId: Int): Byte = values(rowId) + } + } + protected def timestampVector(values: Seq[Long]): ColumnVector = { new ColumnVector { override def getDataType: DataType = TimestampType.TIMESTAMP @@ -135,49 +149,27 @@ trait VectorTestUtils { override def getBoolean(rowId: Int): Boolean = rowId == selectRowId } - protected def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - assert(actual.getDataType === expected.getDataType) - assert(actual.getSize === expected.getSize) - Seq.range(0, actual.getSize).foreach { rowId => - assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) - if (!actual.isNullAt(rowId)) { - assert( - actual.getBoolean(rowId) === expected.getBoolean(rowId), - s"unexpected value at $rowId" - ) - } - } - } + protected def checkVectors[T]( + actual: ColumnVector, + expected: ColumnVector, + expectedType: DataType, + getValue: (ColumnVector, Int) => T, + errorMessageFn: (Int, T, T) => String = (rowId: Int, exp: T, act: T) => + s"unexpected value at $rowId" + ): Unit = { - protected def checkTimestampVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - assert(actual.getDataType === TimestampType.TIMESTAMP) + assert(actual.getDataType === expectedType) assert(actual.getDataType === expected.getDataType) assert(actual.getSize === expected.getSize) - Seq.range(0, actual.getSize).foreach { rowId => - assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) - if (!actual.isNullAt(rowId)) { - assert( - actual.getLong(rowId) === expected.getLong(rowId), - s"unexpected value at $rowId: " + - s"expected: ${expected.getLong(rowId)} " + - s"actual: ${actual.getLong(rowId)} " - ) - } - } - } - protected def checkStringVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - assert(actual.getDataType === StringType.STRING) - assert(actual.getDataType === expected.getDataType) - assert(actual.getSize === expected.getSize) Seq.range(0, actual.getSize).foreach { rowId => assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) if (!actual.isNullAt(rowId)) { + val actualValue = getValue(actual, rowId) + val expectedValue = getValue(expected, rowId) assert( - actual.getString(rowId) === expected.getString(rowId), - s"unexpected value at $rowId: " + - s"expected: ${expected.getString(rowId)} " + - s"actual: ${actual.getString(rowId)} " + actualValue === expectedValue, + errorMessageFn(rowId, expectedValue, actualValue) ) } } From c39c4f0136b413c8d39169b2bf51cc3a8894afe0 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Thu, 6 Feb 2025 16:52:55 -0800 Subject: [PATCH 06/17] refactor test --- .../delta/kernel/test/VectorTestUtils.scala | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala index 0edfd63e5a6..080fa0ca387 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala @@ -18,10 +18,10 @@ package io.delta.kernel.test import io.delta.kernel.data.{ColumnVector, MapValue} import io.delta.kernel.internal.util.VectorUtils import io.delta.kernel.types._ +import org.scalatest.Assertions.convertToEqualizer -import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ, Byte => ByteJ} +import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ} import scala.collection.JavaConverters._ -import org.scalatest.Assertions.convertToEqualizer trait VectorTestUtils { @@ -39,20 +39,6 @@ trait VectorTestUtils { } } - protected def byteVector(values: Seq[ByteJ]): ColumnVector = { - new ColumnVector { - override def getDataType: DataType = ByteType.BYTE - - override def getSize: Int = values.length - - override def close(): Unit = {} - - override def isNullAt(rowId: Int): Boolean = values(rowId) == null - - override def getByte(rowId: Int): Byte = values(rowId) - } - } - protected def timestampVector(values: Seq[Long]): ColumnVector = { new ColumnVector { override def getDataType: DataType = TimestampType.TIMESTAMP @@ -149,13 +135,38 @@ trait VectorTestUtils { override def getBoolean(rowId: Int): Boolean = rowId == selectRowId } - protected def checkVectors[T]( + protected def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + checkVectors( + actual, + expected, + BooleanType.BOOLEAN, + (vec, id) => vec.getBoolean(id) + ) + } + + protected def checkTimestampVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + checkVectors( + actual, + expected, + TimestampType.TIMESTAMP, + (vec, id) => vec.getLong(id) + ) + } + + protected def checkStringVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + checkVectors( + actual, + expected, + StringType.STRING, + (vec, id) => vec.getString(id) + ) + } + + private def checkVectors[T]( actual: ColumnVector, expected: ColumnVector, expectedType: DataType, - getValue: (ColumnVector, Int) => T, - errorMessageFn: (Int, T, T) => String = (rowId: Int, exp: T, act: T) => - s"unexpected value at $rowId" + getValue: (ColumnVector, Int) => T ): Unit = { assert(actual.getDataType === expectedType) @@ -169,7 +180,7 @@ trait VectorTestUtils { val expectedValue = getValue(expected, rowId) assert( actualValue === expectedValue, - errorMessageFn(rowId, expectedValue, actualValue) + s"unexpected value at $rowId: expected: $expected actual: $actual" ) } } From 3f535f70c1d2a1d7ca045491071f512e0e47888b Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 14:47:36 -0800 Subject: [PATCH 07/17] refactor --- .../kernel/internal/util/VectorUtils.java | 214 ++++++++---------- 1 file changed, 95 insertions(+), 119 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 57392cb0585..3889de84900 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -72,28 +72,6 @@ public static Map toJavaMap(MapValue mapValue) { return values; } - /** - * Creates an {@link ArrayValue} from list of object. - * - * @param values list of object - * @return an {@link ArrayValue} with the given values of type {@link StringType} - */ - public static ArrayValue buildArrayValue(List values, DataType dataType) { - if (values == null) { - return null; - } - return new ArrayValue() { - @Override - public int getSize() { - return values.size(); - } - - @Override - public ColumnVector getElements() { - return buildColumnVector(values, dataType); - } - }; - } /** * Creates a {@link MapValue} from map of string keys and string values. The type {@code @@ -127,12 +105,32 @@ public ColumnVector getValues() { }; } + /** + * Creates an {@link ArrayValue} from list of objects. + */ + public static ArrayValue buildArrayValue(List values, DataType dataType) { + if (values == null) { + return null; + } + return new ArrayValue() { + @Override + public int getSize() { + return values.size(); + } + + @Override + public ColumnVector getElements() { + return buildColumnVector(values, dataType); + } + }; + } + /** * Utility method to create a {@link ColumnVector} for given list of object, the object should be * primitive type or an Row instance. * * @param values list of strings - * @return a {@link ColumnVector} with the given values of type {@link StringType} + * @return a {@link ColumnVector} with the given values type. */ public static ColumnVector buildColumnVector(List values, DataType dataType) { return new ColumnVector() { @@ -153,169 +151,147 @@ public void close() { @Override public boolean isNullAt(int rowId) { - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + validateRowId(rowId); return values.get(rowId) == null; } @Override public boolean getBoolean(int rowId) { checkArgument(BooleanType.BOOLEAN.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Boolean); - return (Boolean) value; + return (Boolean) getValidatedValue(rowId, Boolean.class); } @Override public byte getByte(int rowId) { checkArgument(ByteType.BYTE.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Byte); - return (Byte) value; + return (Byte) getValidatedValue(rowId, Byte.class); } @Override public short getShort(int rowId) { checkArgument(ShortType.SHORT.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Short); - return (Short) value; + return (Short) getValidatedValue(rowId, Short.class); } @Override public int getInt(int rowId) { checkArgument(IntegerType.INTEGER.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Integer); - return (Integer) value; + return (Integer) getValidatedValue(rowId, Integer.class); } @Override public long getLong(int rowId) { checkArgument(LongType.LONG.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Long); - return (Long) value; + return (Long) getValidatedValue(rowId, Long.class); } @Override public float getFloat(int rowId) { checkArgument(FloatType.FLOAT.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Float); - return (Float) value; + return (Float) getValidatedValue(rowId, Float.class); } @Override public double getDouble(int rowId) { checkArgument(DoubleType.DOUBLE.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Double); - return (Double) value; + return (Double) getValidatedValue(rowId, Double.class); } @Override public BigDecimal getDecimal(int rowId) { checkArgument(dataType instanceof DecimalType); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof BigDecimal); - return (BigDecimal) value; + return (BigDecimal) getValidatedValue(rowId, BigDecimal.class); } @Override public String getString(int rowId) { checkArgument(StringType.STRING.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof String); - return (String) value; + return (String) getValidatedValue(rowId, String.class); } @Override public byte[] getBinary(int rowId) { checkArgument(BinaryType.BINARY.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof byte[]); - return (byte[]) value; + return (byte[]) getValidatedValue(rowId, byte[].class); } @Override public ArrayValue getArray(int rowId) { - checkArgument(dataType instanceof ArrayValue); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof ArrayValue); - return (ArrayValue) value; + checkArgument(dataType instanceof ArrayType); + return (ArrayValue) getValidatedValue(rowId, ArrayValue.class); } @Override public MapValue getMap(int rowId) { checkArgument(dataType instanceof MapType); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof MapType); - return (MapValue) value; + return (MapValue) getValidatedValue(rowId, MapValue.class); } @Override public ColumnVector getChild(int ordinal) { checkArgument(dataType instanceof StructType); checkArgument(ordinal < ((StructType) dataType).length()); + DataType childDatatype = ((StructType) dataType).at(ordinal).getDataType(); - return buildColumnVector( - values.stream() - .map( - e -> { - checkArgument(e instanceof Row); - Row row = (Row) e; - if (row.isNullAt(ordinal)) { - return null; - } - if (childDatatype instanceof BooleanType) { - return row.getBoolean(ordinal); - } else if (childDatatype instanceof ByteType) { - return row.getByte(ordinal); - } else if (childDatatype instanceof ShortType) { - return row.getShort(ordinal); - } else if (childDatatype instanceof IntegerType - || childDatatype instanceof DateType) { - // DateType data is stored internally as the number of days since 1970-01-01 - return row.getInt(ordinal); - } else if (childDatatype instanceof LongType - || childDatatype instanceof TimestampType) { - // TimestampType data is stored internally as the number of microseconds - // since the unix epoch - return row.getLong(ordinal); - } else if (childDatatype instanceof FloatType) { - return row.getFloat(ordinal); - } else if (childDatatype instanceof DoubleType) { - return row.getDouble(ordinal); - } else if (childDatatype instanceof StringType) { - return row.getString(ordinal); - } else if (childDatatype instanceof BinaryType) { - return row.getBinary(ordinal); - } else if (childDatatype instanceof StructType) { - return row.getStruct(ordinal); - } else if (childDatatype instanceof DecimalType) { - return row.getDecimal(ordinal); - } else if (childDatatype instanceof ArrayType) { - return row.getArray(ordinal); - } else if (dataType instanceof MapType) { - return row.getMap(ordinal); - } else { - throw new UnsupportedOperationException("unsupported data type"); - } - }) - .collect(Collectors.toList()), - childDatatype); + List childValues = extractChildValues(ordinal, childDatatype); + + return buildColumnVector(childValues, childDatatype); + } + + private void validateRowId(int rowId) { + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + } + + private Object getValidatedValue(int rowId, Class expectedType) { + validateRowId(rowId); + Object value = values.get(rowId); + checkArgument(expectedType.isInstance(value), + "Value must be of type %s", expectedType.getSimpleName()); + return value; + } + + + private List extractChildValues(int ordinal, DataType childDatatype) { + return values.stream() + .map(e -> extractChildValue(e, ordinal, childDatatype)) + .collect(Collectors.toList()); + } + + private Object extractChildValue(Object element, int ordinal, DataType childDatatype) { + checkArgument(element instanceof Row); + Row row = (Row) element; + + if (row.isNullAt(ordinal)) { + return null; + } + + return extractTypedValue(row, ordinal, childDatatype); + } + + private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) { + // Primitive Types + if (childDatatype instanceof BooleanType) return row.getBoolean(ordinal); + if (childDatatype instanceof ByteType) return row.getByte(ordinal); + if (childDatatype instanceof ShortType) return row.getShort(ordinal); + if (childDatatype instanceof IntegerType || + childDatatype instanceof DateType) return row.getInt(ordinal); + if (childDatatype instanceof LongType || + childDatatype instanceof TimestampType) return row.getLong(ordinal); + if (childDatatype instanceof FloatType) return row.getFloat(ordinal); + if (childDatatype instanceof DoubleType) return row.getDouble(ordinal); + + // Complex Types + if (childDatatype instanceof StringType) return row.getString(ordinal); + if (childDatatype instanceof BinaryType) return row.getBinary(ordinal); + if (childDatatype instanceof DecimalType) return row.getDecimal(ordinal); + + // Nested Types + if (childDatatype instanceof StructType) return row.getStruct(ordinal); + if (childDatatype instanceof ArrayType) return row.getArray(ordinal); + if (childDatatype instanceof MapType) return row.getMap(ordinal); + + throw new UnsupportedOperationException( + String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName())); } }; } From e4d11c1939f72fd0da5a4fdf8bc938d8f6161eff Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:20:06 -0800 Subject: [PATCH 08/17] add test --- .../kernel/internal/util/VectorUtils.java | 76 ++++++---- .../internal/util/VectorUtilsSuite.scala | 132 ++++++++++++++++-- 2 files changed, 170 insertions(+), 38 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 3889de84900..91e7d6cde62 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -72,7 +72,6 @@ public static Map toJavaMap(MapValue mapValue) { return values; } - /** * Creates a {@link MapValue} from map of string keys and string values. The type {@code * map(string -> string)} is a common occurrence in Delta Log schema. @@ -105,9 +104,7 @@ public ColumnVector getValues() { }; } - /** - * Creates an {@link ArrayValue} from list of objects. - */ + /** Creates an {@link ArrayValue} from list of objects. */ public static ArrayValue buildArrayValue(List values, DataType dataType) { if (values == null) { return null; @@ -175,13 +172,13 @@ public short getShort(int rowId) { @Override public int getInt(int rowId) { - checkArgument(IntegerType.INTEGER.equals(dataType)); + checkArgument(IntegerType.INTEGER.equals(dataType) || DateType.DATE.equals(dataType)); return (Integer) getValidatedValue(rowId, Integer.class); } @Override public long getLong(int rowId) { - checkArgument(LongType.LONG.equals(dataType)); + checkArgument(LongType.LONG.equals(dataType) || TimestampType.TIMESTAMP.equals(dataType)); return (Long) getValidatedValue(rowId, Long.class); } @@ -245,16 +242,17 @@ private void validateRowId(int rowId) { private Object getValidatedValue(int rowId, Class expectedType) { validateRowId(rowId); Object value = values.get(rowId); - checkArgument(expectedType.isInstance(value), - "Value must be of type %s", expectedType.getSimpleName()); + checkArgument( + expectedType.isInstance(value), + "Value must be of type %s", + expectedType.getSimpleName()); return value; } - private List extractChildValues(int ordinal, DataType childDatatype) { return values.stream() - .map(e -> extractChildValue(e, ordinal, childDatatype)) - .collect(Collectors.toList()); + .map(e -> extractChildValue(e, ordinal, childDatatype)) + .collect(Collectors.toList()); } private Object extractChildValue(Object element, int ordinal, DataType childDatatype) { @@ -270,28 +268,52 @@ private Object extractChildValue(Object element, int ordinal, DataType childData private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) { // Primitive Types - if (childDatatype instanceof BooleanType) return row.getBoolean(ordinal); - if (childDatatype instanceof ByteType) return row.getByte(ordinal); - if (childDatatype instanceof ShortType) return row.getShort(ordinal); - if (childDatatype instanceof IntegerType || - childDatatype instanceof DateType) return row.getInt(ordinal); - if (childDatatype instanceof LongType || - childDatatype instanceof TimestampType) return row.getLong(ordinal); - if (childDatatype instanceof FloatType) return row.getFloat(ordinal); - if (childDatatype instanceof DoubleType) return row.getDouble(ordinal); + if (childDatatype instanceof BooleanType) { + return row.getBoolean(ordinal); + } + if (childDatatype instanceof ByteType) { + return row.getByte(ordinal); + } + if (childDatatype instanceof ShortType) { + return row.getShort(ordinal); + } + if (childDatatype instanceof IntegerType || childDatatype instanceof DateType) { + return row.getInt(ordinal); + } + if (childDatatype instanceof LongType || childDatatype instanceof TimestampType) { + return row.getLong(ordinal); + } + if (childDatatype instanceof FloatType) { + return row.getFloat(ordinal); + } + if (childDatatype instanceof DoubleType) { + return row.getDouble(ordinal); + } // Complex Types - if (childDatatype instanceof StringType) return row.getString(ordinal); - if (childDatatype instanceof BinaryType) return row.getBinary(ordinal); - if (childDatatype instanceof DecimalType) return row.getDecimal(ordinal); + if (childDatatype instanceof StringType) { + return row.getString(ordinal); + } + if (childDatatype instanceof BinaryType) { + return row.getBinary(ordinal); + } + if (childDatatype instanceof DecimalType) { + return row.getDecimal(ordinal); + } // Nested Types - if (childDatatype instanceof StructType) return row.getStruct(ordinal); - if (childDatatype instanceof ArrayType) return row.getArray(ordinal); - if (childDatatype instanceof MapType) return row.getMap(ordinal); + if (childDatatype instanceof StructType) { + return row.getStruct(ordinal); + } + if (childDatatype instanceof ArrayType) { + return row.getArray(ordinal); + } + if (childDatatype instanceof MapType) { + return row.getMap(ordinal); + } throw new UnsupportedOperationException( - String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName())); + String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName())); } }; } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala index 6ca9cee8fa8..4bfb1e1f41c 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -16,23 +16,133 @@ package io.delta.kernel.internal.util +import java.sql.{Date, Timestamp} import io.delta.kernel.test.VectorTestUtils -import io.delta.kernel.types.BooleanType +import io.delta.kernel.types.{ + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + TimestampType +} + +import java.lang.{ + Boolean => BooleanJ, + Byte => ByteJ, + Double => DoubleJ, + Float => FloatJ, + Integer => IntegerJ, + Long => LongJ, + Short => ShortJ +} +import scala.collection.JavaConverters._ import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.Tables.Table -import java.lang.{Boolean => BooleanJ} -import java.util +import java.math.BigDecimal class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { - test("test build column vector from list of primitives") { - checkVectors(booleanVector(Seq[BooleanJ](true, false, null)), - VectorUtils.buildColumnVector(util.Arrays.asList(true, false, null), BooleanType.BOOLEAN), - BooleanType.BOOLEAN, - (vec, id) => vec.getBoolean(id) + Table( + ("values", "dataType"), + (List[ByteJ](1.toByte, 2.toByte, 3.toByte, null), ByteType.BYTE), + (List[ShortJ](1.toShort, 2.toShort, 3.toShort, null), ShortType.SHORT), + (List[IntegerJ](1, 2, 3, null), IntegerType.INTEGER), + (List[LongJ](1L, 2L, 3L, null), LongType.LONG), + (List[FloatJ](1.0f, 2.0f, 3.0f, null), FloatType.FLOAT), + (List[DoubleJ](1.0, 2.0, 3.0, null), DoubleType.DOUBLE), + (List[Array[Byte]]("one".getBytes, "two".getBytes, "three".getBytes, null), BinaryType.BINARY), + (List[BooleanJ](true, false, false, null), BooleanType.BOOLEAN), + ( + List[BigDecimal](new BigDecimal("1"), new BigDecimal("2"), new BigDecimal("3"), null), + new DecimalType(10, 2) + ), + (List[String]("one", "two", "three", null), StringType.STRING), + ( + List[IntegerJ](10, 20, 30, null), + DateType.DATE + ), + ( + List[LongJ]( + Timestamp.valueOf("2023-01-01 00:00:00").getTime, + Timestamp.valueOf("2023-01-02 00:00:00").getTime, + Timestamp.valueOf("2023-01-03 00:00:00").getTime, + null + ), + TimestampType.TIMESTAMP ) + ).foreach( + testCase => + test(s"handle ${testCase._2} array correctly") { + val values = testCase._1 + val dataType = testCase._2 + val columnVector = VectorUtils.buildColumnVector(values.asJava, dataType) + assert(columnVector.getSize == 4) - - } - + dataType match { + case ByteType.BYTE => + assert(columnVector.getByte(0) == 1.toByte) + assert(columnVector.getByte(1) == 2.toByte) + assert(columnVector.getByte(2) == 3.toByte) + case ShortType.SHORT => + assert(columnVector.getShort(0) == 1.toShort) + assert(columnVector.getShort(1) == 2.toShort) + assert(columnVector.getShort(2) == 3.toShort) + case IntegerType.INTEGER => + assert(columnVector.getInt(0) == 1) + assert(columnVector.getInt(1) == 2) + assert(columnVector.getInt(2) == 3) + case LongType.LONG => + assert(columnVector.getLong(0) == 1L) + assert(columnVector.getLong(1) == 2L) + assert(columnVector.getLong(2) == 3L) + case FloatType.FLOAT => + assert(columnVector.getFloat(0) == 1.0f) + assert(columnVector.getFloat(1) == 2.0f) + assert(columnVector.getFloat(2) == 3.0f) + case DoubleType.DOUBLE => + assert(columnVector.getDouble(0) == 1.0) + assert(columnVector.getDouble(1) == 2.0) + assert(columnVector.getDouble(2) == 3.0) + case BooleanType.BOOLEAN => + assert(columnVector.getBoolean(0)) + assert(!columnVector.getBoolean(1)) + assert(!columnVector.getBoolean(2)) + case _: DecimalType => + assert(columnVector.getDecimal(0) == new BigDecimal("1")) + assert(columnVector.getDecimal(1) == new BigDecimal("2")) + assert(columnVector.getDecimal(2) == new BigDecimal("3")) + case BinaryType.BINARY => + assert(columnVector.getBinary(0) sameElements "one".getBytes) + assert(columnVector.getBinary(1) sameElements "two".getBytes) + assert(columnVector.getBinary(2) sameElements "three".getBytes) + case StringType.STRING => + assert(columnVector.getString(0) == "one") + assert(columnVector.getString(1) == "two") + assert(columnVector.getString(2) == "three") + case DateType.DATE => + assert(columnVector.getInt(0) == 10) + assert(columnVector.getInt(1) == 20) + assert(columnVector.getInt(2) == 30) + case TimestampType.TIMESTAMP => + assert( + columnVector.getLong(0) == Timestamp.valueOf("2023-01-01 00:00:00").getTime + ) + assert( + columnVector.getLong(1) == Timestamp.valueOf("2023-01-02 00:00:00").getTime + ) + assert( + columnVector.getLong(2) == Timestamp.valueOf("2023-01-03 00:00:00").getTime + ) + } + assert(columnVector.isNullAt(3)) + } + ) } From dd2083312b28d4e98cabfef168f338493eec6443 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:21:22 -0800 Subject: [PATCH 09/17] remove empty line --- .../io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala index 6284c5dc9cb..d210008c1d7 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala @@ -22,8 +22,6 @@ import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getValueAsObject import io.delta.kernel.test.VectorTestUtils import io.delta.kernel.types._ - - trait DefaultVectorTestUtils extends VectorTestUtils { /** * Returns a [[ColumnarBatch]] with each given vector is a top-level column col_i where i is From f6c80218ad85c058ec34ddb333f2652a8363c41a Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:24:01 -0800 Subject: [PATCH 10/17] revert unnecessary change --- .../delta/kernel/test/VectorTestUtils.scala | 51 ------------------- .../expressions/ExpressionSuiteBase.scala | 45 ++++++++++++++++ 2 files changed, 45 insertions(+), 51 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala index 080fa0ca387..40d044d51e6 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala @@ -134,55 +134,4 @@ trait VectorTestUtils { override def getBoolean(rowId: Int): Boolean = rowId == selectRowId } - - protected def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - checkVectors( - actual, - expected, - BooleanType.BOOLEAN, - (vec, id) => vec.getBoolean(id) - ) - } - - protected def checkTimestampVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - checkVectors( - actual, - expected, - TimestampType.TIMESTAMP, - (vec, id) => vec.getLong(id) - ) - } - - protected def checkStringVectors(actual: ColumnVector, expected: ColumnVector): Unit = { - checkVectors( - actual, - expected, - StringType.STRING, - (vec, id) => vec.getString(id) - ) - } - - private def checkVectors[T]( - actual: ColumnVector, - expected: ColumnVector, - expectedType: DataType, - getValue: (ColumnVector, Int) => T - ): Unit = { - - assert(actual.getDataType === expectedType) - assert(actual.getDataType === expected.getDataType) - assert(actual.getSize === expected.getSize) - - Seq.range(0, actual.getSize).foreach { rowId => - assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) - if (!actual.isNullAt(rowId)) { - val actualValue = getValue(actual, rowId) - val expectedValue = getValue(expected, rowId) - assert( - actualValue === expectedValue, - s"unexpected value at $rowId: expected: $expected actual: $actual" - ) - } - } - } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala index 11d4961caf1..0abd5343ae8 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala @@ -65,4 +65,49 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils { new Predicate(symbol, left, right) } + protected def checkBooleanVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getDataType === expected.getDataType) + assert(actual.getSize === expected.getSize) + Seq.range(0, actual.getSize).foreach { rowId => + assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) + if (!actual.isNullAt(rowId)) { + assert( + actual.getBoolean(rowId) === expected.getBoolean(rowId), + s"unexpected value at $rowId" + ) + } + } + } + + protected def checkTimestampVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getSize === expected.getSize) + for (rowId <- 0 until actual.getSize) { + if (expected.isNullAt(rowId)) { + assert(actual.isNullAt(rowId), s"Expected null at row $rowId") + } else { + val expectedValue = getValueAsObject(expected, rowId).asInstanceOf[Long] + val actualValue = getValueAsObject(actual, rowId).asInstanceOf[Long] + assert(actualValue === expectedValue, s"Unexpected value at row $rowId") + } + } + } + + protected def checkStringVectors(actual: ColumnVector, expected: ColumnVector): Unit = { + assert(actual.getDataType === StringType.STRING) + assert(actual.getDataType === expected.getDataType) + assert(actual.getSize === expected.getSize) + Seq.range(0, actual.getSize).foreach { rowId => + assert(actual.isNullAt(rowId) === expected.isNullAt(rowId)) + if (!actual.isNullAt(rowId)) { + assert( + actual.getString(rowId) === expected.getString(rowId), + s"unexpected value at $rowId: " + + s"expected: ${expected.getString(rowId)} " + + s"actual: ${actual.getString(rowId)} " + ) + } + } + } + + } From 7da6af7d7f5444da68f8af077151c8753debff8f Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:26:15 -0800 Subject: [PATCH 11/17] remove empty line --- .../defaults/internal/expressions/ExpressionSuiteBase.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala index 0abd5343ae8..001f49c7fff 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala @@ -102,12 +102,11 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils { assert( actual.getString(rowId) === expected.getString(rowId), s"unexpected value at $rowId: " + - s"expected: ${expected.getString(rowId)} " + - s"actual: ${actual.getString(rowId)} " + s"expected: ${expected.getString(rowId)} " + + s"actual: ${actual.getString(rowId)} " ) } } } - } From fbeb2ee8f2643eee76c232e022bd4fd34d487960 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:27:08 -0800 Subject: [PATCH 12/17] remove empty line --- .../defaults/internal/expressions/ExpressionSuiteBase.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala index 001f49c7fff..d4e3ac4f315 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala @@ -102,8 +102,8 @@ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils { assert( actual.getString(rowId) === expected.getString(rowId), s"unexpected value at $rowId: " + - s"expected: ${expected.getString(rowId)} " + - s"actual: ${actual.getString(rowId)} " + s"expected: ${expected.getString(rowId)} " + + s"actual: ${actual.getString(rowId)} " ) } } From 5d56ea5eaec365a3503ff80d51c0ebe3754f5fea Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:28:04 -0800 Subject: [PATCH 13/17] revert line --- .../src/test/scala/io/delta/kernel/test/VectorTestUtils.scala | 1 - .../io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala index 40d044d51e6..7396522909d 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/VectorTestUtils.scala @@ -18,7 +18,6 @@ package io.delta.kernel.test import io.delta.kernel.data.{ColumnVector, MapValue} import io.delta.kernel.internal.util.VectorUtils import io.delta.kernel.types._ -import org.scalatest.Assertions.convertToEqualizer import java.lang.{Boolean => BooleanJ, Double => DoubleJ, Float => FloatJ} import scala.collection.JavaConverters._ diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala index d210008c1d7..c5c8e8b6ff1 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/DefaultVectorTestUtils.scala @@ -18,7 +18,6 @@ package io.delta.kernel.defaults.utils import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch import io.delta.kernel.data.{ColumnVector, ColumnarBatch} -import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getValueAsObject import io.delta.kernel.test.VectorTestUtils import io.delta.kernel.types._ From 41ac0c0c94a8f26ed7c48a3b680b664b630015ea Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 19:58:55 -0800 Subject: [PATCH 14/17] all test case --- .../internal/util/VectorUtilsSuite.scala | 277 ++++++++++++++++++ 1 file changed, 277 insertions(+) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala index 4bfb1e1f41c..17e0b66e20c 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -16,9 +16,13 @@ package io.delta.kernel.internal.util +import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} +import io.delta.kernel.internal.data.GenericRow + import java.sql.{Date, Timestamp} import io.delta.kernel.test.VectorTestUtils import io.delta.kernel.types.{ + ArrayType, BinaryType, BooleanType, ByteType, @@ -28,8 +32,10 @@ import io.delta.kernel.types.{ FloatType, IntegerType, LongType, + MapType, ShortType, StringType, + StructType, TimestampType } @@ -47,6 +53,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.prop.Tables.Table import java.math.BigDecimal +import java.util class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { @@ -145,4 +152,274 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { assert(columnVector.isNullAt(3)) } ) + + test(s"handle array of struct correctly") { + val structType = + new StructType().add("name", StringType.STRING).add("value", IntegerType.INTEGER) + + val arrayType = new ArrayType(structType, true) + + def row(name: String, value: Integer): Row = { + val map = new util.HashMap[Integer, AnyRef] + map.put(0, name) + map.put(1, value) + new GenericRow(structType, map) + } + + val values = List[ArrayValue]( + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[Row]( + row("a1", 1), + row("a2", 2) + ).asJava, + structType + ) + }, + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[Row]( + row("b1", 3), + row("b2", 4) + ).asJava, + structType + ) + }, + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[Row]( + row("c1", 5), + row("c2", 6) + ).asJava, + structType + ) + }, + null + ) + + val columnVector = VectorUtils.buildColumnVector(values.asJava, arrayType) + + // Test size + assert(columnVector.getSize == 4) + + // Test first array + val array0 = columnVector.getArray(0) + val struct0 = array0.getElements + assert(struct0.getSize == 2) + + val nameVector0 = struct0.getChild(0) + val valueVector0 = struct0.getChild(1) + assert(nameVector0.getString(0) == "a1") + assert(valueVector0.getInt(0) == 1) + assert(nameVector0.getString(1) == "a2") + assert(valueVector0.getInt(1) == 2) + + // Test second array + val array1 = columnVector.getArray(1) + val struct1 = array1.getElements + assert(struct1.getSize == 2) + + val nameVector1 = struct1.getChild(0) + val valueVector1 = struct1.getChild(1) + assert(nameVector1.getString(0) == "b1") + assert(valueVector1.getInt(0) == 3) + assert(nameVector1.getString(1) == "b2") + assert(valueVector1.getInt(1) == 4) + + // Test third array + val array2 = columnVector.getArray(2) + val struct2 = array2.getElements + assert(struct2.getSize == 2) + + val nameVector2 = struct2.getChild(0) + val valueVector2 = struct2.getChild(1) + assert(nameVector2.getString(0) == "c1") + assert(valueVector2.getInt(0) == 5) + assert(nameVector2.getString(1) == "c2") + assert(valueVector2.getInt(1) == 6) + + // Test null value + assert(columnVector.isNullAt(3)) + } + + test(s"handle array of map correctly") { + val mapType = new MapType(StringType.STRING, IntegerType.INTEGER, true) + val arrayType = new ArrayType(mapType, true) + + val values = List[ArrayValue]( + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[MapValue]( + new MapValue { + override def getSize: Int = 2 + override def getKeys: ColumnVector = + VectorUtils.buildColumnVector(List("a1", "a2").asJava, StringType.STRING) + override def getValues: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](1, 2).asJava, IntegerType.INTEGER) + }, + new MapValue { + override def getSize: Int = 2 + override def getKeys: ColumnVector = + VectorUtils.buildColumnVector(List("a3", "a4").asJava, StringType.STRING) + override def getValues: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](3, 4).asJava, IntegerType.INTEGER) + } + ).asJava, + mapType + ) + }, + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[MapValue]( + new MapValue { + override def getSize: Int = 2 + override def getKeys: ColumnVector = + VectorUtils.buildColumnVector(List("b1", "b2").asJava, StringType.STRING) + override def getValues: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](5, 6).asJava, IntegerType.INTEGER) + }, + new MapValue { + override def getSize: Int = 2 + override def getKeys: ColumnVector = + VectorUtils.buildColumnVector(List("b3", "b4").asJava, StringType.STRING) + override def getValues: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](7, 8).asJava, IntegerType.INTEGER) + } + ).asJava, + mapType + ) + }, + null + ) + + val columnVector = VectorUtils.buildColumnVector(values.asJava, arrayType) + + // Test size + assert(columnVector.getSize == 3) + + // Test first array + val firstArray = columnVector.getArray(0) + val firstArrayMaps = firstArray.getElements + assert(firstArrayMaps.getSize == 2) + + val firstArrayFirstMap = firstArrayMaps.getMap(0) + assert(firstArrayFirstMap.getKeys.getString(0) == "a1") + assert(firstArrayFirstMap.getKeys.getString(1) == "a2") + assert(firstArrayFirstMap.getValues.getInt(0) == 1) + assert(firstArrayFirstMap.getValues.getInt(1) == 2) + + val firstArraySecondMap = firstArrayMaps.getMap(1) + assert(firstArraySecondMap.getKeys.getString(0) == "a3") + assert(firstArraySecondMap.getKeys.getString(1) == "a4") + assert(firstArraySecondMap.getValues.getInt(0) == 3) + assert(firstArraySecondMap.getValues.getInt(1) == 4) + + // Test second array + val secondArray = columnVector.getArray(1) + val secondArrayMaps = secondArray.getElements + assert(secondArrayMaps.getSize == 2) + + val secondArrayFirstMap = secondArrayMaps.getMap(0) + assert(secondArrayFirstMap.getKeys.getString(0) == "b1") + assert(secondArrayFirstMap.getKeys.getString(1) == "b2") + assert(secondArrayFirstMap.getValues.getInt(0) == 5) + assert(secondArrayFirstMap.getValues.getInt(1) == 6) + + val secondArraySecondMap = secondArrayMaps.getMap(1) + assert(secondArraySecondMap.getKeys.getString(0) == "b3") + assert(secondArraySecondMap.getKeys.getString(1) == "b4") + assert(secondArraySecondMap.getValues.getInt(0) == 7) + assert(secondArraySecondMap.getValues.getInt(1) == 8) + + // Test null value + assert(columnVector.isNullAt(2)) + } + + test(s"handle array of array correctly") { + val innerArrayType = new ArrayType(IntegerType.INTEGER, true) + val outerArrayType = new ArrayType(innerArrayType, true) + + val values = List[ArrayValue]( + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[ArrayValue]( + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](1, 2).asJava, IntegerType.INTEGER) + }, + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](3, 4).asJava, IntegerType.INTEGER) + } + ).asJava, + innerArrayType + ) + }, + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = VectorUtils.buildColumnVector( + List[ArrayValue]( + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](5, 6).asJava, IntegerType.INTEGER) + }, + new ArrayValue { + override def getSize: Int = 2 + override def getElements: ColumnVector = + VectorUtils.buildColumnVector(List[IntegerJ](7, 8).asJava, IntegerType.INTEGER) + } + ).asJava, + innerArrayType + ) + }, + null + ) + + val columnVector = VectorUtils.buildColumnVector(values.asJava, outerArrayType) + + // Test size + assert(columnVector.getSize == 3) + + // Test first outer array + val firstOuterArray = columnVector.getArray(0) + val firstOuterArrayElements = firstOuterArray.getElements + assert(firstOuterArrayElements.getSize == 2) + + val firstOuterArrayFirstInner = firstOuterArrayElements.getArray(0) + val firstOuterArrayFirstInnerElements = firstOuterArrayFirstInner.getElements + assert(firstOuterArrayFirstInnerElements.getInt(0) == 1) + assert(firstOuterArrayFirstInnerElements.getInt(1) == 2) + + val firstOuterArraySecondInner = firstOuterArrayElements.getArray(1) + val firstOuterArraySecondInnerElements = firstOuterArraySecondInner.getElements + assert(firstOuterArraySecondInnerElements.getInt(0) == 3) + assert(firstOuterArraySecondInnerElements.getInt(1) == 4) + + // Test second outer array + val secondOuterArray = columnVector.getArray(1) + val secondOuterArrayElements = secondOuterArray.getElements + assert(secondOuterArrayElements.getSize == 2) + + val secondOuterArrayFirstInner = secondOuterArrayElements.getArray(0) + val secondOuterArrayFirstInnerElements = secondOuterArrayFirstInner.getElements + assert(secondOuterArrayFirstInnerElements.getInt(0) == 5) + assert(secondOuterArrayFirstInnerElements.getInt(1) == 6) + + val secondOuterArraySecondInner = secondOuterArrayElements.getArray(1) + val secondOuterArraySecondInnerElements = secondOuterArraySecondInner.getElements + assert(secondOuterArraySecondInnerElements.getInt(0) == 7) + assert(secondOuterArraySecondInnerElements.getInt(1) == 8) + + // Test null value + assert(columnVector.isNullAt(2)) + } } From 24d1682c15bafeac943b609194ba760a6782db73 Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 14 Feb 2025 20:15:38 -0800 Subject: [PATCH 15/17] fix test --- .../internal/tablefeatures/TableFeaturesSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/tablefeatures/TableFeaturesSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/tablefeatures/TableFeaturesSuite.scala index 5af3569865f..81d870d15c1 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/tablefeatures/TableFeaturesSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/tablefeatures/TableFeaturesSuite.scala @@ -20,7 +20,7 @@ import io.delta.kernel.exceptions.KernelException import io.delta.kernel.internal.actions.{Format, Metadata, Protocol} import io.delta.kernel.internal.tablefeatures.TableFeatures.{TABLE_FEATURES, validateWriteSupportedTable} import io.delta.kernel.internal.util.InternalUtils.singletonStringColumnVector -import io.delta.kernel.internal.util.VectorUtils.stringVector +import io.delta.kernel.internal.util.VectorUtils.buildColumnVector import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite @@ -299,8 +299,11 @@ class TableFeaturesSuite extends AnyFunSuite { Optional.empty(), new MapValue() { // conf override def getSize = tblProps.size - override def getKeys: ColumnVector = stringVector(tblProps.toSeq.map(_._1).asJava) - override def getValues: ColumnVector = stringVector(tblProps.toSeq.map(_._2).asJava) + + override def getKeys: ColumnVector = + buildColumnVector(tblProps.toSeq.map(_._1).asJava, StringType.STRING) + override def getValues: ColumnVector = + buildColumnVector(tblProps.toSeq.map(_._2).asJava, StringType.STRING) } ) } From 314e21d50a77c1cf939f76bf5ad3733f243a074e Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 28 Feb 2025 10:38:23 -0800 Subject: [PATCH 16/17] fix conflict --- .../java/io/delta/kernel/internal/actions/Protocol.java | 6 +++--- .../kernel/internal/checksum/ChecksumWriterSuite.scala | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java index 1e08823ca6f..7a0c9238bee 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java @@ -18,7 +18,7 @@ import static io.delta.kernel.internal.tablefeatures.TableFeatures.TABLE_FEATURES; import static io.delta.kernel.internal.tablefeatures.TableFeatures.TABLE_FEATURES_MIN_WRITER_VERSION; import static io.delta.kernel.internal.util.Preconditions.checkArgument; -import static io.delta.kernel.internal.util.VectorUtils.stringArrayValue; +import static io.delta.kernel.internal.util.VectorUtils.buildArrayValue; import static java.lang.String.format; import static java.util.Collections.emptySet; import static java.util.Collections.unmodifiableSet; @@ -151,10 +151,10 @@ public Row toRow() { protocolMap.put(0, minReaderVersion); protocolMap.put(1, minWriterVersion); if (supportsReaderFeatures) { - protocolMap.put(2, stringArrayValue(new ArrayList<>(readerFeatures))); + protocolMap.put(2, buildArrayValue(new ArrayList<>(readerFeatures), StringType.STRING)); } if (supportsWriterFeatures) { - protocolMap.put(3, stringArrayValue(new ArrayList<>(writerFeatures))); + protocolMap.put(3, buildArrayValue(new ArrayList<>(writerFeatures), StringType.STRING)); } return new GenericRow(Protocol.FULL_SCHEMA, protocolMap); diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala index 1e73f99de19..a759fdf0a6a 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala @@ -17,18 +17,16 @@ package io.delta.kernel.internal.checksum import java.util import java.util.{Collections, Optional} - import io.delta.kernel.data.Row import io.delta.kernel.internal.actions.{Format, Metadata, Protocol} import io.delta.kernel.internal.checksum.CRCInfo.CRC_FILE_SCHEMA import io.delta.kernel.internal.data.GenericRow import io.delta.kernel.internal.fs.Path import io.delta.kernel.internal.util.VectorUtils -import io.delta.kernel.internal.util.VectorUtils.{stringArrayValue, stringStringMapValue} +import io.delta.kernel.internal.util.VectorUtils.{buildArrayValue, stringStringMapValue} import io.delta.kernel.test.{BaseMockJsonHandler, MockEngineUtils} -import io.delta.kernel.types.StructType +import io.delta.kernel.types.{StringType, StructType} import io.delta.kernel.utils.CloseableIterator - import org.scalatest.funsuite.AnyFunSuite /** @@ -162,7 +160,7 @@ class ChecksumWriterSuite extends AnyFunSuite with MockEngineUtils { new Format("parquet", Collections.emptyMap()), "schemaString", new StructType(), - stringArrayValue(util.Arrays.asList("c3")), + buildArrayValue(util.Arrays.asList("c3"), StringType.STRING), Optional.of(123), stringStringMapValue(new util.HashMap[String, String]() { put("delta.appendOnly", "true") From b0c8c29ab6b8b8b9e50823b8e4ebd56f42de633c Mon Sep 17 00:00:00 2001 From: Xin Huang Date: Fri, 28 Feb 2025 11:02:40 -0800 Subject: [PATCH 17/17] fix scala fmt --- .../checksum/ChecksumWriterSuite.scala | 2 + .../internal/util/VectorUtilsSuite.scala | 229 +++++++----------- 2 files changed, 91 insertions(+), 140 deletions(-) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala index a759fdf0a6a..8ac6ded9e70 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/checksum/ChecksumWriterSuite.scala @@ -17,6 +17,7 @@ package io.delta.kernel.internal.checksum import java.util import java.util.{Collections, Optional} + import io.delta.kernel.data.Row import io.delta.kernel.internal.actions.{Format, Metadata, Protocol} import io.delta.kernel.internal.checksum.CRCInfo.CRC_FILE_SCHEMA @@ -27,6 +28,7 @@ import io.delta.kernel.internal.util.VectorUtils.{buildArrayValue, stringStringM import io.delta.kernel.test.{BaseMockJsonHandler, MockEngineUtils} import io.delta.kernel.types.{StringType, StructType} import io.delta.kernel.utils.CloseableIterator + import org.scalatest.funsuite.AnyFunSuite /** diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala index 17e0b66e20c..8944fe48f3f 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -16,45 +16,21 @@ package io.delta.kernel.internal.util +import java.lang.{Boolean => BooleanJ, Byte => ByteJ, Double => DoubleJ, Float => FloatJ, Integer => IntegerJ, Long => LongJ, Short => ShortJ} +import java.math.BigDecimal +import java.sql.{Date, Timestamp} +import java.util + +import scala.collection.JavaConverters._ + import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} import io.delta.kernel.internal.data.GenericRow - -import java.sql.{Date, Timestamp} import io.delta.kernel.test.VectorTestUtils -import io.delta.kernel.types.{ - ArrayType, - BinaryType, - BooleanType, - ByteType, - DateType, - DecimalType, - DoubleType, - FloatType, - IntegerType, - LongType, - MapType, - ShortType, - StringType, - StructType, - TimestampType -} +import io.delta.kernel.types.{ArrayType, BinaryType, BooleanType, ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampType} -import java.lang.{ - Boolean => BooleanJ, - Byte => ByteJ, - Double => DoubleJ, - Float => FloatJ, - Integer => IntegerJ, - Long => LongJ, - Short => ShortJ -} -import scala.collection.JavaConverters._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.prop.Tables.Table -import java.math.BigDecimal -import java.util - class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { Table( @@ -69,89 +45,79 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { (List[BooleanJ](true, false, false, null), BooleanType.BOOLEAN), ( List[BigDecimal](new BigDecimal("1"), new BigDecimal("2"), new BigDecimal("3"), null), - new DecimalType(10, 2) - ), + new DecimalType(10, 2)), (List[String]("one", "two", "three", null), StringType.STRING), ( List[IntegerJ](10, 20, 30, null), - DateType.DATE - ), + DateType.DATE), ( List[LongJ]( Timestamp.valueOf("2023-01-01 00:00:00").getTime, Timestamp.valueOf("2023-01-02 00:00:00").getTime, Timestamp.valueOf("2023-01-03 00:00:00").getTime, - null - ), - TimestampType.TIMESTAMP - ) - ).foreach( - testCase => - test(s"handle ${testCase._2} array correctly") { - val values = testCase._1 - val dataType = testCase._2 - val columnVector = VectorUtils.buildColumnVector(values.asJava, dataType) - assert(columnVector.getSize == 4) - - dataType match { - case ByteType.BYTE => - assert(columnVector.getByte(0) == 1.toByte) - assert(columnVector.getByte(1) == 2.toByte) - assert(columnVector.getByte(2) == 3.toByte) - case ShortType.SHORT => - assert(columnVector.getShort(0) == 1.toShort) - assert(columnVector.getShort(1) == 2.toShort) - assert(columnVector.getShort(2) == 3.toShort) - case IntegerType.INTEGER => - assert(columnVector.getInt(0) == 1) - assert(columnVector.getInt(1) == 2) - assert(columnVector.getInt(2) == 3) - case LongType.LONG => - assert(columnVector.getLong(0) == 1L) - assert(columnVector.getLong(1) == 2L) - assert(columnVector.getLong(2) == 3L) - case FloatType.FLOAT => - assert(columnVector.getFloat(0) == 1.0f) - assert(columnVector.getFloat(1) == 2.0f) - assert(columnVector.getFloat(2) == 3.0f) - case DoubleType.DOUBLE => - assert(columnVector.getDouble(0) == 1.0) - assert(columnVector.getDouble(1) == 2.0) - assert(columnVector.getDouble(2) == 3.0) - case BooleanType.BOOLEAN => - assert(columnVector.getBoolean(0)) - assert(!columnVector.getBoolean(1)) - assert(!columnVector.getBoolean(2)) - case _: DecimalType => - assert(columnVector.getDecimal(0) == new BigDecimal("1")) - assert(columnVector.getDecimal(1) == new BigDecimal("2")) - assert(columnVector.getDecimal(2) == new BigDecimal("3")) - case BinaryType.BINARY => - assert(columnVector.getBinary(0) sameElements "one".getBytes) - assert(columnVector.getBinary(1) sameElements "two".getBytes) - assert(columnVector.getBinary(2) sameElements "three".getBytes) - case StringType.STRING => - assert(columnVector.getString(0) == "one") - assert(columnVector.getString(1) == "two") - assert(columnVector.getString(2) == "three") - case DateType.DATE => - assert(columnVector.getInt(0) == 10) - assert(columnVector.getInt(1) == 20) - assert(columnVector.getInt(2) == 30) - case TimestampType.TIMESTAMP => - assert( - columnVector.getLong(0) == Timestamp.valueOf("2023-01-01 00:00:00").getTime - ) - assert( - columnVector.getLong(1) == Timestamp.valueOf("2023-01-02 00:00:00").getTime - ) - assert( - columnVector.getLong(2) == Timestamp.valueOf("2023-01-03 00:00:00").getTime - ) - } - assert(columnVector.isNullAt(3)) + null), + TimestampType.TIMESTAMP)).foreach(testCase => + test(s"handle ${testCase._2} array correctly") { + val values = testCase._1 + val dataType = testCase._2 + val columnVector = VectorUtils.buildColumnVector(values.asJava, dataType) + assert(columnVector.getSize == 4) + + dataType match { + case ByteType.BYTE => + assert(columnVector.getByte(0) == 1.toByte) + assert(columnVector.getByte(1) == 2.toByte) + assert(columnVector.getByte(2) == 3.toByte) + case ShortType.SHORT => + assert(columnVector.getShort(0) == 1.toShort) + assert(columnVector.getShort(1) == 2.toShort) + assert(columnVector.getShort(2) == 3.toShort) + case IntegerType.INTEGER => + assert(columnVector.getInt(0) == 1) + assert(columnVector.getInt(1) == 2) + assert(columnVector.getInt(2) == 3) + case LongType.LONG => + assert(columnVector.getLong(0) == 1L) + assert(columnVector.getLong(1) == 2L) + assert(columnVector.getLong(2) == 3L) + case FloatType.FLOAT => + assert(columnVector.getFloat(0) == 1.0f) + assert(columnVector.getFloat(1) == 2.0f) + assert(columnVector.getFloat(2) == 3.0f) + case DoubleType.DOUBLE => + assert(columnVector.getDouble(0) == 1.0) + assert(columnVector.getDouble(1) == 2.0) + assert(columnVector.getDouble(2) == 3.0) + case BooleanType.BOOLEAN => + assert(columnVector.getBoolean(0)) + assert(!columnVector.getBoolean(1)) + assert(!columnVector.getBoolean(2)) + case _: DecimalType => + assert(columnVector.getDecimal(0) == new BigDecimal("1")) + assert(columnVector.getDecimal(1) == new BigDecimal("2")) + assert(columnVector.getDecimal(2) == new BigDecimal("3")) + case BinaryType.BINARY => + assert(columnVector.getBinary(0) sameElements "one".getBytes) + assert(columnVector.getBinary(1) sameElements "two".getBytes) + assert(columnVector.getBinary(2) sameElements "three".getBytes) + case StringType.STRING => + assert(columnVector.getString(0) == "one") + assert(columnVector.getString(1) == "two") + assert(columnVector.getString(2) == "three") + case DateType.DATE => + assert(columnVector.getInt(0) == 10) + assert(columnVector.getInt(1) == 20) + assert(columnVector.getInt(2) == 30) + case TimestampType.TIMESTAMP => + assert( + columnVector.getLong(0) == Timestamp.valueOf("2023-01-01 00:00:00").getTime) + assert( + columnVector.getLong(1) == Timestamp.valueOf("2023-01-02 00:00:00").getTime) + assert( + columnVector.getLong(2) == Timestamp.valueOf("2023-01-03 00:00:00").getTime) } - ) + assert(columnVector.isNullAt(3)) + }) test(s"handle array of struct correctly") { val structType = @@ -172,33 +138,26 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { override def getElements: ColumnVector = VectorUtils.buildColumnVector( List[Row]( row("a1", 1), - row("a2", 2) - ).asJava, - structType - ) + row("a2", 2)).asJava, + structType) }, new ArrayValue { override def getSize: Int = 2 override def getElements: ColumnVector = VectorUtils.buildColumnVector( List[Row]( row("b1", 3), - row("b2", 4) - ).asJava, - structType - ) + row("b2", 4)).asJava, + structType) }, new ArrayValue { override def getSize: Int = 2 override def getElements: ColumnVector = VectorUtils.buildColumnVector( List[Row]( row("c1", 5), - row("c2", 6) - ).asJava, - structType - ) + row("c2", 6)).asJava, + structType) }, - null - ) + null) val columnVector = VectorUtils.buildColumnVector(values.asJava, arrayType) @@ -267,10 +226,8 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { VectorUtils.buildColumnVector(List("a3", "a4").asJava, StringType.STRING) override def getValues: ColumnVector = VectorUtils.buildColumnVector(List[IntegerJ](3, 4).asJava, IntegerType.INTEGER) - } - ).asJava, - mapType - ) + }).asJava, + mapType) }, new ArrayValue { override def getSize: Int = 2 @@ -289,13 +246,10 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { VectorUtils.buildColumnVector(List("b3", "b4").asJava, StringType.STRING) override def getValues: ColumnVector = VectorUtils.buildColumnVector(List[IntegerJ](7, 8).asJava, IntegerType.INTEGER) - } - ).asJava, - mapType - ) + }).asJava, + mapType) }, - null - ) + null) val columnVector = VectorUtils.buildColumnVector(values.asJava, arrayType) @@ -358,10 +312,8 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { override def getSize: Int = 2 override def getElements: ColumnVector = VectorUtils.buildColumnVector(List[IntegerJ](3, 4).asJava, IntegerType.INTEGER) - } - ).asJava, - innerArrayType - ) + }).asJava, + innerArrayType) }, new ArrayValue { override def getSize: Int = 2 @@ -376,13 +328,10 @@ class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { override def getSize: Int = 2 override def getElements: ColumnVector = VectorUtils.buildColumnVector(List[IntegerJ](7, 8).asJava, IntegerType.INTEGER) - } - ).asJava, - innerArrayType - ) + }).asJava, + innerArrayType) }, - null - ) + null) val columnVector = VectorUtils.buildColumnVector(values.asJava, outerArrayType)