Skip to content

Commit 5215558

Browse files
authored
[FLINK-38528][table] Introduce async vector search operator (#27126)
1 parent 0caa819 commit 5215558

File tree

6 files changed

+606
-12
lines changed

6 files changed

+606
-12
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecVectorSearchTableFunction.java

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,23 @@
2323
import org.apache.flink.configuration.PipelineOptions;
2424
import org.apache.flink.configuration.ReadableConfig;
2525
import org.apache.flink.streaming.api.functions.ProcessFunction;
26+
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
2627
import org.apache.flink.streaming.api.operators.ProcessOperator;
2728
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
2829
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
30+
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
2931
import org.apache.flink.table.api.TableException;
3032
import org.apache.flink.table.catalog.DataTypeFactory;
3133
import org.apache.flink.table.connector.source.VectorSearchTableSource;
3234
import org.apache.flink.table.connector.source.search.AsyncVectorSearchFunctionProvider;
3335
import org.apache.flink.table.connector.source.search.VectorSearchFunctionProvider;
3436
import org.apache.flink.table.data.RowData;
37+
import org.apache.flink.table.functions.AsyncVectorSearchFunction;
3538
import org.apache.flink.table.functions.UserDefinedFunction;
3639
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
3740
import org.apache.flink.table.functions.VectorSearchFunction;
3841
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
42+
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
3943
import org.apache.flink.table.planner.codegen.VectorSearchCodeGenerator;
4044
import org.apache.flink.table.planner.delegation.PlannerBase;
4145
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -55,9 +59,11 @@
5559
import org.apache.flink.table.runtime.collector.ListenableCollector;
5660
import org.apache.flink.table.runtime.generated.GeneratedCollector;
5761
import org.apache.flink.table.runtime.generated.GeneratedFunction;
62+
import org.apache.flink.table.runtime.operators.search.AsyncVectorSearchRunner;
5863
import org.apache.flink.table.runtime.operators.search.VectorSearchRunner;
5964
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
6065
import org.apache.flink.table.types.logical.RowType;
66+
import org.apache.flink.util.Preconditions;
6167

6268
import org.apache.calcite.plan.RelOptTable;
6369
import org.apache.calcite.rel.core.JoinRelType;
@@ -116,17 +122,27 @@ protected Transformation<RowData> translateToPlanInternal(
116122
// 3. build the operator
117123
RowType inputType = (RowType) inputEdge.getOutputType();
118124
RowType outputType = (RowType) getOutputType();
125+
DataTypeFactory dataTypeFactory =
126+
ShortcutUtils.unwrapContext(planner.getFlinkContext())
127+
.getCatalogManager()
128+
.getDataTypeFactory();
119129
StreamOperatorFactory<RowData> operatorFactory =
120130
isAsyncEnabled
121-
? createAsyncVectorSearchOperator()
131+
? createAsyncVectorSearchOperator(
132+
searchTable,
133+
config,
134+
planner.getFlinkContext().getClassLoader(),
135+
(AsyncVectorSearchFunction) vectorSearchFunction,
136+
dataTypeFactory,
137+
inputType,
138+
vectorSearchSpec.getOutputType(),
139+
outputType)
122140
: createSyncVectorSearchOperator(
123141
searchTable,
124142
config,
125143
planner.getFlinkContext().getClassLoader(),
126144
(VectorSearchFunction) vectorSearchFunction,
127-
ShortcutUtils.unwrapContext(planner.getFlinkContext())
128-
.getCatalogManager()
129-
.getDataTypeFactory(),
145+
dataTypeFactory,
130146
inputType,
131147
vectorSearchSpec.getOutputType(),
132148
outputType);
@@ -225,7 +241,49 @@ private ProcessFunction<RowData, RowData> createSyncVectorSearchFunction(
225241
searchOutputType.getFieldCount());
226242
}
227243

228-
private SimpleOperatorFactory<RowData> createAsyncVectorSearchOperator() {
229-
throw new UnsupportedOperationException("Async vector search is not supported yet.");
244+
@SuppressWarnings("unchecked")
245+
private StreamOperatorFactory<RowData> createAsyncVectorSearchOperator(
246+
RelOptTable searchTable,
247+
ExecNodeConfig config,
248+
ClassLoader jobClassLoader,
249+
AsyncVectorSearchFunction vectorSearchFunction,
250+
DataTypeFactory dataTypeFactory,
251+
RowType inputType,
252+
RowType searchOutputType,
253+
RowType outputType) {
254+
ArrayList<FunctionCallUtil.FunctionParam> parameters =
255+
new ArrayList<>(1 + vectorSearchSpec.getSearchColumns().size());
256+
parameters.add(vectorSearchSpec.getTopK());
257+
parameters.addAll(vectorSearchSpec.getSearchColumns().values());
258+
259+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
260+
generatedFetcher =
261+
VectorSearchCodeGenerator.generateAsyncVectorSearchFunction(
262+
config,
263+
jobClassLoader,
264+
dataTypeFactory,
265+
inputType,
266+
searchOutputType,
267+
outputType,
268+
parameters,
269+
vectorSearchFunction,
270+
((TableSourceTable) searchTable)
271+
.contextResolvedTable()
272+
.getIdentifier()
273+
.asSummaryString());
274+
275+
boolean isLeftOuterJoin = vectorSearchSpec.getJoinType() == JoinRelType.LEFT;
276+
277+
Preconditions.checkNotNull(asyncOptions, "Async Options can not be null.");
278+
279+
return new AsyncWaitOperatorFactory<>(
280+
new AsyncVectorSearchRunner(
281+
(GeneratedFunction) generatedFetcher.tableFunc(),
282+
isLeftOuterJoin,
283+
asyncOptions.asyncBufferCapacity,
284+
searchOutputType.getFieldCount()),
285+
asyncOptions.asyncTimeout,
286+
asyncOptions.asyncBufferCapacity,
287+
asyncOptions.asyncOutputMode);
230288
}
231289
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/VectorSearchCodeGenerator.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ package org.apache.flink.table.planner.codegen
1919

2020
import org.apache.flink.api.common.functions.FlatMapFunction
2121
import org.apache.flink.configuration.ReadableConfig
22+
import org.apache.flink.streaming.api.functions.async.AsyncFunction
2223
import org.apache.flink.table.catalog.DataTypeFactory
2324
import org.apache.flink.table.data.RowData
2425
import org.apache.flink.table.functions._
26+
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType
2527
import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil
2628
import org.apache.flink.table.planner.functions.inference.FunctionCallContext
2729
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam
@@ -68,6 +70,32 @@ object VectorSearchCodeGenerator {
6870
.tableFunc
6971
}
7072

73+
/** Generates a async vector search function ([[AsyncTableFunction]]) */
74+
def generateAsyncVectorSearchFunction(
75+
tableConfig: ReadableConfig,
76+
classLoader: ClassLoader,
77+
dataTypeFactory: DataTypeFactory,
78+
inputType: LogicalType,
79+
searchOutputType: LogicalType,
80+
outputType: LogicalType,
81+
searchColumns: util.List[FunctionParam],
82+
asyncVectorSearchFunction: AsyncTableFunction[_],
83+
functionName: String): GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
84+
FunctionCallCodeGenerator.generateAsyncFunctionCall(
85+
tableConfig,
86+
classLoader,
87+
dataTypeFactory,
88+
inputType,
89+
searchOutputType,
90+
outputType,
91+
searchColumns,
92+
asyncVectorSearchFunction,
93+
generateCallWithDataType(functionName, searchOutputType),
94+
functionName,
95+
"AsyncVectorSearchFunction"
96+
)
97+
}
98+
7199
private def generateCallWithDataType(
72100
functionName: String,
73101
searchOutputType: LogicalType

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesRuntimeFunctions.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.apache.flink.table.data.conversion.RowRowConverter;
5555
import org.apache.flink.table.data.utils.JoinedRowData;
5656
import org.apache.flink.table.functions.AsyncLookupFunction;
57+
import org.apache.flink.table.functions.AsyncVectorSearchFunction;
5758
import org.apache.flink.table.functions.FunctionContext;
5859
import org.apache.flink.table.functions.LookupFunction;
5960
import org.apache.flink.table.functions.VectorSearchFunction;
@@ -74,6 +75,8 @@
7475
import org.apache.flink.util.clock.RelativeClock;
7576
import org.apache.flink.util.clock.SystemClock;
7677

78+
import javax.annotation.Nullable;
79+
7780
import java.io.ByteArrayInputStream;
7881
import java.io.ByteArrayOutputStream;
7982
import java.io.IOException;
@@ -1171,4 +1174,51 @@ private double cosineDistance(double[] left, double[] right) {
11711174
return sum;
11721175
}
11731176
}
1177+
1178+
public static class TestValueAsyncVectorSearchFunction extends AsyncVectorSearchFunction {
1179+
1180+
private final TestValueVectorSearchFunction impl;
1181+
private final @Nullable Integer latency;
1182+
private transient ExecutorService executors;
1183+
private transient Random random;
1184+
1185+
public TestValueAsyncVectorSearchFunction(
1186+
List<Row> data,
1187+
int[] searchIndices,
1188+
DataType physicalRowType,
1189+
@Nullable Integer latency) {
1190+
this.impl = new TestValueVectorSearchFunction(data, searchIndices, physicalRowType);
1191+
this.latency = latency;
1192+
}
1193+
1194+
@Override
1195+
public void open(FunctionContext context) throws Exception {
1196+
super.open(context);
1197+
impl.open(context);
1198+
executors = Executors.newCachedThreadPool();
1199+
random = new Random();
1200+
}
1201+
1202+
@Override
1203+
public CompletableFuture<Collection<RowData>> asyncVectorSearch(
1204+
int topK, RowData queryData) {
1205+
return CompletableFuture.supplyAsync(
1206+
() -> {
1207+
try {
1208+
Thread.sleep(latency == null ? random.nextInt(1000) : latency);
1209+
return impl.vectorSearch(topK, queryData);
1210+
} catch (Exception e) {
1211+
throw new RuntimeException(e);
1212+
}
1213+
},
1214+
executors);
1215+
}
1216+
1217+
@Override
1218+
public void close() throws Exception {
1219+
super.close();
1220+
impl.close();
1221+
executors.shutdown();
1222+
}
1223+
}
11741224
}

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import org.apache.flink.table.connector.source.lookup.cache.LookupCache;
8484
import org.apache.flink.table.connector.source.lookup.cache.trigger.CacheReloadTrigger;
8585
import org.apache.flink.table.connector.source.lookup.cache.trigger.PeriodicCacheReloadTrigger;
86+
import org.apache.flink.table.connector.source.search.AsyncVectorSearchFunctionProvider;
8687
import org.apache.flink.table.connector.source.search.VectorSearchFunctionProvider;
8788
import org.apache.flink.table.data.GenericRowData;
8889
import org.apache.flink.table.data.RowData;
@@ -95,9 +96,11 @@
9596
import org.apache.flink.table.factories.FactoryUtil;
9697
import org.apache.flink.table.functions.AsyncLookupFunction;
9798
import org.apache.flink.table.functions.AsyncTableFunction;
99+
import org.apache.flink.table.functions.AsyncVectorSearchFunction;
98100
import org.apache.flink.table.functions.FunctionDefinition;
99101
import org.apache.flink.table.functions.LookupFunction;
100102
import org.apache.flink.table.functions.TableFunction;
103+
import org.apache.flink.table.functions.VectorSearchFunction;
101104
import org.apache.flink.table.legacy.api.TableSchema;
102105
import org.apache.flink.table.legacy.api.WatermarkSpec;
103106
import org.apache.flink.table.legacy.connector.source.AsyncTableFunctionProvider;
@@ -501,6 +504,14 @@ private static RowKind parseRowKind(String rowKindShortString) {
501504
"Option to specify the amount of time to sleep after processing every N elements. "
502505
+ "The default value is 0, which means that no sleep is performed");
503506

507+
public static final ConfigOption<Integer> LATENCY =
508+
ConfigOptions.key("latency")
509+
.intType()
510+
.noDefaultValue()
511+
.withDescription(
512+
"Latency in milliseconds for async vector search call for each row. "
513+
+ "If not set, the default is random between 0ms and 1000ms.");
514+
504515
/**
505516
* Parse partition list from Options with the format as
506517
* "key1:val1,key2:val2;key1:val3,key2:val4".
@@ -654,7 +665,9 @@ public DynamicTableSource createDynamicTableSource(Context context) {
654665
readableMetadata,
655666
null,
656667
parallelism,
657-
enableAggregatePushDown);
668+
enableAggregatePushDown,
669+
isAsync,
670+
helper.getOptions().get(LATENCY));
658671
}
659672

660673
if (disableLookup) {
@@ -888,7 +901,8 @@ public Set<ConfigOption<?>> optionalOptions() {
888901
FULL_CACHE_PERIODIC_RELOAD_INTERVAL,
889902
FULL_CACHE_PERIODIC_RELOAD_SCHEDULE_MODE,
890903
FULL_CACHE_TIMED_RELOAD_ISO_TIME,
891-
FULL_CACHE_TIMED_RELOAD_INTERVAL_IN_DAYS));
904+
FULL_CACHE_TIMED_RELOAD_INTERVAL_IN_DAYS,
905+
LATENCY));
892906
}
893907

894908
private static int validateAndExtractRowtimeIndex(
@@ -1054,7 +1068,7 @@ private static class TestValuesScanTableSourceWithoutProjectionPushDown
10541068
private @Nullable int[] groupingSet;
10551069
private List<AggregateExpression> aggregateExpressions;
10561070
private List<String> acceptedPartitionFilterFields;
1057-
private final Integer parallelism;
1071+
protected final Integer parallelism;
10581072

10591073
private TestValuesScanTableSourceWithoutProjectionPushDown(
10601074
DataType producedDataType,
@@ -2247,6 +2261,9 @@ private static class TestValuesVectorSearchTableSourceWithoutProjectionPushDown
22472261
extends TestValuesScanTableSourceWithoutProjectionPushDown
22482262
implements VectorSearchTableSource {
22492263

2264+
private final boolean isAsync;
2265+
@Nullable private final Integer latency;
2266+
22502267
private TestValuesVectorSearchTableSourceWithoutProjectionPushDown(
22512268
DataType producedDataType,
22522269
ChangelogMode changelogMode,
@@ -2266,7 +2283,9 @@ private TestValuesVectorSearchTableSourceWithoutProjectionPushDown(
22662283
Map<String, DataType> readableMetadata,
22672284
@Nullable int[] projectedMetadataFields,
22682285
@Nullable Integer parallelism,
2269-
boolean enableAggregatePushDown) {
2286+
boolean enableAggregatePushDown,
2287+
boolean isAsync,
2288+
@Nullable Integer latency) {
22702289
super(
22712290
producedDataType,
22722291
changelogMode,
@@ -2287,6 +2306,8 @@ private TestValuesVectorSearchTableSourceWithoutProjectionPushDown(
22872306
projectedMetadataFields,
22882307
parallelism,
22892308
enableAggregatePushDown);
2309+
this.isAsync = isAsync;
2310+
this.latency = latency;
22902311
}
22912312

22922313
@Override
@@ -2295,9 +2316,67 @@ public VectorSearchRuntimeProvider getSearchRuntimeProvider(VectorSearchContext
22952316
Arrays.stream(context.getSearchColumns()).mapToInt(k -> k[0]).toArray();
22962317
Collection<Row> rows =
22972318
data.getOrDefault(Collections.emptyMap(), Collections.emptyList());
2298-
return VectorSearchFunctionProvider.of(
2319+
TestValuesRuntimeFunctions.TestValueVectorSearchFunction searchFunction =
22992320
new TestValuesRuntimeFunctions.TestValueVectorSearchFunction(
2300-
new ArrayList<>(rows), searchColumns, producedDataType));
2321+
new ArrayList<>(rows), searchColumns, producedDataType);
2322+
2323+
if (isAsync) {
2324+
return new VectorFunctionProvider(
2325+
new TestValuesRuntimeFunctions.TestValueAsyncVectorSearchFunction(
2326+
new ArrayList<>(rows), searchColumns, producedDataType, latency),
2327+
searchFunction);
2328+
} else {
2329+
return VectorSearchFunctionProvider.of(searchFunction);
2330+
}
2331+
}
2332+
2333+
@Override
2334+
public DynamicTableSource copy() {
2335+
return new TestValuesVectorSearchTableSourceWithoutProjectionPushDown(
2336+
producedDataType,
2337+
changelogMode,
2338+
boundedness,
2339+
terminating,
2340+
runtimeSource,
2341+
failingSource,
2342+
data,
2343+
nestedProjectionSupported,
2344+
projectedPhysicalFields,
2345+
filterPredicates,
2346+
filterableFields,
2347+
dynamicFilteringFields,
2348+
numElementToSkip,
2349+
limit,
2350+
allPartitions,
2351+
readableMetadata,
2352+
projectedMetadataFields,
2353+
parallelism,
2354+
enableAggregatePushDown,
2355+
isAsync,
2356+
latency);
2357+
}
2358+
2359+
private static class VectorFunctionProvider
2360+
implements AsyncVectorSearchFunctionProvider, VectorSearchFunctionProvider {
2361+
2362+
private final AsyncVectorSearchFunction asyncFunction;
2363+
private final VectorSearchFunction syncFunction;
2364+
2365+
public VectorFunctionProvider(
2366+
AsyncVectorSearchFunction asyncFunction, VectorSearchFunction syncFunction) {
2367+
this.asyncFunction = asyncFunction;
2368+
this.syncFunction = syncFunction;
2369+
}
2370+
2371+
@Override
2372+
public AsyncVectorSearchFunction createAsyncVectorSearchFunction() {
2373+
return asyncFunction;
2374+
}
2375+
2376+
@Override
2377+
public VectorSearchFunction createVectorSearchFunction() {
2378+
return syncFunction;
2379+
}
23012380
}
23022381
}
23032382

0 commit comments

Comments
 (0)