Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.Collections;

/** Rule to convert VECTOR_SEARCH call with literal value to a correlated VECTOR_SEARCH call. */
public class ConstantVectorSearchCallToCorrelateRule
extends RelRule<
ConstantVectorSearchCallToCorrelateRule
.ConstantVectorSearchCallToCorrelateRuleConfig> {

public static final ConstantVectorSearchCallToCorrelateRule INSTANCE =
ConstantVectorSearchCallToCorrelateRuleConfig.DEFAULT.toRule();

private ConstantVectorSearchCallToCorrelateRule(
ConstantVectorSearchCallToCorrelateRuleConfig config) {
super(config);
}

@Override
public boolean matches(RelOptRuleCall call) {
LogicalTableFunctionScan scan = call.rel(0);
RexNode rexNode = scan.getCall();
if (!(rexNode instanceof RexCall)) {
return false;
}
RexCall rexCall = (RexCall) rexNode;
return rexCall.getOperator() instanceof SqlVectorSearchTableFunction
&& RexUtil.isConstant(rexCall.getOperands().get(2));
}

@Override
public void onMatch(RelOptRuleCall call) {
LogicalTableFunctionScan scan = call.rel(0);
RexCall functionCall = (RexCall) scan.getCall();
RexNode constantCall = functionCall.getOperands().get(2);
RelOptCluster cluster = scan.getCluster();
RelBuilder builder = call.builder();

// left side
LogicalValues values = LogicalValues.createOneRow(cluster);
builder.push(values);
builder.project(constantCall);

// right side
CorrelationId correlId = cluster.createCorrel();
RexNode correlRex =
cluster.getRexBuilder().makeCorrel(builder.peek().getRowType(), correlId);
RexNode correlatedConstant = cluster.getRexBuilder().makeFieldAccess(correlRex, 0);
builder.push(scan.getInput(0));
ArrayList<RexNode> operands = new ArrayList<>(functionCall.operands);
operands.set(2, correlatedConstant);
builder.functionScan(functionCall.getOperator(), 1, operands);

// add correlate node
builder.join(
JoinRelType.INNER,
cluster.getRexBuilder().makeLiteral(true),
Collections.singleton(correlId));

// prune useless value input
builder.projectExcept(builder.field(0));
call.transformTo(builder.build());
}

@Value.Immutable
public interface ConstantVectorSearchCallToCorrelateRuleConfig extends RelRule.Config {

ConstantVectorSearchCallToCorrelateRuleConfig DEFAULT =
ImmutableConstantVectorSearchCallToCorrelateRuleConfig.builder()
.build()
.withOperandSupplier(
b0 -> b0.operand(LogicalTableFunctionScan.class).anyInputs())
.withDescription("ConstantVectorSearchCallToCorrelateRule");

@Override
default ConstantVectorSearchCallToCorrelateRule toRule() {
return new ConstantVectorSearchCallToCorrelateRule(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ object FlinkStreamRuleSets {
// unnest rule
LogicalUnnestRule.INSTANCE,
UncollectToTableFunctionScanRule.INSTANCE,
// vector search rule.
ConstantVectorSearchCallToCorrelateRule.INSTANCE,
// rewrite constant table function scan to correlate
JoinTableFunctionScanToCorrelateRule.INSTANCE,
// Wrap arguments for JSON aggregate functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,14 @@ void testSimple() {
void testLiteralValue() {
String sql =
"SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
assertThatThrownBy(() -> util.verifyRelPlan(sql))
.satisfies(
FlinkAssertions.anyCauseMatches(
TableException.class,
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
+ "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])"));
util.verifyRelPlan(sql);
}

@Test
void testLiteralValueWithoutLateralKeyword() {
String sql =
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))";
assertThatThrownBy(() -> util.verifyRelPlan(sql))
.satisfies(
FlinkAssertions.anyCauseMatches(
TableException.class,
"FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n"
+ "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])"));
util.verifyRelPlan(sql);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.List;
import java.util.concurrent.TimeoutException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;

Expand Down Expand Up @@ -153,6 +154,19 @@ void testTimeout() {
TimeoutException.class, "Async function call has timed out."));
}

@TestTemplate
void testConstantValue() {
List<Row> actual =
CollectionUtil.iteratorToList(
tEnv().executeSql(
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
.collect());
assertThat(actual)
.containsExactlyInAnyOrder(
Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
Row.of(3L, new Float[] {8f, 15f, 17f}, 0.9977375565610862));
}

@TestTemplate
void testVectorSearchWithCalc() {
assertThatThrownBy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Arrays;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;

Expand Down Expand Up @@ -123,6 +124,19 @@ void testLeftLateralJoin() {
Row.of(4L, null, null, null, null));
}

@Test
void testConstantValue() {
List<Row> actual =
CollectionUtil.iteratorToList(
tEnv().executeSql(
"SELECT * FROM TABLE(VECTOR_SEARCH(TABLE vector, DESCRIPTOR(`vector`), ARRAY[5, 12, 13], 2))")
.collect());
assertThat(actual)
.containsExactlyInAnyOrder(
Row.of(1L, new Float[] {5.0f, 12.0f, 13.0f}, 1.0),
Row.of(3L, new Float[] {8f, 15f, 17f}, 0.9977375565610862));
}

@Test
void testVectorSearchWithCalc() {
assertThatThrownBy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,48 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
<Root>
<TestCase name="testLiteralValue">
<Resource name="sql">
<![CDATA[SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
+- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
+- LogicalProject(e=[$0], f=[$1], g=[$2])
+- LogicalTableScan(table=[[default_catalog, default_database, VectorTable]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
Calc(select=[e, f, g, score])
+- VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10], select=[$f0, e, f, g, score])
+- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
+- Values(tuples=[[{ 0 }]])
]]>
</Resource>
</TestCase>
<TestCase name="testLiteralValueWithoutLateralKeyword">
<Resource name="sql">
<![CDATA[SELECT * FROM TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(e=[$0], f=[$1], g=[$2], score=[$3])
+- LogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])
+- LogicalProject(e=[$0], f=[$1], g=[$2])
+- LogicalTableScan(table=[[default_catalog, default_database, VectorTable]])
]]>
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
Calc(select=[e, f, g, score])
+- VectorSearchTableFunction(table=[default_catalog.default_database.VectorTable], joinType=[InnerJoin], columnToSearch=[g], columnToQuery=[$f0], topK=[10], select=[$f0, e, f, g, score])
+- Calc(select=[ARRAY(1.5, 2.0) AS $f0])
+- Values(tuples=[[{ 0 }]])
]]>
</Resource>
</TestCase>
<TestCase name="testNameConflicts">
<Resource name="sql">
<![CDATA[SELECT * FROM QueryTable, LATERAL TABLE(
Expand Down