Skip to content

Commit 7a7e4ed

Browse files
committed
GH-891: Add ExtensionTypeWriterFactory to TransferPair
1 parent aee8a10 commit 7a7e4ed

File tree

6 files changed

+210
-3
lines changed

6 files changed

+210
-3
lines changed

vector/src/main/codegen/templates/ComplexCopier.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension
6363
writer.startList();
6464
while (reader.next()) {
6565
FieldReader childReader = reader.reader();
66-
FieldWriter childWriter = getListWriterForReader(childReader, writer);
66+
FieldWriter childWriter = getListWriterForReader(childReader, writer, extensionTypeWriterFactory);
6767
if (childReader.isSet()) {
6868
writeValue(childReader, childWriter, extensionTypeWriterFactory);
6969
} else {
@@ -189,6 +189,10 @@ private static FieldWriter getStructWriterForReader(FieldReader reader, StructWr
189189
}
190190
191191
private static FieldWriter getListWriterForReader(FieldReader reader, ListWriter writer) {
192+
return getListWriterForReader(reader, writer, null);
193+
}
194+
195+
private static FieldWriter getListWriterForReader(FieldReader reader, ListWriter writer, ExtensionTypeWriterFactory extensionTypeWriterFactory) {
192196
switch (reader.getMinorType()) {
193197
<#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first />
194198
<#assign fields = minor.fields!type.fields />
@@ -209,6 +213,9 @@ private static FieldWriter getListWriterForReader(FieldReader reader, ListWriter
209213
return (FieldWriter) writer.listView();
210214
case EXTENSIONTYPE:
211215
ExtensionWriter extensionWriter = writer.extension(reader.getField().getType());
216+
if (extensionTypeWriterFactory != null) {
217+
extensionWriter.addExtensionTypeWriterFactory(extensionTypeWriterFactory);
218+
}
212219
return (FieldWriter) extensionWriter;
213220
default:
214221
throw new UnsupportedOperationException(reader.getMinorType().toString());

vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import org.apache.arrow.vector.util.OversizedAllocationException;
6464
import org.apache.arrow.vector.util.SchemaChangeRuntimeException;
6565
import org.apache.arrow.vector.util.TransferPair;
66+
import org.apache.arrow.vector.util.TransferPairWithExtendedType;
6667

6768
/**
6869
* A list vector contains lists of a specific type of elements. Its structure contains 3 elements.
@@ -648,7 +649,7 @@ public UnionVector promoteToUnion() {
648649
return vector;
649650
}
650651

651-
private class TransferImpl implements TransferPair {
652+
private class TransferImpl implements TransferPairWithExtendedType {
652653

653654
LargeListVector to;
654655
TransferPair dataTransferPair;
@@ -731,6 +732,12 @@ public ValueVector getTo() {
731732
public void copyValueSafe(int from, int to) {
732733
this.to.copyFrom(from, to, LargeListVector.this);
733734
}
735+
736+
@Override
737+
public void copyValueSafe(
738+
int from, int to, ExtensionTypeWriterFactory extensionTypeWriterFactory) {
739+
this.to.copyFrom(from, to, LargeListVector.this, extensionTypeWriterFactory);
740+
}
734741
}
735742

736743
@Override

vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.apache.arrow.vector.util.JsonStringArrayList;
5757
import org.apache.arrow.vector.util.OversizedAllocationException;
5858
import org.apache.arrow.vector.util.TransferPair;
59+
import org.apache.arrow.vector.util.TransferPairWithExtendedType;
5960

6061
/**
6162
* A list vector contains lists of a specific type of elements. Its structure contains 3 elements.
@@ -528,7 +529,7 @@ public <OUT, IN> OUT accept(VectorVisitor<OUT, IN> visitor, IN value) {
528529
return visitor.visit(this, value);
529530
}
530531

531-
private class TransferImpl implements TransferPair {
532+
private class TransferImpl implements TransferPairWithExtendedType {
532533

533534
ListVector to;
534535
TransferPair dataTransferPair;
@@ -612,6 +613,12 @@ public ValueVector getTo() {
612613
public void copyValueSafe(int from, int to) {
613614
this.to.copyFrom(from, to, ListVector.this);
614615
}
616+
617+
@Override
618+
public void copyValueSafe(
619+
int from, int to, ExtensionTypeWriterFactory extensionTypeWriterFactory) {
620+
this.to.copyFrom(from, to, ListVector.this, extensionTypeWriterFactory);
621+
}
615622
}
616623

617624
@Override
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.arrow.vector.util;
18+
19+
import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory;
20+
21+
public interface TransferPairWithExtendedType extends TransferPair {
22+
void copyValueSafe(int from, int to, ExtensionTypeWriterFactory extensionTypeWriterFactory);
23+
}

vector/src/test/java/org/apache/arrow/vector/TestLargeListVector.java

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,27 @@
2323
import static org.junit.jupiter.api.Assertions.assertSame;
2424
import static org.junit.jupiter.api.Assertions.assertTrue;
2525

26+
import java.nio.ByteBuffer;
2627
import java.util.ArrayList;
2728
import java.util.Arrays;
2829
import java.util.List;
30+
import java.util.UUID;
2931
import org.apache.arrow.memory.ArrowBuf;
3032
import org.apache.arrow.memory.BufferAllocator;
3133
import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
3234
import org.apache.arrow.vector.complex.LargeListVector;
3335
import org.apache.arrow.vector.complex.ListVector;
36+
import org.apache.arrow.vector.complex.impl.UnionLargeListReader;
3437
import org.apache.arrow.vector.complex.impl.UnionLargeListWriter;
38+
import org.apache.arrow.vector.complex.impl.UuidWriterFactory;
3539
import org.apache.arrow.vector.complex.reader.FieldReader;
40+
import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter;
41+
import org.apache.arrow.vector.holder.UuidHolder;
3642
import org.apache.arrow.vector.types.Types.MinorType;
3743
import org.apache.arrow.vector.types.pojo.ArrowType;
3844
import org.apache.arrow.vector.types.pojo.Field;
3945
import org.apache.arrow.vector.types.pojo.FieldType;
46+
import org.apache.arrow.vector.types.pojo.UuidType;
4047
import org.apache.arrow.vector.util.TransferPair;
4148
import org.junit.jupiter.api.AfterEach;
4249
import org.junit.jupiter.api.BeforeEach;
@@ -1021,6 +1028,84 @@ public void testGetTransferPairWithField() throws Exception {
10211028
}
10221029
}
10231030

1031+
@Test
1032+
public void testCopyValueSafeForExtensionType() throws Exception {
1033+
try (LargeListVector inVector = LargeListVector.empty("input", allocator);
1034+
LargeListVector outVector = LargeListVector.empty("output", allocator)) {
1035+
UnionLargeListWriter writer = inVector.getWriter();
1036+
writer.allocate();
1037+
1038+
// Create first list with UUIDs
1039+
writer.setPosition(0);
1040+
UUID u1 = UUID.randomUUID();
1041+
UUID u2 = UUID.randomUUID();
1042+
writer.startList();
1043+
ExtensionWriter extensionWriter = writer.extension(new UuidType());
1044+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
1045+
extensionWriter.writeExtension(u1);
1046+
extensionWriter.writeExtension(u2);
1047+
writer.endList();
1048+
1049+
// Create second list with UUIDs
1050+
writer.setPosition(1);
1051+
UUID u3 = UUID.randomUUID();
1052+
UUID u4 = UUID.randomUUID();
1053+
writer.startList();
1054+
extensionWriter = writer.extension(new UuidType());
1055+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
1056+
extensionWriter.writeExtension(u3);
1057+
extensionWriter.writeExtension(u4);
1058+
extensionWriter.writeNull();
1059+
1060+
writer.endList();
1061+
writer.setValueCount(2);
1062+
1063+
// Use copyFromSafe with ExtensionTypeWriterFactory
1064+
// This internally calls TransferImpl.copyValueSafe with ExtensionTypeWriterFactory
1065+
outVector.allocateNew();
1066+
outVector.copyFromSafe(0, 0, inVector, new UuidWriterFactory());
1067+
outVector.copyFromSafe(1, 1, inVector, new UuidWriterFactory());
1068+
outVector.setValueCount(2);
1069+
1070+
// Verify first list
1071+
UnionLargeListReader reader = outVector.getReader();
1072+
reader.setPosition(0);
1073+
assertTrue(reader.isSet(), "first list shouldn't be null");
1074+
reader.next();
1075+
FieldReader uuidReader = reader.reader();
1076+
UuidHolder holder = new UuidHolder();
1077+
uuidReader.read(holder);
1078+
ByteBuffer bb = ByteBuffer.wrap(holder.value);
1079+
UUID actualUuid = new UUID(bb.getLong(), bb.getLong());
1080+
assertEquals(u1, actualUuid);
1081+
reader.next();
1082+
uuidReader = reader.reader();
1083+
uuidReader.read(holder);
1084+
bb = ByteBuffer.wrap(holder.value);
1085+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1086+
assertEquals(u2, actualUuid);
1087+
1088+
// Verify second list
1089+
reader.setPosition(1);
1090+
assertTrue(reader.isSet(), "second list shouldn't be null");
1091+
reader.next();
1092+
uuidReader = reader.reader();
1093+
uuidReader.read(holder);
1094+
bb = ByteBuffer.wrap(holder.value);
1095+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1096+
assertEquals(u3, actualUuid);
1097+
reader.next();
1098+
uuidReader = reader.reader();
1099+
uuidReader.read(holder);
1100+
bb = ByteBuffer.wrap(holder.value);
1101+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1102+
assertEquals(u4, actualUuid);
1103+
reader.next();
1104+
uuidReader = reader.reader();
1105+
assertFalse(uuidReader.isSet(), "third element should be null");
1106+
}
1107+
}
1108+
10241109
private void writeIntValues(UnionLargeListWriter writer, int[] values) {
10251110
writer.startList();
10261111
for (int v : values) {

vector/src/test/java/org/apache/arrow/vector/TestListVector.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,84 @@ public void testCopyFromForExtensionType() throws Exception {
13141314
}
13151315
}
13161316

1317+
@Test
1318+
public void testCopyValueSafeForExtensionType() throws Exception {
1319+
try (ListVector inVector = ListVector.empty("input", allocator);
1320+
ListVector outVector = ListVector.empty("output", allocator)) {
1321+
UnionListWriter writer = inVector.getWriter();
1322+
writer.allocate();
1323+
1324+
// Create first list with UUIDs
1325+
writer.setPosition(0);
1326+
UUID u1 = UUID.randomUUID();
1327+
UUID u2 = UUID.randomUUID();
1328+
writer.startList();
1329+
ExtensionWriter extensionWriter = writer.extension(new UuidType());
1330+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
1331+
extensionWriter.writeExtension(u1);
1332+
extensionWriter.writeExtension(u2);
1333+
writer.endList();
1334+
1335+
// Create second list with UUIDs
1336+
writer.setPosition(1);
1337+
UUID u3 = UUID.randomUUID();
1338+
UUID u4 = UUID.randomUUID();
1339+
writer.startList();
1340+
extensionWriter = writer.extension(new UuidType());
1341+
extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory());
1342+
extensionWriter.writeExtension(u3);
1343+
extensionWriter.writeExtension(u4);
1344+
extensionWriter.writeNull();
1345+
1346+
writer.endList();
1347+
writer.setValueCount(2);
1348+
1349+
// Use copyFromSafe with ExtensionTypeWriterFactory
1350+
// This internally calls TransferImpl.copyValueSafe with ExtensionTypeWriterFactory
1351+
outVector.allocateNew();
1352+
outVector.copyFromSafe(0, 0, inVector, new UuidWriterFactory());
1353+
outVector.copyFromSafe(1, 1, inVector, new UuidWriterFactory());
1354+
outVector.setValueCount(2);
1355+
1356+
// Verify first list
1357+
UnionListReader reader = outVector.getReader();
1358+
reader.setPosition(0);
1359+
assertTrue(reader.isSet(), "first list shouldn't be null");
1360+
reader.next();
1361+
FieldReader uuidReader = reader.reader();
1362+
UuidHolder holder = new UuidHolder();
1363+
uuidReader.read(holder);
1364+
ByteBuffer bb = ByteBuffer.wrap(holder.value);
1365+
UUID actualUuid = new UUID(bb.getLong(), bb.getLong());
1366+
assertEquals(u1, actualUuid);
1367+
reader.next();
1368+
uuidReader = reader.reader();
1369+
uuidReader.read(holder);
1370+
bb = ByteBuffer.wrap(holder.value);
1371+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1372+
assertEquals(u2, actualUuid);
1373+
1374+
// Verify second list
1375+
reader.setPosition(1);
1376+
assertTrue(reader.isSet(), "second list shouldn't be null");
1377+
reader.next();
1378+
uuidReader = reader.reader();
1379+
uuidReader.read(holder);
1380+
bb = ByteBuffer.wrap(holder.value);
1381+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1382+
assertEquals(u3, actualUuid);
1383+
reader.next();
1384+
uuidReader = reader.reader();
1385+
uuidReader.read(holder);
1386+
bb = ByteBuffer.wrap(holder.value);
1387+
actualUuid = new UUID(bb.getLong(), bb.getLong());
1388+
assertEquals(u4, actualUuid);
1389+
reader.next();
1390+
uuidReader = reader.reader();
1391+
assertFalse(uuidReader.isSet(), "third element should be null");
1392+
}
1393+
}
1394+
13171395
private void writeIntValues(UnionListWriter writer, int[] values) {
13181396
writer.startList();
13191397
for (int v : values) {

0 commit comments

Comments
 (0)