Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Optimize not to call getNullCount as much as possible #820

Closed
wants to merge 9 commits into from
2 changes: 1 addition & 1 deletion .github/actions/java-test/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ runs:
overwrite: true

- name: Upload coverage results
if: ${{ inputs.upload-test-reports == 'true' }}
if: ${{ inputs.upload-test-reports == 'true' && ( runner.os != 'macOS' || runner.arch != 'X64' || inputs.maven_opts != '-Pspark-4.0' ) }}
uses: codecov/codecov-action@v3 # uses v3 as it allows tokenless uploading
128 changes: 128 additions & 0 deletions common/src/main/java/org/apache/arrow/c/CometArrayExporter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.arrow.c;

import java.util.ArrayList;
import java.util.List;

import org.apache.arrow.c.jni.JniWrapper;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;

import static org.apache.arrow.c.Data.exportField;
import static org.apache.arrow.c.NativeUtil.NULL;
import static org.apache.arrow.c.NativeUtil.addressOrNull;
import static org.apache.arrow.util.Preconditions.checkNotNull;

public final class CometArrayExporter {
// Copied from Data.exportVector and changed to take nullCount from outside
public static void exportVector(
BufferAllocator allocator,
FieldVector vector,
DictionaryProvider provider,
ArrowArray out,
ArrowSchema outSchema,
long nullCount,
long dictValNullCount) {
exportField(allocator, vector.getField(), provider, outSchema);
export(allocator, out, vector, provider, nullCount, dictValNullCount);
}

private static void export(
BufferAllocator allocator,
ArrowArray array,
FieldVector vector,
DictionaryProvider dictionaryProvider,
long nullCount,
long dictValNullCount) {
List<FieldVector> children = vector.getChildrenFromFields();
List<ArrowBuf> buffers = vector.getFieldBuffers();
int valueCount = vector.getValueCount();
DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();

ArrayExporter.ExportedArrayPrivateData data = new ArrayExporter.ExportedArrayPrivateData();
try {
if (children != null) {
data.children = new ArrayList<>(children.size());
data.children_ptrs = allocator.buffer((long) children.size() * Long.BYTES);
for (int i = 0; i < children.size(); i++) {
ArrowArray child = ArrowArray.allocateNew(allocator);
data.children.add(child);
data.children_ptrs.writeLong(child.memoryAddress());
}
}

if (buffers != null) {
data.buffers = new ArrayList<>(buffers.size());
data.buffers_ptrs = allocator.buffer((long) buffers.size() * Long.BYTES);
vector.exportCDataBuffers(data.buffers, data.buffers_ptrs, NULL);
}

if (dictionaryEncoding != null) {
Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
checkNotNull(dictionary, "Dictionary lookup failed on export of dictionary encoded array");

data.dictionary = ArrowArray.allocateNew(allocator);
FieldVector dictionaryVector = dictionary.getVector();
export(
allocator, data.dictionary, dictionaryVector, dictionaryProvider, dictValNullCount, 0);
}

ArrowArray.Snapshot snapshot = new ArrowArray.Snapshot();
snapshot.length = valueCount;
snapshot.null_count = nullCount;
snapshot.offset = 0;
snapshot.n_buffers = (data.buffers != null) ? data.buffers.size() : 0;
snapshot.n_children = (data.children != null) ? data.children.size() : 0;
snapshot.buffers = addressOrNull(data.buffers_ptrs);
snapshot.children = addressOrNull(data.children_ptrs);
snapshot.dictionary = addressOrNull(data.dictionary);
snapshot.release = NULL;
array.save(snapshot);

// sets release and private data
JniWrapper.get().exportArray(array.memoryAddress(), data);
} catch (Exception e) {
data.close();
throw e;
}

// Export children
if (children != null) {
for (int i = 0; i < children.size(); i++) {
FieldVector childVector = children.get(i);
ArrowArray child = data.children.get(i);
// TODO: getNullCount is slow, avoid calling it if possible
long cNullCount = childVector.getNullCount();
DictionaryEncoding cDictionaryEncoding = childVector.getField().getDictionary();
long cDictValNullCount = 0;
if (cDictionaryEncoding != null) {
Dictionary dictionary = dictionaryProvider.lookup(cDictionaryEncoding.getId());
cDictValNullCount = dictionary.getVector().getNullCount();
}
export(allocator, child, childVector, dictionaryProvider, cNullCount, cDictValNullCount);
}
}
}
}
15 changes: 9 additions & 6 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
import org.apache.comet.vector.CometPlainVector;
import org.apache.comet.vector.CometVector;

import static org.apache.comet.parquet.Utils.getDictValNullCount;
import static org.apache.comet.parquet.Utils.getNullCount;

public class ColumnReader extends AbstractColumnReader {
protected static final Logger LOG = LoggerFactory.getLogger(ColumnReader.class);

Expand Down Expand Up @@ -205,11 +208,13 @@ public CometDecodedVector loadVector() {

try (ArrowArray array = ArrowArray.wrap(addresses[0]);
ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
int nullCount = getNullCount(array);
int dictValNullCount = getDictValNullCount(array);
FieldVector vector = importer.importVector(array, schema);

DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();

CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, false, nullCount);

// Update whether the current vector contains any null values. This is used in the following
// batch(s) to determine whether we can skip loading the native vector.
Expand All @@ -233,7 +238,8 @@ public CometDecodedVector loadVector() {
// release the previous dictionary vector and create a new one.
Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId());
CometPlainVector dictionaryVector =
new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid);
new CometPlainVector(
arrowDictionary.getVector(), useDecimal128, isUuid, dictValNullCount);
if (dictionary != null) {
dictionary.setDictionaryVector(dictionaryVector);
} else {
Expand All @@ -243,9 +249,6 @@ public CometDecodedVector loadVector() {
currentVector =
new CometDictionaryVector(
cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid);

currentVector =
new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128);
return currentVector;
}
}
Expand All @@ -255,7 +258,7 @@ protected void readPage() {
if (page == null) {
throw new RuntimeException("overreading: returned DataPage is null");
}
;

int pageValueCount = page.getValueCount();
page.accept(
new DataPage.Visitor<Void>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ public CometVector currentBatch() {
}

/** Read all rows up to the `batchSize`. Expects no rows are skipped so far. */
public void readAllBatch() {
public boolean readAllBatch() {
// All rows should be read without any skips so far
assert (lastSkippedRowId == -1);

if (batchSize <= currentNumValues) return false;

readBatch(batchSize - 1, 0);
return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.apache.comet.vector.CometPlainVector;
import org.apache.comet.vector.CometVector;

import static org.apache.comet.parquet.Utils.getNullCount;

/** A metadata column reader that can be extended by {@link RowIndexColumnReader} etc. */
public class MetadataColumnReader extends AbstractColumnReader {
private final BufferAllocator allocator = new RootAllocator();
Expand All @@ -53,8 +55,9 @@ public void readBatch(int total) {
long[] addresses = Native.currentBatch(nativeHandle);
try (ArrowArray array = ArrowArray.wrap(addresses[0]);
ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
int nullCount = getNullCount(array);
FieldVector fieldVector = Data.importVector(allocator, array, schema, null);
vector = new CometPlainVector(fieldVector, useDecimal128);
vector = new CometPlainVector(fieldVector, useDecimal128, false, nullCount);
}
}
vector.setNumValues(total);
Expand Down
20 changes: 20 additions & 0 deletions common/src/main/java/org/apache/comet/parquet/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

package org.apache.comet.parquet;

import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;

import static org.apache.arrow.c.NativeUtil.NULL;

public class Utils {
public static ColumnReader getColumnReader(
Expand Down Expand Up @@ -257,4 +261,20 @@ static int getTimeUnitId(LogicalTypeAnnotation.TimeUnit tu) {
throw new UnsupportedOperationException("Unsupported TimeUnit " + tu);
}
}

public static int getNullCount(ArrowArray array) {
// The second long value in the c interface is the null count
// https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowArray.null_count
return (int) Platform.getLong(null, array.memoryAddress() + 8L);
}

public static int getDictValNullCount(ArrowArray array) {
// The 8th long value in the c interface is the dictionary addresses
// https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowArray.null_count
long dictionary = Platform.getLong(null, array.memoryAddress() + (8L * 7L));
if (dictionary == NULL) {
return 0;
}
return (int) Platform.getLong(null, dictionary + 8L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,27 @@ public abstract class CometDecodedVector extends CometVector {
private int validityByteCacheIndex = -1;
private byte validityByteCache;
protected boolean isUuid;
private final int dictValNumNulls;

protected CometDecodedVector(ValueVector vector, Field valueField, boolean useDecimal128) {
this(vector, valueField, useDecimal128, false);
// TODO: getNullCount is slow, avoid calling it if possible
this(vector, valueField, useDecimal128, false, vector.getNullCount(), 0);
}

protected CometDecodedVector(
ValueVector vector, Field valueField, boolean useDecimal128, boolean isUuid) {
ValueVector vector,
Field valueField,
boolean useDecimal128,
boolean isUuid,
int nullCount,
int dictValNullCount) {
super(Utils.fromArrowField(valueField), useDecimal128);
this.valueVector = vector;
this.numNulls = valueVector.getNullCount();
this.numNulls = nullCount;
this.numValues = valueVector.getValueCount();
this.hasNull = numNulls != 0;
this.isUuid = isUuid;
this.dictValNumNulls = dictValNullCount;
}

@Override
Expand Down Expand Up @@ -98,6 +106,11 @@ public int numNulls() {
return numNulls;
}

@Override
public int dictValNumNulls() {
return dictValNumNulls;
}

@Override
public boolean isNullAt(int rowId) {
if (!hasNull) return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public int numNulls() {
return delegate.numNulls();
}

@Override
public int dictValNumNulls() {
return delegate.dictValNumNulls();
}

@Override
public boolean isNullAt(int rowId) {
return delegate.isNullAt(rowId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ public UTF8String decodeToUTF8String(int index) {
return values.getUTF8String(index);
}

public int numNulls() {
return values.numNulls();
}

@Override
public void close() {
values.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ public CometDictionaryVector(
boolean useDecimal128,
boolean isAlias,
boolean isUuid) {
super(indices.valueVector, values.getValueVector().getField(), useDecimal128, isUuid);
super(
indices.valueVector,
values.getValueVector().getField(),
useDecimal128,
isUuid,
indices.numNulls(),
values.numNulls());
Preconditions.checkArgument(
indices.valueVector instanceof IntVector, "'indices' should be a IntVector");
this.values = values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ public CometDecodedVector getDecodedVector() {

@Override
public ValueVector getValueVector() {
columnReader.readAllBatch();
setDelegate(columnReader.loadVector());
if (columnReader.readAllBatch()) {
setDelegate(columnReader.loadVector());
}
return super.getValueVector();
}

Expand All @@ -60,15 +61,17 @@ public void close() {

@Override
public boolean hasNull() {
columnReader.readAllBatch();
setDelegate(columnReader.loadVector());
if (columnReader.readAllBatch()) {
setDelegate(columnReader.loadVector());
}
return super.hasNull();
}

@Override
public int numNulls() {
columnReader.readAllBatch();
setDelegate(columnReader.loadVector());
if (columnReader.readAllBatch()) {
setDelegate(columnReader.loadVector());
}
return super.numNulls();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public class CometListVector extends CometDecodedVector {
final ValueVector dataVector;
final ColumnVector dataColumnVector;

public CometListVector(ValueVector vector, boolean useDecimal128) {
super(vector, vector.getField(), useDecimal128);
public CometListVector(ValueVector vector, boolean useDecimal128, int nullCount) {
super(vector, vector.getField(), false, useDecimal128, nullCount, 0);

this.listVector = ((ListVector) vector);
this.dataVector = listVector.getDataVector();
Expand All @@ -51,7 +51,9 @@ public ColumnarArray getArray(int i) {
public CometVector slice(int offset, int length) {
TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);
ValueVector vector = tp.getTo();

return new CometListVector(tp.getTo(), useDecimal128);
// TODO: getNullCount is slow, avoid calling it if possible
return new CometListVector(vector, useDecimal128, vector.getNullCount());
}
}
Loading
Loading