From beeb6616cdcf75a5e7528ef49d3526e0cc00af00 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Tue, 22 Apr 2025 10:01:41 +0200 Subject: [PATCH 01/10] fix: sql fn params --- crates/pgt_typecheck/src/lib.rs | 16 +++- crates/pgt_typecheck/src/typed_identifier.rs | 18 ++++ .../src/workspace/server/document.rs | 7 ++ .../src/workspace/server/parsed_document.rs | 26 ++++- .../src/workspace/server/sql_function.rs | 96 ++++++++++++++++++- .../workspace/server/statement_identifier.rs | 15 +++ 6 files changed, 172 insertions(+), 6 deletions(-) create mode 100644 crates/pgt_typecheck/src/typed_identifier.rs diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index f741c0e6..0dcb199b 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -1,11 +1,13 @@ mod diagnostics; +mod typed_identifier; -pub use diagnostics::TypecheckDiagnostic; use diagnostics::create_type_error; +pub use diagnostics::TypecheckDiagnostic; use pgt_text_size::TextRange; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; use sqlx::{Executor, PgPool}; +use typed_identifier::{apply_identifiers, TypedIdentifier}; #[derive(Debug)] pub struct TypecheckParams<'a> { @@ -13,6 +15,9 @@ pub struct TypecheckParams<'a> { pub sql: &'a str, pub ast: &'a pgt_query_ext::NodeEnum, pub tree: &'a tree_sitter::Tree, + pub schema_cache: &'a pgt_schema_cache::SchemaCache, + pub cst: &'a tree_sitter::Node<'a>, + pub identifiers: Vec, } #[derive(Debug, Clone)] @@ -51,7 +56,14 @@ pub async fn check_sql( // each typecheck operation. conn.close_on_drop(); - let res = conn.prepare(params.sql).await; + let prepared = apply_identifiers( + params.identifiers, + params.schema_cache, + params.cst, + params.sql, + ); + + let res = conn.prepare(prepared).await; match res { Ok(_) => Ok(None), diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs new file mode 100644 index 00000000..4c21d982 --- /dev/null +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -0,0 +1,18 @@ +#[derive(Debug)] +pub struct TypedIdentifier { + pub schema: Option, + pub relation: String, + pub name: String, + pub type_: String, +} + +/// Applies the identifiers to the SQL string by replacing them with their default values. +pub fn apply_identifiers<'a>( + identifiers: Vec, + schema_cache: &'a pgt_schema_cache::SchemaCache, + cst: &'a tree_sitter::Node<'a>, + sql: &'a str, +) -> &'a str { + // TODO + sql +} diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index f2c500cc..4fe8b208 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -34,6 +34,13 @@ impl Document { } } + pub fn statement_content(&self, id: &StatementId) -> Option<&str> { + self.positions + .iter() + .find(|(statement_id, _)| statement_id == id) + .map(|(_, range)| &self.content[*range]) + } + /// Returns true if there is at least one fatal error in the diagnostics /// /// A fatal error is a scan error that prevents the document from being used diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs index 01f18d3c..01531b2c 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -12,7 +12,7 @@ use super::{ change::StatementChange, document::{Document, StatementIterator}, pg_query::PgQueryStore, - sql_function::SQLFunctionBodyStore, + sql_function::{SQLFunctionBodyStore, SQLFunctionSignature}, statement_identifier::StatementId, tree_sitter::TreeSitterStore, }; @@ -274,6 +274,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { String, Option, Arc, + Option>, ); fn map( @@ -293,7 +294,26 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { let cst_result = parser.cst_db.get_or_cache_tree(&id, &content_owned); - (id, range, content_owned, ast_option, cst_result) + let sql_fn_sig = id + .parent() + .and_then(|root| { + let c = parser.doc.statement_content(&root)?; + Some((root, c)) + }) + .and_then(|(root, c)| { + let ast_option = parser + .ast_db + .get_or_cache_ast(&root, c) + .as_ref() + .clone() + .ok(); + + let ast_option = ast_option.as_ref()?; + + parser.sql_fn_db.get_function_signature(&root, ast_option) + }); + + (id, range, content_owned, ast_option, cst_result, sql_fn_sig) } } @@ -413,7 +433,7 @@ mod tests { #[test] fn sql_function_body() { - let input = "CREATE FUNCTION add(integer, integer) RETURNS integer + let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 777210d5..fef73121 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -5,6 +5,18 @@ use pgt_text_size::TextRange; use super::statement_identifier::StatementId; +#[derive(Debug, Clone)] +pub struct SQLFunctionArgs { + pub name: Option, + pub type_: (Option, String), +} + +#[derive(Debug, Clone)] +pub struct SQLFunctionSignature { + pub name: (Option, String), + pub args: Vec, +} + #[derive(Debug, Clone)] pub struct SQLFunctionBody { pub range: TextRange, @@ -13,11 +25,33 @@ pub struct SQLFunctionBody { pub struct SQLFunctionBodyStore { db: DashMap>>, + sig_db: DashMap>>, } impl SQLFunctionBodyStore { pub fn new() -> SQLFunctionBodyStore { - SQLFunctionBodyStore { db: DashMap::new() } + SQLFunctionBodyStore { + db: DashMap::new(), + sig_db: DashMap::new(), + } + } + + pub fn get_function_signature( + &self, + statement: &StatementId, + ast: &pgt_query_ext::NodeEnum, + ) -> Option> { + // First check if we already have this statement cached + if let Some(existing) = self.sig_db.get(statement).map(|x| x.clone()) { + return existing; + } + + // If not cached, try to extract it from the AST + let fn_sig = get_sql_fn_signature(ast).map(Arc::new); + + // Cache the result and return it + self.sig_db.insert(statement.clone(), fn_sig.clone()); + fn_sig } pub fn get_function_body( @@ -48,6 +82,48 @@ impl SQLFunctionBodyStore { } } +/// Extracts SQL function signature from a CreateFunctionStmt node. +fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option { + let create_fn = match ast { + pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, + _ => return None, + }; + + println!("create_fn: {:?}", create_fn); + + // Extract language from function options + let language = find_option_value(create_fn, "language")?; + + // Only process SQL functions + if language != "sql" { + return None; + } + + let fn_name = parse_name(&create_fn.funcname)?; + + // we return None if anything is not expected + let mut fn_args = Vec::new(); + for arg in &create_fn.parameters { + if let Some(pgt_query_ext::NodeEnum::FunctionParameter(node)) = &arg.node { + let arg_name = (!node.name.is_empty()).then_some(node.name.clone()); + + let type_name = parse_name(&node.arg_type.as_ref().unwrap().names)?; + + fn_args.push(SQLFunctionArgs { + name: arg_name, + type_: type_name, + }); + } else { + return None; + } + } + + Some(SQLFunctionSignature { + name: fn_name, + args: fn_args, + }) +} + /// Extracts SQL function body and its text range from a CreateFunctionStmt node. /// Returns None if the function is not an SQL function or if the body can't be found. fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option { @@ -56,6 +132,8 @@ fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option return None, }; + println!("create_fn: {:?}", create_fn); + // Extract language from function options let language = find_option_value(create_fn, "language")?; @@ -120,3 +198,19 @@ fn find_option_value( } }) } + +fn parse_name(nodes: &Vec) -> Option<(Option, String)> { + let names = nodes + .iter() + .map(|n| match &n.node { + Some(pgt_query_ext::NodeEnum::String(s)) => Some(s.sval.clone()), + _ => None, + }) + .collect::>(); + + match names.as_slice() { + [Some(schema), Some(name)] => Some((Some(schema.clone()), name.clone())), + [Some(name)] => Some((None, name.clone())), + _ => None, + } +} diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 8c02814d..7c7d76f0 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -57,6 +57,21 @@ impl StatementId { StatementId::Child(s) => s.inner, } } + + pub fn is_root(&self) -> bool { + matches!(self, StatementId::Root(_)) + } + + pub fn is_child(&self) -> bool { + matches!(self, StatementId::Child(_)) + } + + pub fn parent(&self) -> Option { + match self { + StatementId::Root(_) => None, + StatementId::Child(id) => Some(StatementId::Root(id.clone())), + } + } } /// Helper struct to generate unique statement ids From 9b6c7aa7722a76e78e62651d67c334e02dc3eb52 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 25 Apr 2025 08:45:52 +0200 Subject: [PATCH 02/10] save progress --- crates/pgt_typecheck/src/lib.rs | 6 +- crates/pgt_typecheck/src/typed_identifier.rs | 36 ++++++++- crates/pgt_typecheck/tests/diagnostics.rs | 82 ++++++++++++++------ crates/pgt_workspace/src/workspace/server.rs | 14 ++-- 4 files changed, 104 insertions(+), 34 deletions(-) diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 0dcb199b..8b05ca13 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -3,11 +3,12 @@ mod typed_identifier; use diagnostics::create_type_error; pub use diagnostics::TypecheckDiagnostic; +pub use typed_identifier::TypedIdentifier; use pgt_text_size::TextRange; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; use sqlx::{Executor, PgPool}; -use typed_identifier::{apply_identifiers, TypedIdentifier}; +use typed_identifier::apply_identifiers; #[derive(Debug)] pub struct TypecheckParams<'a> { @@ -16,7 +17,6 @@ pub struct TypecheckParams<'a> { pub ast: &'a pgt_query_ext::NodeEnum, pub tree: &'a tree_sitter::Tree, pub schema_cache: &'a pgt_schema_cache::SchemaCache, - pub cst: &'a tree_sitter::Node<'a>, pub identifiers: Vec, } @@ -59,7 +59,7 @@ pub async fn check_sql( let prepared = apply_identifiers( params.identifiers, params.schema_cache, - params.cst, + params.tree, params.sql, ); diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 4c21d982..4e526b15 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,18 +1,48 @@ +#[derive(Debug)] +pub struct Type { + pub schema: Option, + pub name: String, + pub oid: i32, +} + #[derive(Debug)] pub struct TypedIdentifier { pub schema: Option, - pub relation: String, + pub relation: Option, pub name: String, - pub type_: String, + pub type_: Type, +} + +impl TypedIdentifier { + pub fn new( + schema: Option, + relation: Option, + name: String, + type_: Type, + ) -> Self { + TypedIdentifier { + schema, + relation, + name, + type_, + } + } + + pub fn default_value(&self, schema_cache: &pgt_schema_cache::SchemaCache) -> String { + "NULL".to_string() + } } /// Applies the identifiers to the SQL string by replacing them with their default values. pub fn apply_identifiers<'a>( identifiers: Vec, schema_cache: &'a pgt_schema_cache::SchemaCache, - cst: &'a tree_sitter::Node<'a>, + cst: &'a tree_sitter::Tree, sql: &'a str, ) -> &'a str { // TODO + println!("Applying identifiers to SQL: {}", sql); + println!("Identifiers: {:?}", identifiers); + println!("CST: {:#?}", cst); sql } diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 4c780d74..0c31a701 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -4,22 +4,28 @@ use pgt_console::{ }; use pgt_diagnostics::PrintDiagnostic; use pgt_test_utils::test_database::get_new_test_db; -use pgt_typecheck::{TypecheckParams, check_sql}; +use pgt_typecheck::{check_sql, TypecheckParams}; use sqlx::Executor; -async fn test(name: &str, query: &str, setup: &str) { +async fn test(name: &str, query: &str, setup: Option<&str>) { let test_db = get_new_test_db().await; - test_db - .execute(setup) - .await - .expect("Failed to setup test database"); + if let Some(setup) = setup { + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + } let mut parser = tree_sitter::Parser::new(); parser .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); + let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + let root = pgt_query_ext::parse(query).unwrap(); let tree = parser.parse(query, None).unwrap(); @@ -29,25 +35,27 @@ async fn test(name: &str, query: &str, setup: &str) { sql: query, ast: &root, tree: &tree, + schema_cache: &schema_cache, + identifiers: vec![], }) .await; - let mut content = vec![]; - let mut writer = HTML::new(&mut content); - - Formatter::new(&mut writer) - .write_markup(markup! { - {PrintDiagnostic::simple(&result.unwrap().unwrap())} - }) - .unwrap(); - - let content = String::from_utf8(content).unwrap(); + // let mut content = vec![]; + // let mut writer = HTML::new(&mut content); - insta::with_settings!({ - prepend_module_to_snapshot => false, - }, { - insta::assert_snapshot!(name, content); - }); + // Formatter::new(&mut writer) + // .write_markup(markup! { + // {PrintDiagnostic::simple(&result.unwrap().unwrap())} + // }) + // .unwrap(); + // + // let content = String::from_utf8(content).unwrap(); + // + // insta::with_settings!({ + // prepend_module_to_snapshot => false, + // }, { + // insta::assert_snapshot!(name, content); + // }); } #[tokio::test] @@ -55,7 +63,8 @@ async fn invalid_column() { test( "invalid_column", "select id, unknown from contacts;", - r#" + Some( + r#" create table public.contacts ( id serial primary key, name varchar(255) not null, @@ -63,6 +72,35 @@ async fn invalid_column() { middle_name varchar(255) ); "#, + ), + ) + .await; +} + +#[tokio::test] +async fn sql_fn() { + test( + "sql_fn", + "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;", + Some(""), + ) + .await; +} + +#[tokio::test] +async fn sql_fn_named() { + test( + "sql_fn", + "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer + AS 'select test0 + test1;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;", + Some(""), ) .await; } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 3bf540cc..f8fee4d2 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -5,7 +5,7 @@ use async_helper::run_async; use dashmap::DashMap; use db_connection::DbConnection; use document::Document; -use futures::{StreamExt, stream}; +use futures::{stream, StreamExt}; use parsed_document::{ AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, ParsedDocument, SyncDiagnosticsMapper, @@ -13,26 +13,26 @@ use parsed_document::{ use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pgt_diagnostics::{ - Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, + serde::Diagnostic as SDiagnostic, Diagnostic, DiagnosticExt, Error, Severity, }; use pgt_fs::{ConfigName, PgTPath}; -use pgt_typecheck::TypecheckParams; +use pgt_typecheck::{TypecheckParams, TypedIdentifier}; use schema_cache_manager::SchemaCacheManager; use sqlx::Executor; use tracing::info; use crate::{ - WorkspaceError, configuration::to_analyser_rules, features::{ code_actions::{ self, CodeAction, CodeActionKind, CodeActionsResult, CommandAction, CommandActionCategory, ExecuteStatementParams, ExecuteStatementResult, }, - completions::{CompletionsResult, GetCompletionsParams, get_statement_for_completions}, + completions::{get_statement_for_completions, CompletionsResult, GetCompletionsParams}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, }, settings::{Settings, SettingsHandle, SettingsHandleMut}, + WorkspaceError, }; use super::{ @@ -370,7 +370,7 @@ impl Workspace for WorkspaceServer { let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); let async_results = run_async(async move { stream::iter(input) - .map(|(_id, range, content, ast, cst)| { + .map(|(_id, range, content, ast, cst, sign)| { let pool = pool.clone(); let path = path_clone.clone(); async move { @@ -380,6 +380,8 @@ impl Workspace for WorkspaceServer { sql: &content, ast: &ast, tree: &cst, + schema_cache: &self.schema_cache, + identifiers: sign.map(|s| TypedIdentifier {}).unwrap_or_default(), }) .await .map(|d| { From ba832021150f963943de2f055999359ee087fcc9 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 25 Apr 2025 22:44:28 +0200 Subject: [PATCH 03/10] add ts query --- crates/pgt_treesitter_queries/src/lib.rs | 46 ++++++- .../pgt_treesitter_queries/src/queries/mod.rs | 13 ++ .../src/queries/parameters.rs | 124 ++++++++++++++++++ crates/pgt_typecheck/src/lib.rs | 4 +- crates/pgt_typecheck/src/typed_identifier.rs | 113 ++++++++++++++-- crates/pgt_typecheck/tests/diagnostics.rs | 60 +++------ crates/pgt_workspace/src/workspace/server.rs | 108 +++++++-------- .../src/workspace/server/parsed_document.rs | 17 +-- .../src/workspace/server/sql_function.rs | 99 ++++---------- 9 files changed, 387 insertions(+), 197 deletions(-) create mode 100644 crates/pgt_treesitter_queries/src/queries/parameters.rs diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index 7d2ba61b..4bd7c01d 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -68,7 +68,10 @@ impl<'a> Iterator for QueryResultIter<'a> { #[cfg(test)] mod tests { - use crate::{TreeSitterQueriesExecutor, queries::RelationMatch}; + use crate::{ + TreeSitterQueriesExecutor, + queries::{ParameterMatch, RelationMatch}, + }; #[test] fn finds_all_relations_and_ignores_functions() { @@ -137,11 +140,11 @@ where select * from ( - select * + select * from ( select * from private.something - ) as sq2 + ) as sq2 join private.tableau pt1 on sq2.id = pt1.id ) as sq1 @@ -185,4 +188,41 @@ on sq1.id = pt.id; assert_eq!(results[0].get_schema(sql), Some("private".into())); assert_eq!(results[0].get_table(sql), "something"); } + + #[test] + fn extracts_parameters() { + let sql = r#"select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&ParameterMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 4); + + assert_eq!(results[0].get_root(sql), None); + assert_eq!(results[0].get_path(sql), None); + assert_eq!(results[0].get_field(sql), "v_test"); + + assert_eq!(results[1].get_root(sql), Some("fn_name".into())); + assert_eq!(results[1].get_path(sql), Some("custom_type".into())); + assert_eq!(results[1].get_field(sql), "v_test2"); + + assert_eq!(results[2].get_root(sql), None); + assert_eq!(results[2].get_path(sql), None); + assert_eq!(results[2].get_field(sql), "$3"); + + assert_eq!(results[3].get_root(sql), None); + assert_eq!(results[3].get_path(sql), Some("custom_type".into())); + assert_eq!(results[3].get_field(sql), "v_test3"); + } } diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 98b55e03..62924e00 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -1,10 +1,13 @@ +mod parameters; mod relations; +pub use parameters::*; pub use relations::*; #[derive(Debug)] pub enum QueryResult<'a> { Relation(RelationMatch<'a>), + Parameter(ParameterMatch<'a>), } impl QueryResult<'_> { @@ -18,6 +21,16 @@ impl QueryResult<'_> { let end = rm.table.end_position(); + start >= range.start_point && end <= range.end_point + } + Self::Parameter(pm) => { + let start = match pm.root { + Some(s) => s.start_position(), + None => pm.path.as_ref().unwrap().start_position(), + }; + + let end = pm.field.end_position(); + start >= range.start_point && end <= range.end_point } } diff --git a/crates/pgt_treesitter_queries/src/queries/parameters.rs b/crates/pgt_treesitter_queries/src/queries/parameters.rs new file mode 100644 index 00000000..b0ffaee2 --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/parameters.rs @@ -0,0 +1,124 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" +[ + (field + (identifier)) @reference + (field + (object_reference) + "." (identifier)) @reference + (parameter) @parameter +] +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct ParameterMatch<'a> { + pub(crate) root: Option>, + pub(crate) path: Option>, + + pub(crate) field: tree_sitter::Node<'a>, +} + +impl ParameterMatch<'_> { + pub fn get_root(&self, sql: &str) -> Option { + let str = self + .root + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get schema from RelationMatch"); + + Some(str.to_string()) + } + + pub fn get_path(&self, sql: &str) -> Option { + let str = self + .path + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get table from RelationMatch"); + + Some(str.to_string()) + } + + pub fn get_field(&self, sql: &str) -> String { + self.field + .utf8_text(sql.as_bytes()) + .expect("Failed to get table from RelationMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ParameterMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::Parameter(r) => Ok(r), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for ParameterMatch<'a> { + type Ref = &'a ParameterMatch<'a>; +} + +impl<'a> Query<'a> for ParameterMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + matches + .filter_map(|m| { + let captures = m.captures; + + // We expect exactly one capture for a parameter + if captures.len() != 1 { + return None; + } + + let field = captures[0].node; + let text = match field.utf8_text(stmt.as_bytes()) { + Ok(t) => t, + Err(_) => return None, + }; + let parts: Vec<&str> = text.split('.').collect(); + + let param_match = match parts.len() { + // Simple field: field_name + 1 => ParameterMatch { + root: None, + path: None, + field, + }, + // Table qualified: table.field_name + 2 => ParameterMatch { + root: None, + path: field.named_child(0), + field: field.named_child(1)?, + }, + // Fully qualified: schema.table.field_name + 3 => ParameterMatch { + root: field.named_child(0).and_then(|n| n.named_child(0)), + path: field.named_child(0).and_then(|n| n.named_child(1)), + field: field.named_child(1)?, + }, + // Unexpected number of parts + _ => return None, + }; + + Some(QueryResult::Parameter(param_match)) + }) + .collect() + } +} diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 8b05ca13..8158e82c 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -1,13 +1,13 @@ mod diagnostics; mod typed_identifier; -use diagnostics::create_type_error; pub use diagnostics::TypecheckDiagnostic; -pub use typed_identifier::TypedIdentifier; +use diagnostics::create_type_error; use pgt_text_size::TextRange; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; use sqlx::{Executor, PgPool}; +pub use typed_identifier::TypedIdentifier; use typed_identifier::apply_identifiers; #[derive(Debug)] diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 4e526b15..bbba6e79 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,16 +1,9 @@ -#[derive(Debug)] -pub struct Type { - pub schema: Option, - pub name: String, - pub oid: i32, -} - #[derive(Debug)] pub struct TypedIdentifier { pub schema: Option, pub relation: Option, pub name: String, - pub type_: Type, + pub type_: (Option, String), } impl TypedIdentifier { @@ -18,7 +11,7 @@ impl TypedIdentifier { schema: Option, relation: Option, name: String, - type_: Type, + type_: (Option, String), ) -> Self { TypedIdentifier { schema, @@ -44,5 +37,107 @@ pub fn apply_identifiers<'a>( println!("Applying identifiers to SQL: {}", sql); println!("Identifiers: {:?}", identifiers); println!("CST: {:#?}", cst); + sql } + +#[cfg(test)] +mod tests { + use pgt_test_utils::test_database::get_new_test_db; + use sqlx::Executor; + + #[tokio::test] + async fn test_apply_identifiers() { + let input = "select v_test + fn_name.custom_type.v_test2 + $3 + test.field;"; + + let test_db = get_new_test_db().await; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + + let root = pgt_query_ext::parse(input).unwrap(); + let tree = parser.parse(input, None).unwrap(); + + println!("Parsed SQL: {:?}", root); + println!("Parsed CST: {:?}", tree); + + // let mut parameters = Vec::new(); + + enum Parameter { + Identifier { + range: (usize, usize), + name: (Option, String), + }, + Parameter { + range: (usize, usize), + idx: usize, + }, + } + + let mut c = tree.walk(); + + 'outer: loop { + // 0. Add the current node to the map. + println!("Current node: {:?}", c.node()); + match c.node().kind() { + "identifier" => { + println!( + "Found identifier: {:?}", + c.node().utf8_text(input.as_bytes()).unwrap() + ); + } + "parameter" => { + println!( + "Found parameter: {:?}", + c.node().utf8_text(input.as_bytes()).unwrap() + ); + } + "object_reference" => { + println!( + "Found object reference: {:?}", + c.node().utf8_text(input.as_bytes()).unwrap() + ); + + // let source = self.text; + // ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { + // if SanitizedCompletionParams::is_sanitized_token(txt) { + // NodeText::Replaced + // } else { + // NodeText::Original(txt) + // } + // }) + } + _ => {} + } + + // 1. Go to its child and continue. + if c.goto_first_child() { + continue 'outer; + } + + // 2. We've reached a leaf (node without a child). We will go to a sibling. + if c.goto_next_sibling() { + continue 'outer; + } + + // 3. If there are no more siblings, we need to go back up. + 'inner: loop { + // 4. Check if we've reached the root node. If so, we're done. + if !c.goto_parent() { + break 'outer; + } + // 5. Go to the previous node's sibling. + if c.goto_next_sibling() { + // And break out of the inner loop. + break 'inner; + } + } + } + } +} diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 0c31a701..9628962d 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -4,7 +4,7 @@ use pgt_console::{ }; use pgt_diagnostics::PrintDiagnostic; use pgt_test_utils::test_database::get_new_test_db; -use pgt_typecheck::{check_sql, TypecheckParams}; +use pgt_typecheck::{TypecheckParams, check_sql}; use sqlx::Executor; async fn test(name: &str, query: &str, setup: Option<&str>) { @@ -40,22 +40,22 @@ async fn test(name: &str, query: &str, setup: Option<&str>) { }) .await; - // let mut content = vec![]; - // let mut writer = HTML::new(&mut content); + let mut content = vec![]; + let mut writer = HTML::new(&mut content); - // Formatter::new(&mut writer) - // .write_markup(markup! { - // {PrintDiagnostic::simple(&result.unwrap().unwrap())} - // }) - // .unwrap(); - // - // let content = String::from_utf8(content).unwrap(); - // - // insta::with_settings!({ - // prepend_module_to_snapshot => false, - // }, { - // insta::assert_snapshot!(name, content); - // }); + Formatter::new(&mut writer) + .write_markup(markup! { + {PrintDiagnostic::simple(&result.unwrap().unwrap())} + }) + .unwrap(); + + let content = String::from_utf8(content).unwrap(); + + insta::with_settings!({ + prepend_module_to_snapshot => false, + }, { + insta::assert_snapshot!(name, content); + }); } #[tokio::test] @@ -76,31 +76,3 @@ async fn invalid_column() { ) .await; } - -#[tokio::test] -async fn sql_fn() { - test( - "sql_fn", - "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer - AS 'select $1 + $2;' - LANGUAGE SQL - IMMUTABLE - RETURNS NULL ON NULL INPUT;", - Some(""), - ) - .await; -} - -#[tokio::test] -async fn sql_fn_named() { - test( - "sql_fn", - "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer - AS 'select test0 + test1;' - LANGUAGE SQL - IMMUTABLE - RETURNS NULL ON NULL INPUT;", - Some(""), - ) - .await; -} diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index f8fee4d2..61ad9363 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -5,7 +5,7 @@ use async_helper::run_async; use dashmap::DashMap; use db_connection::DbConnection; use document::Document; -use futures::{stream, StreamExt}; +use futures::{StreamExt, stream}; use parsed_document::{ AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, ParsedDocument, SyncDiagnosticsMapper, @@ -13,7 +13,7 @@ use parsed_document::{ use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pgt_diagnostics::{ - serde::Diagnostic as SDiagnostic, Diagnostic, DiagnosticExt, Error, Severity, + Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, }; use pgt_fs::{ConfigName, PgTPath}; use pgt_typecheck::{TypecheckParams, TypedIdentifier}; @@ -22,17 +22,17 @@ use sqlx::Executor; use tracing::info; use crate::{ + WorkspaceError, configuration::to_analyser_rules, features::{ code_actions::{ self, CodeAction, CodeActionKind, CodeActionsResult, CommandAction, CommandActionCategory, ExecuteStatementParams, ExecuteStatementResult, }, - completions::{get_statement_for_completions, CompletionsResult, GetCompletionsParams}, + completions::{CompletionsResult, GetCompletionsParams, get_statement_for_completions}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, }, settings::{Settings, SettingsHandle, SettingsHandleMut}, - WorkspaceError, }; use super::{ @@ -360,55 +360,57 @@ impl Workspace for WorkspaceServer { let mut diagnostics: Vec = parser.document_diagnostics().to_vec(); - if let Some(pool) = self - .connection - .read() - .expect("DbConnection RwLock panicked") - .get_pool() - { - let path_clone = params.path.clone(); - let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); - let async_results = run_async(async move { - stream::iter(input) - .map(|(_id, range, content, ast, cst, sign)| { - let pool = pool.clone(); - let path = path_clone.clone(); - async move { - if let Some(ast) = ast { - pgt_typecheck::check_sql(TypecheckParams { - conn: &pool, - sql: &content, - ast: &ast, - tree: &cst, - schema_cache: &self.schema_cache, - identifiers: sign.map(|s| TypedIdentifier {}).unwrap_or_default(), - }) - .await - .map(|d| { - d.map(|d| { - let r = d.location().span.map(|span| span + range.start()); - - d.with_file_path(path.as_path().display().to_string()) - .with_file_span(r.unwrap_or(range)) - }) - }) - } else { - Ok(None) - } - } - }) - .buffer_unordered(10) - .collect::>() - .await - })?; - - for result in async_results.into_iter() { - let result = result?; - if let Some(diag) = result { - diagnostics.push(SDiagnostic::new(diag)); - } - } - } + // if let Some(pool) = self + // .connection + // .read() + // .expect("DbConnection RwLock panicked") + // .get_pool() + // { + // let path_clone = params.path.clone(); + // let schema_cache = self.schema_cache.load(pool)?; + // let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); + // let async_results = run_async(async move { + // stream::iter(input) + // .map(|(_id, range, content, ast, cst, sign)| { + // let pool = pool.clone(); + // let path = path_clone.clone(); + // async move { + // if let Some(ast) = ast { + // // pgt_typecheck::check_sql(TypecheckParams { + // // conn: &pool, + // // sql: &content, + // // ast: &ast, + // // tree: &cst, + // // schema_cache, + // // identifiers: vec![], + // // }) + // // .await + // // .map(|d| { + // // d.map(|d| { + // // let r = d.location().span.map(|span| span + range.start()); + // // + // // d.with_file_path(path.as_path().display().to_string()) + // // .with_file_span(r.unwrap_or(range)) + // // }) + // // }) + // Ok(None) + // } else { + // Ok(None) + // } + // } + // }) + // .buffer_unordered(10) + // .collect::>() + // .await + // })?; + // + // for result in async_results.into_iter() { + // let result = result?; + // if let Some(diag) = result { + // diagnostics.push(SDiagnostic::new(diag)); + // } + // } + // } diagnostics.extend(parser.iter(SyncDiagnosticsMapper).flat_map( |(_id, range, ast, diag)| { diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs index 01531b2c..d18806ed 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -12,7 +12,7 @@ use super::{ change::StatementChange, document::{Document, StatementIterator}, pg_query::PgQueryStore, - sql_function::{SQLFunctionBodyStore, SQLFunctionSignature}, + sql_function::{SQLFunctionSignature, get_sql_fn_body, get_sql_fn_signature}, statement_identifier::StatementId, tree_sitter::TreeSitterStore, }; @@ -24,7 +24,6 @@ pub struct ParsedDocument { doc: Document, ast_db: PgQueryStore, cst_db: TreeSitterStore, - sql_fn_db: SQLFunctionBodyStore, annotation_db: AnnotationStore, } @@ -34,7 +33,6 @@ impl ParsedDocument { let cst_db = TreeSitterStore::new(); let ast_db = PgQueryStore::new(); - let sql_fn_db = SQLFunctionBodyStore::new(); let annotation_db = AnnotationStore::new(); doc.iter().for_each(|(stmt, _, content)| { @@ -46,7 +44,6 @@ impl ParsedDocument { doc, ast_db, cst_db, - sql_fn_db, annotation_db, } } @@ -72,7 +69,6 @@ impl ParsedDocument { tracing::debug!("Deleting statement: id {:?}", s,); self.cst_db.remove_statement(s); self.ast_db.clear_statement(s); - self.sql_fn_db.clear_statement(s); self.annotation_db.clear_statement(s); } StatementChange::Modified(s) => { @@ -88,7 +84,6 @@ impl ParsedDocument { self.cst_db.modify_statement(s); self.ast_db.clear_statement(&s.old_stmt); - self.sql_fn_db.clear_statement(&s.old_stmt); self.annotation_db.clear_statement(&s.old_stmt); } } @@ -197,11 +192,7 @@ where .as_ref() { // Check if this is a SQL function definition with a body - if let Some(sub_statement) = - self.parser - .sql_fn_db - .get_function_body(&root_id, ast, &content_owned) - { + if let Some(sub_statement) = get_sql_fn_body(ast, &content_owned) { // Add sub-statements to our pending queue self.pending_sub_statements.push(( root_id.create_child(), @@ -274,7 +265,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { String, Option, Arc, - Option>, + Option, ); fn map( @@ -310,7 +301,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { let ast_option = ast_option.as_ref()?; - parser.sql_fn_db.get_function_signature(&root, ast_option) + get_sql_fn_signature(ast_option) }); (id, range, content_owned, ast_option, cst_result, sql_fn_sig) diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index fef73121..a490f7e9 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -1,10 +1,5 @@ -use std::sync::Arc; - -use dashmap::DashMap; use pgt_text_size::TextRange; -use super::statement_identifier::StatementId; - #[derive(Debug, Clone)] pub struct SQLFunctionArgs { pub name: Option, @@ -23,74 +18,13 @@ pub struct SQLFunctionBody { pub body: String, } -pub struct SQLFunctionBodyStore { - db: DashMap>>, - sig_db: DashMap>>, -} - -impl SQLFunctionBodyStore { - pub fn new() -> SQLFunctionBodyStore { - SQLFunctionBodyStore { - db: DashMap::new(), - sig_db: DashMap::new(), - } - } - - pub fn get_function_signature( - &self, - statement: &StatementId, - ast: &pgt_query_ext::NodeEnum, - ) -> Option> { - // First check if we already have this statement cached - if let Some(existing) = self.sig_db.get(statement).map(|x| x.clone()) { - return existing; - } - - // If not cached, try to extract it from the AST - let fn_sig = get_sql_fn_signature(ast).map(Arc::new); - - // Cache the result and return it - self.sig_db.insert(statement.clone(), fn_sig.clone()); - fn_sig - } - - pub fn get_function_body( - &self, - statement: &StatementId, - ast: &pgt_query_ext::NodeEnum, - content: &str, - ) -> Option> { - // First check if we already have this statement cached - if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { - return existing; - } - - // If not cached, try to extract it from the AST - let fn_body = get_sql_fn(ast, content).map(Arc::new); - - // Cache the result and return it - self.db.insert(statement.clone(), fn_body.clone()); - fn_body - } - - pub fn clear_statement(&self, id: &StatementId) { - self.db.remove(id); - - if let Some(child_id) = id.get_child_id() { - self.db.remove(&child_id); - } - } -} - -/// Extracts SQL function signature from a CreateFunctionStmt node. -fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option { +/// Extracts the function signature from a SQL function definition +pub fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option { let create_fn = match ast { pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, _ => return None, }; - println!("create_fn: {:?}", create_fn); - // Extract language from function options let language = find_option_value(create_fn, "language")?; @@ -124,16 +58,13 @@ fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option Option { +/// Extracts the SQL body from a function definition +pub fn get_sql_fn_body(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option { let create_fn = match ast { pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, _ => return None, }; - println!("create_fn: {:?}", create_fn); - // Extract language from function options let language = find_option_value(create_fn, "language")?; @@ -214,3 +145,25 @@ fn parse_name(nodes: &Vec) -> Option<(Option None, } } + +#[cfg(test)] +mod tests { + use super::*; + + use pgt_fs::PgTPath; + + #[test] + fn sql_function_signature() { + let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + + let ast = pgt_query_ext::parse(input).unwrap(); + + let sig = get_sql_fn_signature(&ast); + + println!("Function signature: {:?}", sig); + } +} From e032c3f12425f94d47280eb730b44565625c7cb3 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Fri, 25 Apr 2025 23:18:01 +0200 Subject: [PATCH 04/10] progress --- Cargo.lock | 1 + crates/pgt_treesitter_queries/src/lib.rs | 11 ++-- .../src/queries/parameters.rs | 33 ++++++++-- crates/pgt_typecheck/Cargo.toml | 1 + crates/pgt_typecheck/src/typed_identifier.rs | 60 ++++++++++++------- 5 files changed, 74 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 72ba810f..ffdbff92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2865,6 +2865,7 @@ dependencies = [ "pgt_schema_cache", "pgt_test_utils", "pgt_text_size", + "pgt_treesitter_queries", "sqlx", "tokio", "tree-sitter", diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index 4bd7c01d..ef5b73fc 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -69,8 +69,7 @@ impl<'a> Iterator for QueryResultIter<'a> { mod tests { use crate::{ - TreeSitterQueriesExecutor, - queries::{ParameterMatch, RelationMatch}, + queries::{Field, ParameterMatch, RelationMatch}, TreeSitterQueriesExecutor }; #[test] @@ -211,18 +210,18 @@ on sq1.id = pt.id; assert_eq!(results[0].get_root(sql), None); assert_eq!(results[0].get_path(sql), None); - assert_eq!(results[0].get_field(sql), "v_test"); + assert_eq!(results[0].get_field(sql), Field::Text("v_test".to_string())); assert_eq!(results[1].get_root(sql), Some("fn_name".into())); assert_eq!(results[1].get_path(sql), Some("custom_type".into())); - assert_eq!(results[1].get_field(sql), "v_test2"); + assert_eq!(results[1].get_field(sql), Field::Text("v_test2".to_string())); assert_eq!(results[2].get_root(sql), None); assert_eq!(results[2].get_path(sql), None); - assert_eq!(results[2].get_field(sql), "$3"); + assert_eq!(results[2].get_field(sql), Field::Parameter(3)); assert_eq!(results[3].get_root(sql), None); assert_eq!(results[3].get_path(sql), Some("custom_type".into())); - assert_eq!(results[3].get_field(sql), "v_test3"); + assert_eq!(results[3].get_field(sql), Field::Text("v_test3".to_string())); } } diff --git a/crates/pgt_treesitter_queries/src/queries/parameters.rs b/crates/pgt_treesitter_queries/src/queries/parameters.rs index b0ffaee2..6bc95734 100644 --- a/crates/pgt_treesitter_queries/src/queries/parameters.rs +++ b/crates/pgt_treesitter_queries/src/queries/parameters.rs @@ -26,6 +26,12 @@ pub struct ParameterMatch<'a> { pub(crate) field: tree_sitter::Node<'a>, } +#[derive(Debug, PartialEq)] +pub enum Field { + Text(String), + Parameter(usize), +} + impl ParameterMatch<'_> { pub fn get_root(&self, sql: &str) -> Option { let str = self @@ -47,11 +53,30 @@ impl ParameterMatch<'_> { Some(str.to_string()) } - pub fn get_field(&self, sql: &str) -> String { - self.field + pub fn get_field(&self, sql: &str) -> Field { + let text = self + .field .utf8_text(sql.as_bytes()) - .expect("Failed to get table from RelationMatch") - .to_string() + .expect("Failed to get field from RelationMatch"); + + if let Some(stripped) = text.strip_prefix('$') { + return Field::Parameter( + stripped + .parse::() + .expect("Failed to parse parameter"), + ); + } + + Field::Text(text.to_string()) + } + + pub fn get_range(&self) -> tree_sitter::Range { + self.field.range() + } + + pub fn get_byte_range(&self) -> std::ops::Range { + let range = self.field.range(); + range.start_byte..range.end_byte } } diff --git a/crates/pgt_typecheck/Cargo.toml b/crates/pgt_typecheck/Cargo.toml index a097fa56..9a2d7022 100644 --- a/crates/pgt_typecheck/Cargo.toml +++ b/crates/pgt_typecheck/Cargo.toml @@ -17,6 +17,7 @@ pgt_diagnostics.workspace = true pgt_query_ext.workspace = true pgt_schema_cache.workspace = true pgt_text_size.workspace = true +pgt_treesitter_queries.workspace = true sqlx.workspace = true tokio.workspace = true tree-sitter.workspace = true diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index bbba6e79..41d6a29e 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,30 +1,17 @@ +use pgt_treesitter_queries::{ + queries::{Field, ParameterMatch}, + TreeSitterQueriesExecutor, +}; + #[derive(Debug)] pub struct TypedIdentifier { - pub schema: Option, - pub relation: Option, + pub root: Option, + pub field: Option, pub name: String, - pub type_: (Option, String), + pub type_: Identifier, } -impl TypedIdentifier { - pub fn new( - schema: Option, - relation: Option, - name: String, - type_: (Option, String), - ) -> Self { - TypedIdentifier { - schema, - relation, - name, - type_, - } - } - - pub fn default_value(&self, schema_cache: &pgt_schema_cache::SchemaCache) -> String { - "NULL".to_string() - } -} +type Identifier = (Option, String); /// Applies the identifiers to the SQL string by replacing them with their default values. pub fn apply_identifiers<'a>( @@ -38,6 +25,35 @@ pub fn apply_identifiers<'a>( println!("Identifiers: {:?}", identifiers); println!("CST: {:#?}", cst); + let mut executor = TreeSitterQueriesExecutor::new(cst.root_node(), sql); + + executor.add_query_results::(); + + // we need the range and type of each field + + let results: Vec<(std::ops::Range, &Identifier)> = executor + .get_iter(None) + .filter_map(|q| { + let m: &ParameterMatch = q.try_into().ok()?; + + let ident = match m.get_field(sql) { + Field::Parameter(idx) => identifiers.get(idx)?, + Field::Text(field) => { + let r = m.get_root(sql); + let p = m.get_path(sql); + + identifiers.iter().find(|i| { + // TODO: this is not correct, we need to check if the identifier is a prefix of the field + + })? + } + }; + + Some((m.get_byte_range(), &ident.type_)) + }) + // TODO resolve composite types or table types to plain types + .collect(); + sql } From 896bfb1516a0f49ff3e9dbaf6f4cdccd41e5dedf Mon Sep 17 00:00:00 2001 From: psteinroe Date: Tue, 29 Apr 2025 09:39:00 +0200 Subject: [PATCH 05/10] progress --- crates/pgt_schema_cache/src/types.rs | 61 ++++- crates/pgt_treesitter_queries/src/lib.rs | 19 +- .../pgt_treesitter_queries/src/queries/mod.rs | 10 +- .../src/queries/parameters.rs | 87 +------ crates/pgt_typecheck/Cargo.toml | 18 +- crates/pgt_typecheck/src/typed_identifier.rs | 239 ++++++++++-------- 6 files changed, 223 insertions(+), 211 deletions(-) diff --git a/crates/pgt_schema_cache/src/types.rs b/crates/pgt_schema_cache/src/types.rs index 8b2d04bb..5d7f1b43 100644 --- a/crates/pgt_schema_cache/src/types.rs +++ b/crates/pgt_schema_cache/src/types.rs @@ -6,13 +6,13 @@ use crate::schema_cache::SchemaCacheItem; #[derive(Debug, Clone, Default)] pub struct TypeAttributes { - attrs: Vec, + pub attrs: Vec, } #[derive(Debug, Clone, Default, Deserialize)] pub struct PostgresTypeAttribute { - name: String, - type_id: i64, + pub name: String, + pub type_id: i64, } impl From> for TypeAttributes { @@ -56,3 +56,58 @@ impl SchemaCacheItem for PostgresType { .await } } + +#[cfg(test)] +mod tests { + use pgt_test_utils::test_database::get_new_test_db; + use sqlx::Executor; + + use crate::{schema_cache::SchemaCacheItem, types::PostgresType}; + + #[tokio::test] + async fn test_types() { + let setup = r#" + CREATE TYPE "public"."priority" AS ENUM ( + 'critical', + 'high', + 'default', + 'low', + 'very_low' + ); + + CREATE TYPE complex AS ( + r double precision, + i double precision + ); + "#; + + let test_db = get_new_test_db().await; + + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + + let types = PostgresType::load(&test_db).await.unwrap(); + + let enum_type = types.iter().find(|t| t.name == "priority"); + let comp_type = types.iter().find(|t| t.name == "complex"); + + println!("{:?}", enum_type); + // search for type id + println!("{:?}", comp_type); + + comp_type.and_then(|t| { + t.attributes.attrs.iter().for_each(|a| { + let typ = types.iter().find(|t| t.id == a.type_id); + println!( + "{}: {} - {:?}", + a.name, + a.type_id, + typ.as_ref().map(|t| t.name.clone()) + ); + }); + Some(()) + }); + } +} diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index ef5b73fc..03c47256 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -69,7 +69,8 @@ impl<'a> Iterator for QueryResultIter<'a> { mod tests { use crate::{ - queries::{Field, ParameterMatch, RelationMatch}, TreeSitterQueriesExecutor + TreeSitterQueriesExecutor, + queries::{Field, ParameterMatch, RelationMatch}, }; #[test] @@ -208,20 +209,12 @@ on sq1.id = pt.id; assert_eq!(results.len(), 4); - assert_eq!(results[0].get_root(sql), None); - assert_eq!(results[0].get_path(sql), None); - assert_eq!(results[0].get_field(sql), Field::Text("v_test".to_string())); + assert_eq!(results[0].get_path(sql), "v_test"); - assert_eq!(results[1].get_root(sql), Some("fn_name".into())); - assert_eq!(results[1].get_path(sql), Some("custom_type".into())); - assert_eq!(results[1].get_field(sql), Field::Text("v_test2".to_string())); + assert_eq!(results[1].get_path(sql), "fn_name.custom_type._test2"); - assert_eq!(results[2].get_root(sql), None); - assert_eq!(results[2].get_path(sql), None); - assert_eq!(results[2].get_field(sql), Field::Parameter(3)); + assert_eq!(results[2].get_path(sql), "$3"); - assert_eq!(results[3].get_root(sql), None); - assert_eq!(results[3].get_path(sql), Some("custom_type".into())); - assert_eq!(results[3].get_field(sql), Field::Text("v_test3".to_string())); + assert_eq!(results[3].get_path(sql), "custom_type.v_test3"); } } diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 62924e00..ecc924e5 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -24,14 +24,10 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } Self::Parameter(pm) => { - let start = match pm.root { - Some(s) => s.start_position(), - None => pm.path.as_ref().unwrap().start_position(), - }; + let node_range = pm.node.range(); - let end = pm.field.end_position(); - - start >= range.start_point && end <= range.end_point + node_range.start_point >= range.start_point + && node_range.end_point <= range.end_point } } } diff --git a/crates/pgt_treesitter_queries/src/queries/parameters.rs b/crates/pgt_treesitter_queries/src/queries/parameters.rs index 6bc95734..85ea9ad2 100644 --- a/crates/pgt_treesitter_queries/src/queries/parameters.rs +++ b/crates/pgt_treesitter_queries/src/queries/parameters.rs @@ -20,62 +20,23 @@ static TS_QUERY: LazyLock = LazyLock::new(|| { #[derive(Debug)] pub struct ParameterMatch<'a> { - pub(crate) root: Option>, - pub(crate) path: Option>, - - pub(crate) field: tree_sitter::Node<'a>, -} - -#[derive(Debug, PartialEq)] -pub enum Field { - Text(String), - Parameter(usize), + pub(crate) node: tree_sitter::Node<'a>, } impl ParameterMatch<'_> { - pub fn get_root(&self, sql: &str) -> Option { - let str = self - .root - .as_ref()? - .utf8_text(sql.as_bytes()) - .expect("Failed to get schema from RelationMatch"); - - Some(str.to_string()) - } - - pub fn get_path(&self, sql: &str) -> Option { - let str = self - .path - .as_ref()? + pub fn get_path(&self, sql: &str) -> String { + self.node .utf8_text(sql.as_bytes()) - .expect("Failed to get table from RelationMatch"); - - Some(str.to_string()) - } - - pub fn get_field(&self, sql: &str) -> Field { - let text = self - .field - .utf8_text(sql.as_bytes()) - .expect("Failed to get field from RelationMatch"); - - if let Some(stripped) = text.strip_prefix('$') { - return Field::Parameter( - stripped - .parse::() - .expect("Failed to parse parameter"), - ); - } - - Field::Text(text.to_string()) + .expect("Failed to get path from ParameterMatch") + .to_string() } pub fn get_range(&self) -> tree_sitter::Range { - self.field.range() + self.node.range() } pub fn get_byte_range(&self) -> std::ops::Range { - let range = self.field.range(); + let range = self.node.range(); range.start_byte..range.end_byte } } @@ -112,37 +73,9 @@ impl<'a> Query<'a> for ParameterMatch<'a> { return None; } - let field = captures[0].node; - let text = match field.utf8_text(stmt.as_bytes()) { - Ok(t) => t, - Err(_) => return None, - }; - let parts: Vec<&str> = text.split('.').collect(); - - let param_match = match parts.len() { - // Simple field: field_name - 1 => ParameterMatch { - root: None, - path: None, - field, - }, - // Table qualified: table.field_name - 2 => ParameterMatch { - root: None, - path: field.named_child(0), - field: field.named_child(1)?, - }, - // Fully qualified: schema.table.field_name - 3 => ParameterMatch { - root: field.named_child(0).and_then(|n| n.named_child(0)), - path: field.named_child(0).and_then(|n| n.named_child(1)), - field: field.named_child(1)?, - }, - // Unexpected number of parts - _ => return None, - }; - - Some(QueryResult::Parameter(param_match)) + Some(QueryResult::Parameter(ParameterMatch { + node: captures[0].node, + })) }) .collect() } diff --git a/crates/pgt_typecheck/Cargo.toml b/crates/pgt_typecheck/Cargo.toml index 9a2d7022..caacc6d1 100644 --- a/crates/pgt_typecheck/Cargo.toml +++ b/crates/pgt_typecheck/Cargo.toml @@ -12,16 +12,16 @@ version = "0.0.0" [dependencies] -pgt_console.workspace = true -pgt_diagnostics.workspace = true -pgt_query_ext.workspace = true -pgt_schema_cache.workspace = true -pgt_text_size.workspace = true +pgt_console.workspace = true +pgt_diagnostics.workspace = true +pgt_query_ext.workspace = true +pgt_schema_cache.workspace = true +pgt_text_size.workspace = true pgt_treesitter_queries.workspace = true -sqlx.workspace = true -tokio.workspace = true -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +sqlx.workspace = true +tokio.workspace = true +tree-sitter.workspace = true +tree_sitter_sql.workspace = true [dev-dependencies] insta.workspace = true diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 41d6a29e..90efcf9d 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,13 +1,15 @@ -use pgt_treesitter_queries::{ - queries::{Field, ParameterMatch}, - TreeSitterQueriesExecutor, -}; +use pgt_treesitter_queries::{TreeSitterQueriesExecutor, queries::ParameterMatch}; +/// A typed identifier is a parameter that has a type associated with it. +/// It is used to replace parameters within the SQL string. #[derive(Debug)] pub struct TypedIdentifier { - pub root: Option, - pub field: Option, + /// The path of the parameter, usually the name of the function. + /// This is because `fn_name.arg_name` is a valid reference within a SQL function. + pub path: String, + /// The name of the argument pub name: String, + /// The type of the argument with schema and name pub type_: Identifier, } @@ -20,39 +22,94 @@ pub fn apply_identifiers<'a>( cst: &'a tree_sitter::Tree, sql: &'a str, ) -> &'a str { - // TODO - println!("Applying identifiers to SQL: {}", sql); - println!("Identifiers: {:?}", identifiers); - println!("CST: {:#?}", cst); - let mut executor = TreeSitterQueriesExecutor::new(cst.root_node(), sql); executor.add_query_results::(); // we need the range and type of each field - - let results: Vec<(std::ops::Range, &Identifier)> = executor + let results = executor .get_iter(None) .filter_map(|q| { let m: &ParameterMatch = q.try_into().ok()?; - let ident = match m.get_field(sql) { - Field::Parameter(idx) => identifiers.get(idx)?, - Field::Text(field) => { - let r = m.get_root(sql); - let p = m.get_path(sql); + let path = m.get_path(sql); + let parts = path.split(".").collect::>(); + + // find the identifier and its index + // if it starts with $ it is a parameter, e.g. `$2` targets the second parameter + let (ident, idx) = if parts.len() == 1 && parts[0].starts_with("$") { + let idx = parts[0][1..].parse::().ok()?; + + let ident = identifiers.get(idx - 1)?; - identifiers.iter().find(|i| { - // TODO: this is not correct, we need to check if the identifier is a prefix of the field + (ident, idx) + } else { + // If it is not a parameter, its the path to the identifier + // e.g. `fn_name.custom_type.v_test2` or `custom_type.v_test3` or just `v_test4` + // Note that we cannot know if its `fn_name.arg_name` or `arg_name.field_name` (for + // composite types). + identifiers.iter().find_map(|i| { + let (idx, _part) = parts.iter().enumerate().find(|(_idx, p)| **p == i.name)?; - })? - } + Some((i, idx)) + })? }; - Some((m.get_byte_range(), &ident.type_)) + println!("Found identifier: {:?}", ident); + + // now resolve its type + let type_ = if idx < parts.len() - 1 { + // special case: composite types + let (schema, name) = &ident.type_; + + let schema_type = schema_cache + .types + .iter() + .find(|t| schema.as_ref().is_none_or(|s| t.schema == *s) && t.name == *name)?; + + let field_name = parts.last().unwrap(); + + let field = schema_type + .attributes + .attrs + .iter() + .find(|a| a.name == *field_name)?; + + let field_type = schema_cache.types.iter().find(|t| t.id == field.type_id)?; + + (Some(field_type.schema.as_str()), field_type.name.as_str()) + } else { + // find schema of the type + let schema = ident.type_.0.as_deref().or_else(|| { + schema_cache + .find_type(&ident.type_.1, None) + .map(|t| t.schema.as_str()) + }); + + (schema, ident.type_.1.as_str()) + }; + + Some((m.get_byte_range(), type_)) }) - // TODO resolve composite types or table types to plain types - .collect(); + .collect::>(); + + println!("Results: {:?}", results); + + // now resolve the default values + // for enums we need to fetch the values + // for everything else we implement a default value generator + // we then replace the identifier with the default value + // we will have an issue with enum values that are longer than the original identifier, e.g. $1 + // but for the rest we can simply fill up the space with spaces. + // we might be able to use NULL for some types or as a fallback. + // for now, we can simply not expose the location if the default is larger than the identifier + + results.iter().for_each(|(r, type_)| { + let (schema, name) = type_; + + // if the type not in pg_catalog, its probably an enum and we want to fetch one of its + // values + }); sql } @@ -64,10 +121,63 @@ mod tests { #[tokio::test] async fn test_apply_identifiers() { - let input = "select v_test + fn_name.custom_type.v_test2 + $3 + test.field;"; + let input = "select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3 + fn_name.v_test2 + enum_type"; + + let identifiers = vec![ + super::TypedIdentifier { + path: "fn_name".to_string(), + name: "v_test".to_string(), + type_: (None, "int4".to_string()), + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: "custom_type".to_string(), + type_: (Some("public".to_string()), "custom_type".to_string()), + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: "another".to_string(), + type_: (None, "numeric".to_string()), + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: "custom_type".to_string(), + type_: (Some("public".to_string()), "custom_type".to_string()), + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: "v_test2".to_string(), + type_: (None, "int4".to_string()), + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: "enum_type".to_string(), + type_: (Some("public".to_string()), "enum_type".to_string()), + }, + ]; let test_db = get_new_test_db().await; + let setup = r#" + CREATE TYPE "public"."custom_type" AS ( + v_test2 integer, + v_test3 integer + ); + + CREATE TYPE "public"."enum_type" AS ENUM ( + 'critical', + 'high', + 'default', + 'low', + 'very_low' + ); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + let mut parser = tree_sitter::Parser::new(); parser .set_language(tree_sitter_sql::language()) @@ -77,83 +187,8 @@ mod tests { .await .expect("Failed to load Schema Cache"); - let root = pgt_query_ext::parse(input).unwrap(); let tree = parser.parse(input, None).unwrap(); - println!("Parsed SQL: {:?}", root); - println!("Parsed CST: {:?}", tree); - - // let mut parameters = Vec::new(); - - enum Parameter { - Identifier { - range: (usize, usize), - name: (Option, String), - }, - Parameter { - range: (usize, usize), - idx: usize, - }, - } - - let mut c = tree.walk(); - - 'outer: loop { - // 0. Add the current node to the map. - println!("Current node: {:?}", c.node()); - match c.node().kind() { - "identifier" => { - println!( - "Found identifier: {:?}", - c.node().utf8_text(input.as_bytes()).unwrap() - ); - } - "parameter" => { - println!( - "Found parameter: {:?}", - c.node().utf8_text(input.as_bytes()).unwrap() - ); - } - "object_reference" => { - println!( - "Found object reference: {:?}", - c.node().utf8_text(input.as_bytes()).unwrap() - ); - - // let source = self.text; - // ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { - // if SanitizedCompletionParams::is_sanitized_token(txt) { - // NodeText::Replaced - // } else { - // NodeText::Original(txt) - // } - // }) - } - _ => {} - } - - // 1. Go to its child and continue. - if c.goto_first_child() { - continue 'outer; - } - - // 2. We've reached a leaf (node without a child). We will go to a sibling. - if c.goto_next_sibling() { - continue 'outer; - } - - // 3. If there are no more siblings, we need to go back up. - 'inner: loop { - // 4. Check if we've reached the root node. If so, we're done. - if !c.goto_parent() { - break 'outer; - } - // 5. Go to the previous node's sibling. - if c.goto_next_sibling() { - // And break out of the inner loop. - break 'inner; - } - } - } + super::apply_identifiers(identifiers, &schema_cache, &tree, input); } } From 9f9cf9bed2518787ba5637e4e2f60e248ecf9650 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Wed, 7 May 2025 09:43:57 +0200 Subject: [PATCH 06/10] just tests missing now --- crates/pgt_completions/src/context.rs | 32 +- crates/pgt_schema_cache/src/lib.rs | 1 + crates/pgt_schema_cache/src/types.rs | 55 --- crates/pgt_typecheck/src/diagnostics.rs | 21 +- crates/pgt_typecheck/src/lib.rs | 12 +- crates/pgt_typecheck/src/typed_identifier.rs | 320 +++++++++++++----- crates/pgt_workspace/src/workspace/server.rs | 128 ++++--- .../src/workspace/server/sql_function.rs | 64 +++- 8 files changed, 403 insertions(+), 230 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index a4578df8..60fcf9be 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -100,25 +100,23 @@ impl<'a> CompletionContext<'a> { executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { - match relation_match { - QueryResult::Relation(r) => { - let schema_name = r.get_schema(sql); - let table_name = r.get_table(sql); + if let QueryResult::Relation(r) = relation_match { + let schema_name = r.get_schema(sql); + let table_name = r.get_table(sql); - let current = self.mentioned_relations.get_mut(&schema_name); + let current = self.mentioned_relations.get_mut(&schema_name); - match current { - Some(c) => { - c.insert(table_name); - } - None => { - let mut new = HashSet::new(); - new.insert(table_name); - self.mentioned_relations.insert(schema_name, new); - } - }; - } - }; + match current { + Some(c) => { + c.insert(table_name); + } + None => { + let mut new = HashSet::new(); + new.insert(table_name); + self.mentioned_relations.insert(schema_name, new); + } + }; + } } } diff --git a/crates/pgt_schema_cache/src/lib.rs b/crates/pgt_schema_cache/src/lib.rs index 28c5b641..3e2f58ec 100644 --- a/crates/pgt_schema_cache/src/lib.rs +++ b/crates/pgt_schema_cache/src/lib.rs @@ -15,3 +15,4 @@ pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; pub use schema_cache::SchemaCache; pub use schemas::Schema; pub use tables::{ReplicaIdentity, Table}; +pub use types::{PostgresType, PostgresTypeAttribute}; diff --git a/crates/pgt_schema_cache/src/types.rs b/crates/pgt_schema_cache/src/types.rs index 5d7f1b43..dd67e439 100644 --- a/crates/pgt_schema_cache/src/types.rs +++ b/crates/pgt_schema_cache/src/types.rs @@ -56,58 +56,3 @@ impl SchemaCacheItem for PostgresType { .await } } - -#[cfg(test)] -mod tests { - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; - - use crate::{schema_cache::SchemaCacheItem, types::PostgresType}; - - #[tokio::test] - async fn test_types() { - let setup = r#" - CREATE TYPE "public"."priority" AS ENUM ( - 'critical', - 'high', - 'default', - 'low', - 'very_low' - ); - - CREATE TYPE complex AS ( - r double precision, - i double precision - ); - "#; - - let test_db = get_new_test_db().await; - - test_db - .execute(setup) - .await - .expect("Failed to setup test database"); - - let types = PostgresType::load(&test_db).await.unwrap(); - - let enum_type = types.iter().find(|t| t.name == "priority"); - let comp_type = types.iter().find(|t| t.name == "complex"); - - println!("{:?}", enum_type); - // search for type id - println!("{:?}", comp_type); - - comp_type.and_then(|t| { - t.attributes.attrs.iter().for_each(|a| { - let typ = types.iter().find(|t| t.id == a.type_id); - println!( - "{}: {} - {:?}", - a.name, - a.type_id, - typ.as_ref().map(|t| t.name.clone()) - ); - }); - Some(()) - }); - } -} diff --git a/crates/pgt_typecheck/src/diagnostics.rs b/crates/pgt_typecheck/src/diagnostics.rs index 8fd92da2..2117adbe 100644 --- a/crates/pgt_typecheck/src/diagnostics.rs +++ b/crates/pgt_typecheck/src/diagnostics.rs @@ -97,6 +97,7 @@ impl Advices for TypecheckAdvices { pub(crate) fn create_type_error( pg_err: &PgDatabaseError, ts: &tree_sitter::Tree, + positions_valid: bool, ) -> TypecheckDiagnostic { let position = pg_err.position().and_then(|pos| match pos { sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1), @@ -104,14 +105,18 @@ pub(crate) fn create_type_error( }); let range = position.and_then(|pos| { - ts.root_node() - .named_descendant_for_byte_range(pos, pos) - .map(|node| { - TextRange::new( - node.start_byte().try_into().unwrap(), - node.end_byte().try_into().unwrap(), - ) - }) + if positions_valid { + ts.root_node() + .named_descendant_for_byte_range(pos, pos) + .map(|node| { + TextRange::new( + node.start_byte().try_into().unwrap(), + node.end_byte().try_into().unwrap(), + ) + }) + } else { + None + } }); let severity = match pg_err.severity() { diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 8158e82c..e1dcd259 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -7,8 +7,8 @@ use pgt_text_size::TextRange; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; use sqlx::{Executor, PgPool}; -pub use typed_identifier::TypedIdentifier; use typed_identifier::apply_identifiers; +pub use typed_identifier::{IdentifierType, TypedIdentifier}; #[derive(Debug)] pub struct TypecheckParams<'a> { @@ -56,20 +56,24 @@ pub async fn check_sql( // each typecheck operation. conn.close_on_drop(); - let prepared = apply_identifiers( + let (prepared, positions_valid) = apply_identifiers( params.identifiers, params.schema_cache, params.tree, params.sql, ); - let res = conn.prepare(prepared).await; + let res = conn.prepare(&prepared).await; match res { Ok(_) => Ok(None), Err(sqlx::Error::Database(err)) => { let pg_err = err.downcast_ref::(); - Ok(Some(create_type_error(pg_err, params.tree))) + Ok(Some(create_type_error( + pg_err, + params.tree, + positions_valid, + ))) } Err(err) => Err(err), } diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 90efcf9d..ef06ddd3 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,3 +1,4 @@ +use pgt_schema_cache::PostgresType; use pgt_treesitter_queries::{TreeSitterQueriesExecutor, queries::ParameterMatch}; /// A typed identifier is a parameter that has a type associated with it. @@ -8,12 +9,17 @@ pub struct TypedIdentifier { /// This is because `fn_name.arg_name` is a valid reference within a SQL function. pub path: String, /// The name of the argument - pub name: String, + pub name: Option, /// The type of the argument with schema and name - pub type_: Identifier, + pub type_: IdentifierType, } -type Identifier = (Option, String); +#[derive(Debug, Clone)] +pub struct IdentifierType { + pub schema: Option, + pub name: String, + pub is_array: bool, +} /// Applies the identifiers to the SQL string by replacing them with their default values. pub fn apply_identifiers<'a>( @@ -21,97 +27,206 @@ pub fn apply_identifiers<'a>( schema_cache: &'a pgt_schema_cache::SchemaCache, cst: &'a tree_sitter::Tree, sql: &'a str, -) -> &'a str { +) -> (String, bool) { let mut executor = TreeSitterQueriesExecutor::new(cst.root_node(), sql); executor.add_query_results::(); - // we need the range and type of each field - let results = executor + // Collect all replacements first to avoid modifying the string while iterating + let replacements: Vec<_> = executor .get_iter(None) .filter_map(|q| { let m: &ParameterMatch = q.try_into().ok()?; - let path = m.get_path(sql); - let parts = path.split(".").collect::>(); - - // find the identifier and its index - // if it starts with $ it is a parameter, e.g. `$2` targets the second parameter - let (ident, idx) = if parts.len() == 1 && parts[0].starts_with("$") { - let idx = parts[0][1..].parse::().ok()?; - - let ident = identifiers.get(idx - 1)?; - - (ident, idx) - } else { - // If it is not a parameter, its the path to the identifier - // e.g. `fn_name.custom_type.v_test2` or `custom_type.v_test3` or just `v_test4` - // Note that we cannot know if its `fn_name.arg_name` or `arg_name.field_name` (for - // composite types). - identifiers.iter().find_map(|i| { - let (idx, _part) = parts.iter().enumerate().find(|(_idx, p)| **p == i.name)?; - - Some((i, idx)) - })? - }; - - println!("Found identifier: {:?}", ident); - - // now resolve its type - let type_ = if idx < parts.len() - 1 { - // special case: composite types - let (schema, name) = &ident.type_; - - let schema_type = schema_cache - .types - .iter() - .find(|t| schema.as_ref().is_none_or(|s| t.schema == *s) && t.name == *name)?; - - let field_name = parts.last().unwrap(); - - let field = schema_type - .attributes - .attrs - .iter() - .find(|a| a.name == *field_name)?; - - let field_type = schema_cache.types.iter().find(|t| t.id == field.type_id)?; - - (Some(field_type.schema.as_str()), field_type.name.as_str()) - } else { - // find schema of the type - let schema = ident.type_.0.as_deref().or_else(|| { - schema_cache - .find_type(&ident.type_.1, None) - .map(|t| t.schema.as_str()) - }); - - (schema, ident.type_.1.as_str()) - }; - - Some((m.get_byte_range(), type_)) + let parts: Vec<_> = path.split('.').collect(); + + // Find the matching identifier and its position in the path + let (identifier, position) = find_matching_identifier(&parts, &identifiers)?; + + // Resolve the type based on whether we're accessing a field of a composite type + let type_ = resolve_type(identifier, position, &parts, schema_cache)?; + + Some((m.get_byte_range(), type_, identifier.type_.is_array)) }) - .collect::>(); + .collect(); + + let mut result = sql.to_string(); + + let mut valid_positions = true; + + // Apply replacements in reverse order to maintain correct byte offsets + for (range, type_, is_array) in replacements.into_iter().rev() { + let default_value = get_formatted_default_value(type_, is_array); + + // if the default_value is shorter than "range", fill it up with spaces + let default_value = if default_value.len() < range.end - range.start { + format!("{: range.end - range.start { + valid_positions = false; + } + + result.replace_range(range, &default_value); + } + + (result, valid_positions) +} + +/// Format the default value based on the type and whether it's an array +fn get_formatted_default_value(pg_type: &PostgresType, is_array: bool) -> String { + // Get the base default value for this type + let default = resolve_default_value(pg_type); + + let default = if default.len() > "NULL".len() { + // If the default value is longer than "NULL", use "NULL" instead + "NULL".to_string() + } else { + // Otherwise, use the default value + default + }; + + // For arrays, wrap the default in array syntax + if is_array { + format!("'{{{}}}'", default) + } else { + default + } +} + +/// Resolve the default value for a given Postgres type. +/// +/// * `pg_type`: The type to return the default value for. +pub fn resolve_default_value(pg_type: &PostgresType) -> String { + // Handle ENUM types by returning the first variant + if !pg_type.enums.values.is_empty() { + return format!("'{}'", pg_type.enums.values[0]); + } - println!("Results: {:?}", results); + match pg_type.name.as_str() { + // Numeric types + "smallint" | "int2" | "integer" | "int" | "int4" | "bigint" | "int8" | "decimal" + | "numeric" | "real" | "float4" | "double precision" | "float8" | "smallserial" + | "serial2" | "serial" | "serial4" | "bigserial" | "serial8" => "0".to_string(), - // now resolve the default values - // for enums we need to fetch the values - // for everything else we implement a default value generator - // we then replace the identifier with the default value - // we will have an issue with enum values that are longer than the original identifier, e.g. $1 - // but for the rest we can simply fill up the space with spaces. - // we might be able to use NULL for some types or as a fallback. - // for now, we can simply not expose the location if the default is larger than the identifier + // Boolean type + "boolean" | "bool" => "false".to_string(), - results.iter().for_each(|(r, type_)| { - let (schema, name) = type_; + // Character types + "character" | "char" | "character varying" | "varchar" | "text" => "''".to_string(), - // if the type not in pg_catalog, its probably an enum and we want to fetch one of its - // values - }); + // Date/time types + "date" => "'1970-01-01'".to_string(), + "time" | "time without time zone" => "'00:00:00'".to_string(), + "time with time zone" | "timetz" => "'00:00:00+00'".to_string(), + "timestamp" | "timestamp without time zone" => "'1970-01-01 00:00:00'".to_string(), + "timestamp with time zone" | "timestamptz" => "'1970-01-01 00:00:00+00'".to_string(), + "interval" => "'0'".to_string(), - sql + // JSON types + "json" | "jsonb" => "'null'".to_string(), + + // UUID + "uuid" => "'00000000-0000-0000-0000-000000000000'".to_string(), + + // Byte array + "bytea" => "'\\x'".to_string(), + + // Network types + "inet" => "'0.0.0.0'".to_string(), + "cidr" => "'0.0.0.0/0'".to_string(), + "macaddr" => "'00:00:00:00:00:00'".to_string(), + "macaddr8" => "'00:00:00:00:00:00:00:00'".to_string(), + + // Monetary type + "money" => "'0.00'".to_string(), + + // Geometric types + "point" => "'(0,0)'".to_string(), + "line" => "'{0,0,0}'".to_string(), + "lseg" => "'[(0,0),(0,0)]'".to_string(), + "box" => "'((0,0),(0,0))'".to_string(), + "path" => "'((0,0),(0,0))'".to_string(), + "polygon" => "'((0,0),(0,0),(0,0))'".to_string(), + "circle" => "'<(0,0),0>'".to_string(), + + // Text search types + "tsvector" => "''".to_string(), + "tsquery" => "''".to_string(), + + // XML + "xml" => "''".to_string(), + + // Log sequence number + "pg_lsn" => "'0/0'".to_string(), + + // Snapshot types + "txid_snapshot" | "pg_snapshot" => "NULL".to_string(), + + // Fallback for unrecognized types + _ => "NULL".to_string(), + } +} + +// Helper function to find the matching identifier and its position in the path +fn find_matching_identifier<'a>( + parts: &[&str], + identifiers: &'a [TypedIdentifier], +) -> Option<(&'a TypedIdentifier, usize)> { + // Case 1: Parameter reference (e.g., $2) + if parts.len() == 1 && parts[0].starts_with('$') { + let idx = parts[0][1..].parse::().ok()?; + let identifier = identifiers.get(idx - 1)?; + return Some((identifier, idx)); + } + + // Case 2: Named reference (e.g., fn_name.custom_type.v_test2) + identifiers.iter().find_map(|identifier| { + let name = identifier.name.as_ref()?; + + parts + .iter() + .enumerate() + .find(|(_idx, part)| **part == name) + .map(|(idx, _)| (identifier, idx)) + }) +} + +// Helper function to resolve the type based on the identifier and path +fn resolve_type<'a>( + identifier: &TypedIdentifier, + position: usize, + parts: &[&str], + schema_cache: &'a pgt_schema_cache::SchemaCache, +) -> Option<&'a PostgresType> { + if position < parts.len() - 1 { + // Find the composite type + let schema_type = schema_cache.types.iter().find(|t| { + identifier + .type_ + .schema + .as_ref() + .is_none_or(|s| t.schema == *s) + && t.name == *identifier.type_.name + })?; + + // Find the field within the composite type + let field_name = parts.last().unwrap(); + let field = schema_type + .attributes + .attrs + .iter() + .find(|a| a.name == *field_name)?; + + // Find the field's type + schema_cache.types.iter().find(|t| t.id == field.type_id) + } else { + // Direct type reference + schema_cache.find_type(&identifier.type_.name, identifier.type_.schema.as_deref()) + } } #[cfg(test)] @@ -127,32 +242,56 @@ mod tests { super::TypedIdentifier { path: "fn_name".to_string(), name: "v_test".to_string(), - type_: (None, "int4".to_string()), + type_: super::IdentifierType { + schema: None, + name: "int4".to_string(), + is_array: false, + }, }, super::TypedIdentifier { path: "fn_name".to_string(), name: "custom_type".to_string(), - type_: (Some("public".to_string()), "custom_type".to_string()), + type_: super::IdentifierType { + schema: Some("public".to_string()), + name: "custom_type".to_string(), + is_array: false, + }, }, super::TypedIdentifier { path: "fn_name".to_string(), name: "another".to_string(), - type_: (None, "numeric".to_string()), + type_: super::IdentifierType { + schema: None, + name: "numeric".to_string(), + is_array: false, + }, }, super::TypedIdentifier { path: "fn_name".to_string(), name: "custom_type".to_string(), - type_: (Some("public".to_string()), "custom_type".to_string()), + type_: super::IdentifierType { + schema: Some("public".to_string()), + name: "custom_type".to_string(), + is_array: false, + }, }, super::TypedIdentifier { path: "fn_name".to_string(), name: "v_test2".to_string(), - type_: (None, "int4".to_string()), + type_: super::IdentifierType { + schema: None, + name: "int4".to_string(), + is_array: false, + }, }, super::TypedIdentifier { path: "fn_name".to_string(), name: "enum_type".to_string(), - type_: (Some("public".to_string()), "enum_type".to_string()), + type_: super::IdentifierType { + schema: Some("public".to_string()), + name: "enum_type".to_string(), + is_array: false, + }, }, ]; @@ -189,6 +328,13 @@ mod tests { let tree = parser.parse(input, None).unwrap(); - super::apply_identifiers(identifiers, &schema_cache, &tree, input); + let (sql_out, valid_pos) = + super::apply_identifiers(identifiers, &schema_cache, &tree, input); + + assert!(valid_pos); + assert_eq!( + sql_out, + "select 0 + 0 + 0 + 0 + 0 + NULL " + ); } } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 61ad9363..45675cc1 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -1,4 +1,9 @@ -use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock}; +use std::{ + fs, + panic::RefUnwindSafe, + path::Path, + sync::{Arc, RwLock}, +}; use analyser::AnalyserVisitorBuilder; use async_helper::run_async; @@ -16,7 +21,7 @@ use pgt_diagnostics::{ Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, }; use pgt_fs::{ConfigName, PgTPath}; -use pgt_typecheck::{TypecheckParams, TypedIdentifier}; +use pgt_typecheck::{IdentifierType, TypecheckParams, TypedIdentifier}; use schema_cache_manager::SchemaCacheManager; use sqlx::Executor; use tracing::info; @@ -360,57 +365,74 @@ impl Workspace for WorkspaceServer { let mut diagnostics: Vec = parser.document_diagnostics().to_vec(); - // if let Some(pool) = self - // .connection - // .read() - // .expect("DbConnection RwLock panicked") - // .get_pool() - // { - // let path_clone = params.path.clone(); - // let schema_cache = self.schema_cache.load(pool)?; - // let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); - // let async_results = run_async(async move { - // stream::iter(input) - // .map(|(_id, range, content, ast, cst, sign)| { - // let pool = pool.clone(); - // let path = path_clone.clone(); - // async move { - // if let Some(ast) = ast { - // // pgt_typecheck::check_sql(TypecheckParams { - // // conn: &pool, - // // sql: &content, - // // ast: &ast, - // // tree: &cst, - // // schema_cache, - // // identifiers: vec![], - // // }) - // // .await - // // .map(|d| { - // // d.map(|d| { - // // let r = d.location().span.map(|span| span + range.start()); - // // - // // d.with_file_path(path.as_path().display().to_string()) - // // .with_file_span(r.unwrap_or(range)) - // // }) - // // }) - // Ok(None) - // } else { - // Ok(None) - // } - // } - // }) - // .buffer_unordered(10) - // .collect::>() - // .await - // })?; - // - // for result in async_results.into_iter() { - // let result = result?; - // if let Some(diag) = result { - // diagnostics.push(SDiagnostic::new(diag)); - // } - // } - // } + if let Some(pool) = self + .connection + .read() + .expect("DbConnection RwLock panicked") + .get_pool() + { + let path_clone = params.path.clone(); + let schema_cache = self.schema_cache.load(pool.clone())?; + let schema_cache_arc = Arc::new(schema_cache.as_ref().clone()); + let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); + // sorry for the ugly code :( + let async_results = run_async(async move { + stream::iter(input) + .map(|(_id, range, content, ast, cst, sign)| { + let pool = pool.clone(); + let path = path_clone.clone(); + let schema_cache = Arc::clone(&schema_cache_arc); + async move { + if let Some(ast) = ast { + pgt_typecheck::check_sql(TypecheckParams { + conn: &pool, + sql: &content, + ast: &ast, + tree: &cst, + schema_cache: schema_cache.as_ref(), + identifiers: sign + .map(|s| { + s.args + .iter() + .map(|a| TypedIdentifier { + path: s.name.1.clone(), + name: a.name.clone(), + type_: IdentifierType { + schema: a.type_.schema.clone(), + name: a.type_.name.clone(), + is_array: a.type_.is_array, + }, + }) + .collect::>() + }) + .unwrap_or_default(), + }) + .await + .map(|d| { + d.map(|d| { + let r = d.location().span.map(|span| span + range.start()); + + d.with_file_path(path.as_path().display().to_string()) + .with_file_span(r.unwrap_or(range)) + }) + }) + } else { + Ok(None) + } + } + }) + .buffer_unordered(10) + .collect::>() + .await + })?; + + for result in async_results.into_iter() { + let result = result?; + if let Some(diag) = result { + diagnostics.push(SDiagnostic::new(diag)); + } + } + } diagnostics.extend(parser.iter(SyncDiagnosticsMapper).flat_map( |(_id, range, ast, diag)| { diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index a490f7e9..338c28fc 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -1,9 +1,16 @@ use pgt_text_size::TextRange; +#[derive(Debug, Clone)] +pub struct ArgType { + pub schema: Option, + pub name: String, + pub is_array: bool, +} + #[derive(Debug, Clone)] pub struct SQLFunctionArgs { pub name: Option, - pub type_: (Option, String), + pub type_: ArgType, } #[derive(Debug, Clone)] @@ -45,7 +52,15 @@ pub fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option) -> Option<(Option, String)> { +fn parse_name(nodes: &[pgt_query_ext::protobuf::Node]) -> Option<(Option, String)> { let names = nodes .iter() .map(|n| match &n.node { @@ -150,8 +165,6 @@ fn parse_name(nodes: &Vec) -> Option<(Option Date: Wed, 7 May 2025 10:12:53 +0200 Subject: [PATCH 07/10] Update crates/pgt_workspace/src/workspace/server/sql_function.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- crates/pgt_workspace/src/workspace/server/sql_function.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 338c28fc..671e2a8b 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -48,8 +48,8 @@ pub fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option Date: Wed, 7 May 2025 10:13:06 +0200 Subject: [PATCH 08/10] Update crates/pgt_treesitter_queries/src/lib.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- crates/pgt_treesitter_queries/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index 1d42a925..51cf9fcc 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -282,7 +282,7 @@ on sq1.id = pt.id; assert_eq!(results[0].get_path(sql), "v_test"); - assert_eq!(results[1].get_path(sql), "fn_name.custom_type._test2"); + assert_eq!(results[1].get_path(sql), "fn_name.custom_type.v_test2"); assert_eq!(results[2].get_path(sql), "$3"); From 971fd7f6e709164b1db78bccbcd3c7bcfa043ae5 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Wed, 7 May 2025 10:19:51 +0200 Subject: [PATCH 09/10] fix: test --- crates/pgt_typecheck/src/typed_identifier.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index ef06ddd3..e63d3821 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -241,7 +241,7 @@ mod tests { let identifiers = vec![ super::TypedIdentifier { path: "fn_name".to_string(), - name: "v_test".to_string(), + name: Some("v_test".to_string()), type_: super::IdentifierType { schema: None, name: "int4".to_string(), @@ -250,7 +250,7 @@ mod tests { }, super::TypedIdentifier { path: "fn_name".to_string(), - name: "custom_type".to_string(), + name: Some("custom_type".to_string()), type_: super::IdentifierType { schema: Some("public".to_string()), name: "custom_type".to_string(), @@ -259,7 +259,7 @@ mod tests { }, super::TypedIdentifier { path: "fn_name".to_string(), - name: "another".to_string(), + name: Some("another".to_string()), type_: super::IdentifierType { schema: None, name: "numeric".to_string(), @@ -268,7 +268,7 @@ mod tests { }, super::TypedIdentifier { path: "fn_name".to_string(), - name: "custom_type".to_string(), + name: Some("custom_type".to_string()), type_: super::IdentifierType { schema: Some("public".to_string()), name: "custom_type".to_string(), @@ -277,7 +277,7 @@ mod tests { }, super::TypedIdentifier { path: "fn_name".to_string(), - name: "v_test2".to_string(), + name: Some("v_test2".to_string()), type_: super::IdentifierType { schema: None, name: "int4".to_string(), @@ -286,7 +286,7 @@ mod tests { }, super::TypedIdentifier { path: "fn_name".to_string(), - name: "enum_type".to_string(), + name: Some("enum_type".to_string()), type_: super::IdentifierType { schema: Some("public".to_string()), name: "enum_type".to_string(), From 8c4145a50a71824eccf0a541ee64091e45a8ac23 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Wed, 7 May 2025 10:40:35 +0200 Subject: [PATCH 10/10] fix: test --- crates/pgt_treesitter_queries/src/lib.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index 51cf9fcc..4bf71e74 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -70,11 +70,7 @@ mod tests { use crate::{ TreeSitterQueriesExecutor, - queries::{Field, ParameterMatch, RelationMatch}, - }; - use crate::{ - TreeSitterQueriesExecutor, - queries::{RelationMatch, TableAliasMatch}, + queries::{ParameterMatch, RelationMatch, TableAliasMatch}, }; #[test]