Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel]Generalize the stringArrayValue and stringVector in VectorUtils #4161

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import static io.delta.kernel.internal.util.ColumnMapping.isColumnMappingModeEnabled;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;
import static io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames;
import static io.delta.kernel.internal.util.VectorUtils.stringArrayValue;
import static io.delta.kernel.internal.util.VectorUtils.buildArrayValue;
import static io.delta.kernel.internal.util.VectorUtils.stringStringMapValue;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;
Expand All @@ -44,6 +44,7 @@
import io.delta.kernel.internal.util.ColumnMapping.ColumnMappingMode;
import io.delta.kernel.internal.util.SchemaUtils;
import io.delta.kernel.internal.util.Tuple2;
import io.delta.kernel.types.StringType;
import io.delta.kernel.types.StructType;
import java.util.*;
import org.slf4j.Logger;
Expand Down Expand Up @@ -309,7 +310,7 @@ private Metadata getInitialMetadata() {
new Format(), /* format */
schema.get().toJson(), /* schemaString */
schema.get(), /* schema */
stringArrayValue(partitionColumnsCasePreserving), /* partitionColumns */
buildArrayValue(partitionColumnsCasePreserving, StringType.STRING), /* partitionColumns */
Optional.of(currentTimeMillis), /* createdTime */
stringStringMapValue(Collections.emptyMap()) /* configuration */);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import static io.delta.kernel.internal.tablefeatures.TableFeatures.TABLE_FEATURES;
import static io.delta.kernel.internal.tablefeatures.TableFeatures.TABLE_FEATURES_MIN_WRITER_VERSION;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;
import static io.delta.kernel.internal.util.VectorUtils.stringArrayValue;
import static io.delta.kernel.internal.util.VectorUtils.buildArrayValue;
import static java.lang.String.format;
import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet;
Expand Down Expand Up @@ -151,10 +151,10 @@ public Row toRow() {
protocolMap.put(0, minReaderVersion);
protocolMap.put(1, minWriterVersion);
if (supportsReaderFeatures) {
protocolMap.put(2, stringArrayValue(new ArrayList<>(readerFeatures)));
protocolMap.put(2, buildArrayValue(new ArrayList<>(readerFeatures), StringType.STRING));
}
if (supportsWriterFeatures) {
protocolMap.put(3, stringArrayValue(new ArrayList<>(writerFeatures)));
protocolMap.put(3, buildArrayValue(new ArrayList<>(writerFeatures), StringType.STRING));
}

return new GenericRow(Protocol.FULL_SCHEMA, protocolMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import io.delta.kernel.data.Row;
import io.delta.kernel.internal.data.StructRow;
import io.delta.kernel.types.*;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public final class VectorUtils {

Expand Down Expand Up @@ -70,31 +72,6 @@ public static <K, V> Map<K, V> toJavaMap(MapValue mapValue) {
return values;
}

/**
* Creates an {@link ArrayValue} from list of strings. The type {@code array(string)} is a common
* occurrence in Delta Log schema. We don't have any non-string array type in Delta Log. If we end
* up needing to support other types, we can make this generic.
*
* @param values list of strings
* @return an {@link ArrayValue} with the given values of type {@link StringType}
*/
public static ArrayValue stringArrayValue(List<String> values) {
if (values == null) {
return null;
}
return new ArrayValue() {
@Override
public int getSize() {
return values.size();
}

@Override
public ColumnVector getElements() {
return stringVector(values);
}
};
}

/**
* Creates a {@link MapValue} from map of string keys and string values. The type {@code
* map(string -> string)} is a common occurrence in Delta Log schema.
Expand All @@ -117,27 +94,46 @@ public int getSize() {

@Override
public ColumnVector getKeys() {
return stringVector(keys);
return buildColumnVector(keys, StringType.STRING);
}

@Override
public ColumnVector getValues() {
return stringVector(values);
return buildColumnVector(values, StringType.STRING);
}
};
}

/** Creates an {@link ArrayValue} from list of objects. */
public static ArrayValue buildArrayValue(List<?> values, DataType dataType) {
if (values == null) {
return null;
}
return new ArrayValue() {
@Override
public int getSize() {
return values.size();
}

@Override
public ColumnVector getElements() {
return buildColumnVector(values, dataType);
}
};
}

/**
* Utility method to create a {@link ColumnVector} for given list of strings.
* Utility method to create a {@link ColumnVector} for given list of object, the object should be
* primitive type or an Row instance.
*
* @param values list of strings
* @return a {@link ColumnVector} with the given values of type {@link StringType}
* @return a {@link ColumnVector} with the given values type.
*/
public static ColumnVector stringVector(List<String> values) {
public static ColumnVector buildColumnVector(List<?> values, DataType dataType) {
return new ColumnVector() {
@Override
public DataType getDataType() {
return StringType.STRING;
return dataType;
}

@Override
Expand All @@ -152,14 +148,172 @@ public void close() {

@Override
public boolean isNullAt(int rowId) {
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
validateRowId(rowId);
return values.get(rowId) == null;
}

@Override
public boolean getBoolean(int rowId) {
checkArgument(BooleanType.BOOLEAN.equals(dataType));
return (Boolean) getValidatedValue(rowId, Boolean.class);
}

@Override
public byte getByte(int rowId) {
checkArgument(ByteType.BYTE.equals(dataType));
return (Byte) getValidatedValue(rowId, Byte.class);
}

@Override
public short getShort(int rowId) {
checkArgument(ShortType.SHORT.equals(dataType));
return (Short) getValidatedValue(rowId, Short.class);
}

@Override
public int getInt(int rowId) {
checkArgument(IntegerType.INTEGER.equals(dataType) || DateType.DATE.equals(dataType));
return (Integer) getValidatedValue(rowId, Integer.class);
}

@Override
public long getLong(int rowId) {
checkArgument(LongType.LONG.equals(dataType) || TimestampType.TIMESTAMP.equals(dataType));
return (Long) getValidatedValue(rowId, Long.class);
}

@Override
public float getFloat(int rowId) {
checkArgument(FloatType.FLOAT.equals(dataType));
return (Float) getValidatedValue(rowId, Float.class);
}

@Override
public double getDouble(int rowId) {
checkArgument(DoubleType.DOUBLE.equals(dataType));
return (Double) getValidatedValue(rowId, Double.class);
}

@Override
public BigDecimal getDecimal(int rowId) {
checkArgument(dataType instanceof DecimalType);
return (BigDecimal) getValidatedValue(rowId, BigDecimal.class);
}

@Override
public String getString(int rowId) {
checkArgument(StringType.STRING.equals(dataType));
return (String) getValidatedValue(rowId, String.class);
}

@Override
public byte[] getBinary(int rowId) {
checkArgument(BinaryType.BINARY.equals(dataType));
return (byte[]) getValidatedValue(rowId, byte[].class);
}

@Override
public ArrayValue getArray(int rowId) {
checkArgument(dataType instanceof ArrayType);
return (ArrayValue) getValidatedValue(rowId, ArrayValue.class);
}

@Override
public MapValue getMap(int rowId) {
checkArgument(dataType instanceof MapType);
return (MapValue) getValidatedValue(rowId, MapValue.class);
}

@Override
public ColumnVector getChild(int ordinal) {
checkArgument(dataType instanceof StructType);
checkArgument(ordinal < ((StructType) dataType).length());

DataType childDatatype = ((StructType) dataType).at(ordinal).getDataType();
List<?> childValues = extractChildValues(ordinal, childDatatype);

return buildColumnVector(childValues, childDatatype);
}

private void validateRowId(int rowId) {
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
return values.get(rowId);
}

private Object getValidatedValue(int rowId, Class<?> expectedType) {
validateRowId(rowId);
Object value = values.get(rowId);
checkArgument(
expectedType.isInstance(value),
"Value must be of type %s",
expectedType.getSimpleName());
return value;
}

private List<?> extractChildValues(int ordinal, DataType childDatatype) {
return values.stream()
.map(e -> extractChildValue(e, ordinal, childDatatype))
.collect(Collectors.toList());
}

private Object extractChildValue(Object element, int ordinal, DataType childDatatype) {
checkArgument(element instanceof Row);
Row row = (Row) element;

if (row.isNullAt(ordinal)) {
return null;
}

return extractTypedValue(row, ordinal, childDatatype);
}

private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) {
// Primitive Types
if (childDatatype instanceof BooleanType) {
return row.getBoolean(ordinal);
}
if (childDatatype instanceof ByteType) {
return row.getByte(ordinal);
}
if (childDatatype instanceof ShortType) {
return row.getShort(ordinal);
}
if (childDatatype instanceof IntegerType || childDatatype instanceof DateType) {
return row.getInt(ordinal);
}
if (childDatatype instanceof LongType || childDatatype instanceof TimestampType) {
return row.getLong(ordinal);
}
if (childDatatype instanceof FloatType) {
return row.getFloat(ordinal);
}
if (childDatatype instanceof DoubleType) {
return row.getDouble(ordinal);
}

// Complex Types
if (childDatatype instanceof StringType) {
return row.getString(ordinal);
}
if (childDatatype instanceof BinaryType) {
return row.getBinary(ordinal);
}
if (childDatatype instanceof DecimalType) {
return row.getDecimal(ordinal);
}

// Nested Types
if (childDatatype instanceof StructType) {
return row.getStruct(ordinal);
}
if (childDatatype instanceof ArrayType) {
return row.getArray(ordinal);
}
if (childDatatype instanceof MapType) {
return row.getMap(ordinal);
}

throw new UnsupportedOperationException(
String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName()));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object TransactionSuite extends VectorTestUtils with MockEngineUtils {
new Format(),
DataTypeJsonSerDe.serializeDataType(schema),
schema,
VectorUtils.stringArrayValue(partitionCols.asJava), // partitionColumns
VectorUtils.buildArrayValue(partitionCols.asJava, StringType.STRING), // partitionColumns
Optional.empty(), // createdTime
stringStringMapValue(configurationMap.asJava) // configurationMap
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import io.delta.kernel.internal.checksum.CRCInfo.CRC_FILE_SCHEMA
import io.delta.kernel.internal.data.GenericRow
import io.delta.kernel.internal.fs.Path
import io.delta.kernel.internal.util.VectorUtils
import io.delta.kernel.internal.util.VectorUtils.{stringArrayValue, stringStringMapValue}
import io.delta.kernel.internal.util.VectorUtils.{buildArrayValue, stringStringMapValue}
import io.delta.kernel.test.{BaseMockJsonHandler, MockEngineUtils}
import io.delta.kernel.types.StructType
import io.delta.kernel.types.{StringType, StructType}
import io.delta.kernel.utils.CloseableIterator

import org.scalatest.funsuite.AnyFunSuite
Expand Down Expand Up @@ -162,7 +162,7 @@ class ChecksumWriterSuite extends AnyFunSuite with MockEngineUtils {
new Format("parquet", Collections.emptyMap()),
"schemaString",
new StructType(),
stringArrayValue(util.Arrays.asList("c3")),
buildArrayValue(util.Arrays.asList("c3"), StringType.STRING),
Optional.of(123),
stringStringMapValue(new util.HashMap[String, String]() {
put("delta.appendOnly", "true")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class ColumnMappingSuite extends AnyFunSuite {
new Format,
schema.toJson,
schema,
VectorUtils.stringArrayValue(Collections.emptyList()),
VectorUtils.buildArrayValue(Collections.emptyList(), StringType.STRING),
Optional.empty(),
VectorUtils.stringStringMapValue(Collections.emptyMap()))

Expand Down
Loading
Loading