Skip to content

feat(sql-udf): support recursive sql udf #844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ prost = "0.12"
pyo3 = { version = "0.20", features = ["extension-module"], optional = true }
ref-cast = "1.0"
regex = "1"
fancy-regex = "0.13"
risinglight_proto = "0.2"
rust_decimal = "1"
rustyline = "13"
Expand Down
34 changes: 34 additions & 0 deletions src/binder/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::fmt;
use std::str::FromStr;

use fancy_regex::Regex;
use pretty_xmlish::helper::delegate_fmt;
use pretty_xmlish::Pretty;
use serde::{Deserialize, Serialize};
Expand All @@ -18,6 +19,7 @@ pub struct CreateFunction {
pub return_type: crate::types::DataType,
pub language: String,
pub body: String,
pub is_recursive: bool,
}

impl fmt::Display for CreateFunction {
Expand Down Expand Up @@ -45,6 +47,35 @@ impl CreateFunction {
}
}

/// Find the pattern for recursive sql udf
/// return the exact index where the pattern first appears
/// Source: <https://github.com/risingwavelabs/risingwave/blob/a16a230a297aa620fdf6d04d7cd3f9e236f73fdd/src/frontend/src/handler/create_sql_function.rs#L89>
fn find_target(input: &str, target: &str) -> Option<usize> {
// Regex pattern to find `target` not preceded or followed by an ASCII letter
// The pattern uses negative lookbehind (?<!...) and lookahead (?!...) to ensure
// the target is not surrounded by ASCII alphabetic characters
let pattern = format!(r"(?<![A-Za-z]){0}(?![A-Za-z])", fancy_regex::escape(target));
let Ok(re) = Regex::new(&pattern) else {
return None;
};

let Ok(Some(ma)) = re.find(input) else {
return None;
};

Some(ma.start())
}

/// TODO: the current implementation is a "bit" hacky
/// I will try bring a more general & robust solution in the future
fn is_recursive(body: &str, func_name: &str) -> bool {
if let Some(_) = find_target(body, func_name) {
true
} else {
false
}
}

impl Binder {
pub(super) fn bind_create_function(
&mut self,
Expand Down Expand Up @@ -102,6 +133,8 @@ impl Binder {
arg_names.push(arg.name.map_or("".to_string(), |n| n.to_string()));
}

let is_recursive = is_recursive(&body, &name);

let f = self.egraph.add(Node::CreateFunction(CreateFunction {
schema_name,
name,
Expand All @@ -110,6 +143,7 @@ impl Binder {
return_type,
language,
body,
is_recursive,
}));

Ok(f)
Expand Down
16 changes: 16 additions & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ impl Binder {
fn bind_function(&mut self, func: Function) -> Result {
let mut args = vec![];
for arg in func.args.clone() {
println!("arg: {:#?}", arg);
// ignore argument name
let arg = match arg {
FunctionArg::Named { arg, .. } => arg,
Expand Down Expand Up @@ -332,6 +333,21 @@ impl Binder {
// See if the input function is sql udf
if let Some(ref function_catalog) = catalog.get_function_by_name(schema_name, function_name)
{
// For recursive sql udf, we will postpone its execution
// until reaching backend.
// a.k.a. this will not be *inlined* during binding phase
if function_catalog.is_recursive {
return Ok(self.egraph.add(Node::Udf(Udf {
// TODO: presumably there could be multiple arguments
// but for simplicity reason, currently only
// a single argument is supported
id: args[0],
name: function_catalog.name.clone(),
body: function_catalog.body.clone(),
return_type: function_catalog.return_type.clone(),
})));
}

// Create the brand new `udf_context`
let Ok(context) =
UdfContext::create_udf_context(func.args.as_slice(), function_catalog)
Expand Down
2 changes: 2 additions & 0 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ mod expr;
mod insert;
mod select;
mod table;
mod udf;

pub use self::create_function::*;
pub use self::create_table::*;
pub use self::udf::*;

pub type Result<T = Id> = std::result::Result<T, BindError>;

Expand Down
42 changes: 42 additions & 0 deletions src/binder/udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0.

use egg::Id;

use crate::types::DataType;
use pretty_xmlish::helper::delegate_fmt;
use pretty_xmlish::Pretty;
use std::str::FromStr;
use std::fmt;

/// currently represents recursive sql udf
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub struct Udf {
pub id: Id,
pub name: String,
pub body: String,
pub return_type: DataType,
}

impl fmt::Display for Udf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let explainer = Pretty::childless_record("Udf", self.pretty_function());
delegate_fmt(&explainer, f, String::with_capacity(1000))
}
}

impl FromStr for Udf {
type Err = ();

fn from_str(_s: &str) -> std::result::Result<Self, Self::Err> {
Err(())
}
}

impl Udf {
pub fn pretty_function<'a>(&self) -> Vec<(&'a str, Pretty<'a>)> {
vec![
("name", Pretty::display(&self.name)),
("body", Pretty::display(&self.body)),
]
}
}
6 changes: 6 additions & 0 deletions src/catalog/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub struct FunctionCatalog {
pub return_type: DataType,
pub language: String,
pub body: String,
pub is_recursive: bool,
}

impl FunctionCatalog {
Expand All @@ -20,6 +21,7 @@ impl FunctionCatalog {
return_type: DataType,
language: String,
body: String,
is_recursive: bool,
) -> Self {
Self {
name,
Expand All @@ -28,17 +30,21 @@ impl FunctionCatalog {
return_type,
language,
body,
is_recursive,
}
}

#[inline]
pub fn body(&self) -> String {
self.body.clone()
}

#[inline]
pub fn name(&self) -> String {
self.name.clone()
}

#[inline]
pub fn language(&self) -> String {
self.language.clone()
}
Expand Down
11 changes: 10 additions & 1 deletion src/catalog/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,20 @@ impl RootCatalog {
return_type: DataType,
language: String,
body: String,
is_recursive: bool,
) {
let schema_idx = self.get_schema_id_by_name(&schema_name).unwrap();
let mut inner = self.inner.lock().unwrap();
let schema = inner.schemas.get_mut(&schema_idx).unwrap();
schema.create_function(name, arg_types, arg_names, return_type, language, body);
schema.create_function(
name,
arg_types,
arg_names,
return_type,
language,
body,
is_recursive,
);
}

pub const DEFAULT_SCHEMA_NAME: &'static str = "postgres";
Expand Down
2 changes: 2 additions & 0 deletions src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ impl SchemaCatalog {
return_type: DataType,
language: String,
body: String,
is_recursive: bool,
) {
self.functions.insert(
name.clone(),
Expand All @@ -129,6 +130,7 @@ impl SchemaCatalog {
return_type,
language,
body,
is_recursive,
}),
);
}
Expand Down
2 changes: 2 additions & 0 deletions src/executor/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl CreateFunctionExecutor {
return_type,
language,
body,
is_recursive,
} = self.f;

self.catalog.create_function(
Expand All @@ -31,6 +32,7 @@ impl CreateFunctionExecutor {
return_type,
language,
body,
is_recursive,
);
}
}
6 changes: 6 additions & 0 deletions src/executor/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::fmt;
use egg::{Id, Language};

use crate::array::*;
use crate::executor::udf::UdfExecutor;
use crate::planner::{Expr, RecExpr};
use crate::types::{ConvertError, DataValue};

Expand Down Expand Up @@ -129,6 +130,11 @@ impl<'a> Evaluator<'a> {
};
a.replace(from, to)
}
// recursive sql udf's actual backend logic
Udf(udf) => UdfExecutor {
udf: udf.clone(),
}
.execute(chunk),
e => {
if let Some((op, a, b)) = e.binary_op() {
let left = self.next(a).eval(chunk)?;
Expand Down
3 changes: 2 additions & 1 deletion src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ mod table_scan;
mod top_n;
mod values;
mod window;
mod udf;

/// The maximum chunk length produced by executor at a time.
const PROCESSING_WINDOW_SIZE: usize = 1024;
Expand Down Expand Up @@ -352,7 +353,7 @@ impl<S: Storage> Builder<S> {

CreateFunction(f) => CreateFunctionExecutor {
f,
catalog: self.optimizer.catalog().clone(),
catalog: self.catalog().clone(),
}
.execute(),

Expand Down
2 changes: 2 additions & 0 deletions src/executor/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ impl ProjectionExecutor {
pub async fn execute(self, child: BoxedExecutor) {
#[for_await]
for batch in child {
println!("[project]\n{}", batch.clone().unwrap());
println!("projs: {:#?}", self.projs);
yield Evaluator::new(&self.projs).eval_list(&batch?)?;
}
}
Expand Down
16 changes: 16 additions & 0 deletions src/executor/udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0.

use super::*;
use crate::{array::ArrayImpl, binder::Udf, types::ConvertError};

/// The executor of (recursive) sql udf
pub struct UdfExecutor {
pub udf: Udf,
}

impl UdfExecutor {
pub fn execute(&self, chunk: &DataChunk) -> std::result::Result<ArrayImpl, ConvertError> {
println!("udf\n{}", chunk);
Ok(ArrayImpl::new_null((0..1).map(|_| ()).collect()))
}
}
4 changes: 4 additions & 0 deletions src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ impl<'a> Explain<'a> {
vec![].with(cost, rows),
vec![self.child(child).pretty()],
),
Udf(udf) => {
let v = udf.pretty_function();
Pretty::childless_record("Udf", v)
}
Empty(_) => Pretty::childless_record("Empty", vec![].with(cost, rows)),
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use egg::{define_language, Id, Symbol};

use crate::binder::copy::ExtSource;
use crate::binder::{CreateFunction, CreateTable};
use crate::binder::{CreateFunction, CreateTable, Udf};
use crate::catalog::{ColumnRefId, TableRefId};
use crate::parser::{BinaryOperator, UnaryOperator};
use crate::types::{ColumnIndex, DataType, DataValue, DateTimeField};
Expand Down Expand Up @@ -131,6 +131,9 @@ define_language! {
// with the same schema as `child`

Symbol(Symbol),

// currently only used by recursive sql udf
Udf(Udf),
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/planner/rules/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ pub fn analyze_type(enode: &Expr, x: impl Fn(&Id) -> Type, catalog: &RootCatalog
Ok(DataType::Struct(types))
}

// currently for recursive sql udf's type inference
Udf(udf) => Ok(udf.return_type.clone()),

// other plan nodes
_ => Err(TypeError::Unavailable(enode.to_string())),
}
Expand Down