Skip to content

Commit d3680e4

Browse files
committed
[FLINK-38428][table] Support to run vector search with constant value input
1 parent 5215558 commit d3680e4

File tree

6 files changed

+188
-12
lines changed

6 files changed

+188
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.rules.logical;
20+
21+
import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;
22+
23+
import org.apache.calcite.plan.RelOptCluster;
24+
import org.apache.calcite.plan.RelOptRuleCall;
25+
import org.apache.calcite.plan.RelRule;
26+
import org.apache.calcite.rel.core.CorrelationId;
27+
import org.apache.calcite.rel.core.JoinRelType;
28+
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
29+
import org.apache.calcite.rel.logical.LogicalValues;
30+
import org.apache.calcite.rex.RexCall;
31+
import org.apache.calcite.rex.RexNode;
32+
import org.apache.calcite.rex.RexUtil;
33+
import org.apache.calcite.tools.RelBuilder;
34+
import org.immutables.value.Value;
35+
36+
import java.util.ArrayList;
37+
import java.util.Collections;
38+
39+
/** Rule to convert VECTOR_SEARCH call with literal value to a correlated VECTOR_SEARCH call. */
40+
public class ConstantVectorSearchCallToCorrelateRule
41+
extends RelRule<
42+
ConstantVectorSearchCallToCorrelateRule
43+
.ConstantVectorSearchCallToCorrelateRuleConfig> {
44+
45+
public static final ConstantVectorSearchCallToCorrelateRule INSTANCE =
46+
ConstantVectorSearchCallToCorrelateRuleConfig.DEFAULT.toRule();
47+
48+
private ConstantVectorSearchCallToCorrelateRule(
49+
ConstantVectorSearchCallToCorrelateRuleConfig config) {
50+
super(config);
51+
}
52+
53+
@Override
54+
public boolean matches(RelOptRuleCall call) {
55+
LogicalTableFunctionScan scan = call.rel(0);
56+
RexNode rexNode = scan.getCall();
57+
if (!(rexNode instanceof RexCall)) {
58+
return false;
59+
}
60+
RexCall rexCall = (RexCall) rexNode;
61+
return rexCall.getOperator() instanceof SqlVectorSearchTableFunction
62+
&& RexUtil.isConstant(rexCall.getOperands().get(2));
63+
}
64+
65+
@Override
66+
public void onMatch(RelOptRuleCall call) {
67+
LogicalTableFunctionScan scan = call.rel(0);
68+
RexCall functionCall = (RexCall) scan.getCall();
69+
RexNode constantCall = functionCall.getOperands().get(2);
70+
RelOptCluster cluster = scan.getCluster();
71+
RelBuilder builder = call.builder();
72+
73+
// left side
74+
LogicalValues values = LogicalValues.createOneRow(cluster);
75+
builder.push(values);
76+
builder.project(constantCall);
77+
78+
// right side
79+
CorrelationId correlId = cluster.createCorrel();
80+
RexNode correlRex =
81+
cluster.getRexBuilder().makeCorrel(builder.peek().getRowType(), correlId);
82+
RexNode correlatedConstant = cluster.getRexBuilder().makeFieldAccess(correlRex, 0);
83+
builder.push(scan.getInput(0));
84+
ArrayList<RexNode> operands = new ArrayList<>(functionCall.operands);
85+
operands.set(2, correlatedConstant);
86+
builder.functionScan(functionCall.getOperator(), 1, operands);
87+
88+
// add correlate node
89+
builder.join(
90+
JoinRelType.INNER,
91+
cluster.getRexBuilder().makeLiteral(true),
92+
Collections.singleton(correlId));
93+
94+
// prune useless value input
95+
builder.projectExcept(builder.field(0));
96+
call.transformTo(builder.build());
97+
}
98+
99+
@Value.Immutable
100+
public interface ConstantVectorSearchCallToCorrelateRuleConfig extends RelRule.Config {
101+
102+
ConstantVectorSearchCallToCorrelateRuleConfig DEFAULT =
103+
ImmutableConstantVectorSearchCallToCorrelateRuleConfig.builder()
104+
.build()
105+
.withOperandSupplier(
106+
b0 -> b0.operand(LogicalTableFunctionScan.class).anyInputs())
107+
.withDescription("ConstantVectorSearchCallToCorrelateRule");
108+
109+
@Override
110+
default ConstantVectorSearchCallToCorrelateRule toRule() {
111+
return new ConstantVectorSearchCallToCorrelateRule(this);
112+
}
113+
}
114+
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ object FlinkStreamRuleSets {
128128
// unnest rule
129129
LogicalUnnestRule.INSTANCE,
130130
UncollectToTableFunctionScanRule.INSTANCE,
131+
// vector search rule.
132+
ConstantVectorSearchCallToCorrelateRule.INSTANCE,
131133
// rewrite constant table function scan to correlate
132134
JoinTableFunctionScanToCorrelateRule.INSTANCE,
133135
// Wrap arguments for JSON aggregate functions

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,14 @@ void testSimple() {
115115
void testLiteralValue() {
116116
String sql =
117117
"SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
118-
assertThatThrownBy(() -> util.verifyRelPlan(sql))
119-
.satisfies(
120-
FlinkAssertions.anyCauseMatches(
121-
TableException.class,
122-
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
123-
+ "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])"));
118+
util.verifyRelPlan(sql);
124119
}
125120

126121
@Test
127122
void testLiteralValueWithoutLateralKeyword() {
128123
String sql =
129124
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
130-
assertThatThrownBy(() -> util.verifyRelPlan(sql))
131-
.satisfies(
132-
FlinkAssertions.anyCauseMatches(
133-
TableException.class,
134-
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
135-
+ "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])"));
125+
util.verifyRelPlan(sql);
136126
}
137127

138128
@Test

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncVectorSearchITCase.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.util.List;
3838
import java.util.concurrent.TimeoutException;
3939

40+
import static org.assertj.core.api.Assertions.assertThat;
4041
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4142
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
4243

@@ -153,6 +154,19 @@ void testTimeout() {
153154
TimeoutException.class, "Async function call has timed out."));
154155
}
155156

157+
@TestTemplate
158+
void testConstantValue() {
159+
List<Row> actual =
160+
CollectionUtil.iteratorToList(
161+
tEnv().executeSql(
162+
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
163+
.collect());
164+
assertThat(actual)
165+
.containsExactlyInAnyOrder(
166+
Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
167+
Row.of(3L, new Float[] {8f, 15f, 17f}, 0.9977375565610862));
168+
}
169+
156170
@TestTemplate
157171
void testVectorSearchWithCalc() {
158172
assertThatThrownBy(

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/VectorSearchITCase.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.Arrays;
3131
import java.util.List;
3232

33+
import static org.assertj.core.api.Assertions.assertThat;
3334
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3435
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
3536

@@ -123,6 +124,19 @@ void testLeftLateralJoin() {
123124
Row.of(4L, null, null, null, null));
124125
}
125126

127+
@Test
128+
void testConstantValue() {
129+
List<Row> actual =
130+
CollectionUtil.iteratorToList(
131+
tEnv().executeSql(
132+
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
133+
.collect());
134+
assertThat(actual)
135+
.containsExactlyInAnyOrder(
136+
Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
137+
Row.of(3L, new Float[] {8f, 15f, 17f}, 0.9977375565610862));
138+
}
139+
126140
@Test
127141
void testVectorSearchWithCalc() {
128142
assertThatThrownBy(

flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,48 @@ See the License for the specific language governing permissions and
1616
limitations under the License.
1717
-->
1818
<Root>
19+
<TestCase name="testLiteralValue">
20+
<Resource name="sql">
21+
<![CDATA[SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
22+
</Resource>
23+
<Resource name="ast">
24+
<![CDATA[
25+
LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
26+
+- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
27+
+- LogicalProject(e=[$0], f=[$1], g=[$2])
28+
+- LogicalTableScan(table=[[default_catalog, default_database, VectorTable]])
29+
]]>
30+
</Resource>
31+
<Resource name="optimized rel plan">
32+
<![CDATA[
33+
Calc(select=[e, f, g, score])
34+
+- VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10], select=[$f0, e, f, g, score])
35+
+- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
36+
+- Values(tuples=[[{ 0 }]])
37+
]]>
38+
</Resource>
39+
</TestCase>
40+
<TestCase name="testLiteralValueWithoutLateralKeyword">
41+
<Resource name="sql">
42+
<![CDATA[SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
43+
</Resource>
44+
<Resource name="ast">
45+
<![CDATA[
46+
LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
47+
+- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
48+
+- LogicalProject(e=[$0], f=[$1], g=[$2])
49+
+- LogicalTableScan(table=[[default_catalog, default_database, VectorTable]])
50+
]]>
51+
</Resource>
52+
<Resource name="optimized rel plan">
53+
<![CDATA[
54+
Calc(select=[e, f, g, score])
55+
+- VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10], select=[$f0, e, f, g, score])
56+
+- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
57+
+- Values(tuples=[[{ 0 }]])
58+
]]>
59+
</Resource>
60+
</TestCase>
1961
<TestCase name="testNameConflicts">
2062
<Resource name="sql">
2163
<![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(

0 commit comments

Comments
 (0)