diff --git a/Cargo.lock b/Cargo.lock index 72ba810f..ffdbff92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2865,6 +2865,7 @@ dependencies = [ "pgt_schema_cache", "pgt_test_utils", "pgt_text_size", + "pgt_treesitter_queries", "sqlx", "tokio", "tree-sitter", diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index d96d0d53..aea79992 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -171,7 +171,9 @@ impl<'a> CompletionContext<'a> { table_alias_match.get_table(sql), ); } - }; + + _ => {} + } } } diff --git a/crates/pgt_schema_cache/src/lib.rs b/crates/pgt_schema_cache/src/lib.rs index fc717fbe..0fc54145 100644 --- a/crates/pgt_schema_cache/src/lib.rs +++ b/crates/pgt_schema_cache/src/lib.rs @@ -16,3 +16,4 @@ pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; pub use schema_cache::SchemaCache; pub use schemas::Schema; pub use tables::{ReplicaIdentity, Table}; +pub use types::{PostgresType, PostgresTypeAttribute}; diff --git a/crates/pgt_schema_cache/src/types.rs b/crates/pgt_schema_cache/src/types.rs index 8b2d04bb..dd67e439 100644 --- a/crates/pgt_schema_cache/src/types.rs +++ b/crates/pgt_schema_cache/src/types.rs @@ -6,13 +6,13 @@ use crate::schema_cache::SchemaCacheItem; #[derive(Debug, Clone, Default)] pub struct TypeAttributes { - attrs: Vec, + pub attrs: Vec, } #[derive(Debug, Clone, Default, Deserialize)] pub struct PostgresTypeAttribute { - name: String, - type_id: i64, + pub name: String, + pub type_id: i64, } impl From> for TypeAttributes { diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index 8d1719b0..4bf71e74 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -70,7 +70,7 @@ mod tests { use crate::{ TreeSitterQueriesExecutor, - queries::{RelationMatch, TableAliasMatch}, + queries::{ParameterMatch, RelationMatch, TableAliasMatch}, }; #[test] @@ -207,11 +207,11 @@ where select * from ( - select * + select * from ( select * from private.something - ) as sq2 + ) as sq2 join private.tableau pt1 on sq2.id = pt1.id ) as sq1 @@ -255,4 +255,33 @@ on sq1.id = pt.id; assert_eq!(results[0].get_schema(sql), Some("private".into())); assert_eq!(results[0].get_table(sql), "something"); } + + #[test] + fn extracts_parameters() { + let sql = r#"select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&ParameterMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 4); + + assert_eq!(results[0].get_path(sql), "v_test"); + + assert_eq!(results[1].get_path(sql), "fn_name.custom_type.v_test2"); + + assert_eq!(results[2].get_path(sql), "$3"); + + assert_eq!(results[3].get_path(sql), "custom_type.v_test3"); + } } diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 4e10ed60..c295db29 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -1,12 +1,15 @@ +mod parameters; mod relations; mod table_aliases; +pub use parameters::*; pub use relations::*; pub use table_aliases::*; #[derive(Debug)] pub enum QueryResult<'a> { Relation(RelationMatch<'a>), + Parameter(ParameterMatch<'a>), TableAliases(TableAliasMatch<'a>), } @@ -23,6 +26,12 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } + Self::Parameter(pm) => { + let node_range = pm.node.range(); + + node_range.start_point >= range.start_point + && node_range.end_point <= range.end_point + } QueryResult::TableAliases(m) => { let start = m.table.start_position(); let end = m.alias.end_position(); diff --git a/crates/pgt_treesitter_queries/src/queries/parameters.rs b/crates/pgt_treesitter_queries/src/queries/parameters.rs new file mode 100644 index 00000000..85ea9ad2 --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/parameters.rs @@ -0,0 +1,82 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" +[ + (field + (identifier)) @reference + (field + (object_reference) + "." (identifier)) @reference + (parameter) @parameter +] +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct ParameterMatch<'a> { + pub(crate) node: tree_sitter::Node<'a>, +} + +impl ParameterMatch<'_> { + pub fn get_path(&self, sql: &str) -> String { + self.node + .utf8_text(sql.as_bytes()) + .expect("Failed to get path from ParameterMatch") + .to_string() + } + + pub fn get_range(&self) -> tree_sitter::Range { + self.node.range() + } + + pub fn get_byte_range(&self) -> std::ops::Range { + let range = self.node.range(); + range.start_byte..range.end_byte + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ParameterMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::Parameter(r) => Ok(r), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for ParameterMatch<'a> { + type Ref = &'a ParameterMatch<'a>; +} + +impl<'a> Query<'a> for ParameterMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + matches + .filter_map(|m| { + let captures = m.captures; + + // We expect exactly one capture for a parameter + if captures.len() != 1 { + return None; + } + + Some(QueryResult::Parameter(ParameterMatch { + node: captures[0].node, + })) + }) + .collect() + } +} diff --git a/crates/pgt_typecheck/Cargo.toml b/crates/pgt_typecheck/Cargo.toml index a097fa56..caacc6d1 100644 --- a/crates/pgt_typecheck/Cargo.toml +++ b/crates/pgt_typecheck/Cargo.toml @@ -12,15 +12,16 @@ version = "0.0.0" [dependencies] -pgt_console.workspace = true -pgt_diagnostics.workspace = true -pgt_query_ext.workspace = true -pgt_schema_cache.workspace = true -pgt_text_size.workspace = true -sqlx.workspace = true -tokio.workspace = true -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +pgt_console.workspace = true +pgt_diagnostics.workspace = true +pgt_query_ext.workspace = true +pgt_schema_cache.workspace = true +pgt_text_size.workspace = true +pgt_treesitter_queries.workspace = true +sqlx.workspace = true +tokio.workspace = true +tree-sitter.workspace = true +tree_sitter_sql.workspace = true [dev-dependencies] insta.workspace = true diff --git a/crates/pgt_typecheck/src/diagnostics.rs b/crates/pgt_typecheck/src/diagnostics.rs index 8fd92da2..2117adbe 100644 --- a/crates/pgt_typecheck/src/diagnostics.rs +++ b/crates/pgt_typecheck/src/diagnostics.rs @@ -97,6 +97,7 @@ impl Advices for TypecheckAdvices { pub(crate) fn create_type_error( pg_err: &PgDatabaseError, ts: &tree_sitter::Tree, + positions_valid: bool, ) -> TypecheckDiagnostic { let position = pg_err.position().and_then(|pos| match pos { sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1), @@ -104,14 +105,18 @@ pub(crate) fn create_type_error( }); let range = position.and_then(|pos| { - ts.root_node() - .named_descendant_for_byte_range(pos, pos) - .map(|node| { - TextRange::new( - node.start_byte().try_into().unwrap(), - node.end_byte().try_into().unwrap(), - ) - }) + if positions_valid { + ts.root_node() + .named_descendant_for_byte_range(pos, pos) + .map(|node| { + TextRange::new( + node.start_byte().try_into().unwrap(), + node.end_byte().try_into().unwrap(), + ) + }) + } else { + None + } }); let severity = match pg_err.severity() { diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index f741c0e6..e1dcd259 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -1,4 +1,5 @@ mod diagnostics; +mod typed_identifier; pub use diagnostics::TypecheckDiagnostic; use diagnostics::create_type_error; @@ -6,6 +7,8 @@ use pgt_text_size::TextRange; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; use sqlx::{Executor, PgPool}; +use typed_identifier::apply_identifiers; +pub use typed_identifier::{IdentifierType, TypedIdentifier}; #[derive(Debug)] pub struct TypecheckParams<'a> { @@ -13,6 +16,8 @@ pub struct TypecheckParams<'a> { pub sql: &'a str, pub ast: &'a pgt_query_ext::NodeEnum, pub tree: &'a tree_sitter::Tree, + pub schema_cache: &'a pgt_schema_cache::SchemaCache, + pub identifiers: Vec, } #[derive(Debug, Clone)] @@ -51,13 +56,24 @@ pub async fn check_sql( // each typecheck operation. conn.close_on_drop(); - let res = conn.prepare(params.sql).await; + let (prepared, positions_valid) = apply_identifiers( + params.identifiers, + params.schema_cache, + params.tree, + params.sql, + ); + + let res = conn.prepare(&prepared).await; match res { Ok(_) => Ok(None), Err(sqlx::Error::Database(err)) => { let pg_err = err.downcast_ref::(); - Ok(Some(create_type_error(pg_err, params.tree))) + Ok(Some(create_type_error( + pg_err, + params.tree, + positions_valid, + ))) } Err(err) => Err(err), } diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs new file mode 100644 index 00000000..e63d3821 --- /dev/null +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -0,0 +1,340 @@ +use pgt_schema_cache::PostgresType; +use pgt_treesitter_queries::{TreeSitterQueriesExecutor, queries::ParameterMatch}; + +/// A typed identifier is a parameter that has a type associated with it. +/// It is used to replace parameters within the SQL string. +#[derive(Debug)] +pub struct TypedIdentifier { + /// The path of the parameter, usually the name of the function. + /// This is because `fn_name.arg_name` is a valid reference within a SQL function. + pub path: String, + /// The name of the argument + pub name: Option, + /// The type of the argument with schema and name + pub type_: IdentifierType, +} + +#[derive(Debug, Clone)] +pub struct IdentifierType { + pub schema: Option, + pub name: String, + pub is_array: bool, +} + +/// Applies the identifiers to the SQL string by replacing them with their default values. +pub fn apply_identifiers<'a>( + identifiers: Vec, + schema_cache: &'a pgt_schema_cache::SchemaCache, + cst: &'a tree_sitter::Tree, + sql: &'a str, +) -> (String, bool) { + let mut executor = TreeSitterQueriesExecutor::new(cst.root_node(), sql); + + executor.add_query_results::(); + + // Collect all replacements first to avoid modifying the string while iterating + let replacements: Vec<_> = executor + .get_iter(None) + .filter_map(|q| { + let m: &ParameterMatch = q.try_into().ok()?; + let path = m.get_path(sql); + let parts: Vec<_> = path.split('.').collect(); + + // Find the matching identifier and its position in the path + let (identifier, position) = find_matching_identifier(&parts, &identifiers)?; + + // Resolve the type based on whether we're accessing a field of a composite type + let type_ = resolve_type(identifier, position, &parts, schema_cache)?; + + Some((m.get_byte_range(), type_, identifier.type_.is_array)) + }) + .collect(); + + let mut result = sql.to_string(); + + let mut valid_positions = true; + + // Apply replacements in reverse order to maintain correct byte offsets + for (range, type_, is_array) in replacements.into_iter().rev() { + let default_value = get_formatted_default_value(type_, is_array); + + // if the default_value is shorter than "range", fill it up with spaces + let default_value = if default_value.len() < range.end - range.start { + format!("{: range.end - range.start { + valid_positions = false; + } + + result.replace_range(range, &default_value); + } + + (result, valid_positions) +} + +/// Format the default value based on the type and whether it's an array +fn get_formatted_default_value(pg_type: &PostgresType, is_array: bool) -> String { + // Get the base default value for this type + let default = resolve_default_value(pg_type); + + let default = if default.len() > "NULL".len() { + // If the default value is longer than "NULL", use "NULL" instead + "NULL".to_string() + } else { + // Otherwise, use the default value + default + }; + + // For arrays, wrap the default in array syntax + if is_array { + format!("'{{{}}}'", default) + } else { + default + } +} + +/// Resolve the default value for a given Postgres type. +/// +/// * `pg_type`: The type to return the default value for. +pub fn resolve_default_value(pg_type: &PostgresType) -> String { + // Handle ENUM types by returning the first variant + if !pg_type.enums.values.is_empty() { + return format!("'{}'", pg_type.enums.values[0]); + } + + match pg_type.name.as_str() { + // Numeric types + "smallint" | "int2" | "integer" | "int" | "int4" | "bigint" | "int8" | "decimal" + | "numeric" | "real" | "float4" | "double precision" | "float8" | "smallserial" + | "serial2" | "serial" | "serial4" | "bigserial" | "serial8" => "0".to_string(), + + // Boolean type + "boolean" | "bool" => "false".to_string(), + + // Character types + "character" | "char" | "character varying" | "varchar" | "text" => "''".to_string(), + + // Date/time types + "date" => "'1970-01-01'".to_string(), + "time" | "time without time zone" => "'00:00:00'".to_string(), + "time with time zone" | "timetz" => "'00:00:00+00'".to_string(), + "timestamp" | "timestamp without time zone" => "'1970-01-01 00:00:00'".to_string(), + "timestamp with time zone" | "timestamptz" => "'1970-01-01 00:00:00+00'".to_string(), + "interval" => "'0'".to_string(), + + // JSON types + "json" | "jsonb" => "'null'".to_string(), + + // UUID + "uuid" => "'00000000-0000-0000-0000-000000000000'".to_string(), + + // Byte array + "bytea" => "'\\x'".to_string(), + + // Network types + "inet" => "'0.0.0.0'".to_string(), + "cidr" => "'0.0.0.0/0'".to_string(), + "macaddr" => "'00:00:00:00:00:00'".to_string(), + "macaddr8" => "'00:00:00:00:00:00:00:00'".to_string(), + + // Monetary type + "money" => "'0.00'".to_string(), + + // Geometric types + "point" => "'(0,0)'".to_string(), + "line" => "'{0,0,0}'".to_string(), + "lseg" => "'[(0,0),(0,0)]'".to_string(), + "box" => "'((0,0),(0,0))'".to_string(), + "path" => "'((0,0),(0,0))'".to_string(), + "polygon" => "'((0,0),(0,0),(0,0))'".to_string(), + "circle" => "'<(0,0),0>'".to_string(), + + // Text search types + "tsvector" => "''".to_string(), + "tsquery" => "''".to_string(), + + // XML + "xml" => "''".to_string(), + + // Log sequence number + "pg_lsn" => "'0/0'".to_string(), + + // Snapshot types + "txid_snapshot" | "pg_snapshot" => "NULL".to_string(), + + // Fallback for unrecognized types + _ => "NULL".to_string(), + } +} + +// Helper function to find the matching identifier and its position in the path +fn find_matching_identifier<'a>( + parts: &[&str], + identifiers: &'a [TypedIdentifier], +) -> Option<(&'a TypedIdentifier, usize)> { + // Case 1: Parameter reference (e.g., $2) + if parts.len() == 1 && parts[0].starts_with('$') { + let idx = parts[0][1..].parse::().ok()?; + let identifier = identifiers.get(idx - 1)?; + return Some((identifier, idx)); + } + + // Case 2: Named reference (e.g., fn_name.custom_type.v_test2) + identifiers.iter().find_map(|identifier| { + let name = identifier.name.as_ref()?; + + parts + .iter() + .enumerate() + .find(|(_idx, part)| **part == name) + .map(|(idx, _)| (identifier, idx)) + }) +} + +// Helper function to resolve the type based on the identifier and path +fn resolve_type<'a>( + identifier: &TypedIdentifier, + position: usize, + parts: &[&str], + schema_cache: &'a pgt_schema_cache::SchemaCache, +) -> Option<&'a PostgresType> { + if position < parts.len() - 1 { + // Find the composite type + let schema_type = schema_cache.types.iter().find(|t| { + identifier + .type_ + .schema + .as_ref() + .is_none_or(|s| t.schema == *s) + && t.name == *identifier.type_.name + })?; + + // Find the field within the composite type + let field_name = parts.last().unwrap(); + let field = schema_type + .attributes + .attrs + .iter() + .find(|a| a.name == *field_name)?; + + // Find the field's type + schema_cache.types.iter().find(|t| t.id == field.type_id) + } else { + // Direct type reference + schema_cache.find_type(&identifier.type_.name, identifier.type_.schema.as_deref()) + } +} + +#[cfg(test)] +mod tests { + use pgt_test_utils::test_database::get_new_test_db; + use sqlx::Executor; + + #[tokio::test] + async fn test_apply_identifiers() { + let input = "select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3 + fn_name.v_test2 + enum_type"; + + let identifiers = vec![ + super::TypedIdentifier { + path: "fn_name".to_string(), + name: Some("v_test".to_string()), + type_: super::IdentifierType { + schema: None, + name: "int4".to_string(), + is_array: false, + }, + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: Some("custom_type".to_string()), + type_: super::IdentifierType { + schema: Some("public".to_string()), + name: "custom_type".to_string(), + is_array: false, + }, + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: Some("another".to_string()), + type_: super::IdentifierType { + schema: None, + name: "numeric".to_string(), + is_array: false, + }, + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: Some("custom_type".to_string()), + type_: super::IdentifierType { + schema: Some("public".to_string()), + name: "custom_type".to_string(), + is_array: false, + }, + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: Some("v_test2".to_string()), + type_: super::IdentifierType { + schema: None, + name: "int4".to_string(), + is_array: false, + }, + }, + super::TypedIdentifier { + path: "fn_name".to_string(), + name: Some("enum_type".to_string()), + type_: super::IdentifierType { + schema: Some("public".to_string()), + name: "enum_type".to_string(), + is_array: false, + }, + }, + ]; + + let test_db = get_new_test_db().await; + + let setup = r#" + CREATE TYPE "public"."custom_type" AS ( + v_test2 integer, + v_test3 integer + ); + + CREATE TYPE "public"."enum_type" AS ENUM ( + 'critical', + 'high', + 'default', + 'low', + 'very_low' + ); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + + let tree = parser.parse(input, None).unwrap(); + + let (sql_out, valid_pos) = + super::apply_identifiers(identifiers, &schema_cache, &tree, input); + + assert!(valid_pos); + assert_eq!( + sql_out, + "select 0 + 0 + 0 + 0 + 0 + NULL " + ); + } +} diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 4c780d74..9628962d 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -7,19 +7,25 @@ use pgt_test_utils::test_database::get_new_test_db; use pgt_typecheck::{TypecheckParams, check_sql}; use sqlx::Executor; -async fn test(name: &str, query: &str, setup: &str) { +async fn test(name: &str, query: &str, setup: Option<&str>) { let test_db = get_new_test_db().await; - test_db - .execute(setup) - .await - .expect("Failed to setup test database"); + if let Some(setup) = setup { + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + } let mut parser = tree_sitter::Parser::new(); parser .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); + let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + let root = pgt_query_ext::parse(query).unwrap(); let tree = parser.parse(query, None).unwrap(); @@ -29,6 +35,8 @@ async fn test(name: &str, query: &str, setup: &str) { sql: query, ast: &root, tree: &tree, + schema_cache: &schema_cache, + identifiers: vec![], }) .await; @@ -55,7 +63,8 @@ async fn invalid_column() { test( "invalid_column", "select id, unknown from contacts;", - r#" + Some( + r#" create table public.contacts ( id serial primary key, name varchar(255) not null, @@ -63,6 +72,7 @@ async fn invalid_column() { middle_name varchar(255) ); "#, + ), ) .await; } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 3c14f352..acb89f1d 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -1,4 +1,9 @@ -use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock}; +use std::{ + fs, + panic::RefUnwindSafe, + path::Path, + sync::{Arc, RwLock}, +}; use analyser::AnalyserVisitorBuilder; use async_helper::run_async; @@ -16,7 +21,7 @@ use pgt_diagnostics::{ Diagnostic, DiagnosticExt, Error, Severity, serde::Diagnostic as SDiagnostic, }; use pgt_fs::{ConfigName, PgTPath}; -use pgt_typecheck::TypecheckParams; +use pgt_typecheck::{IdentifierType, TypecheckParams, TypedIdentifier}; use schema_cache_manager::SchemaCacheManager; use sqlx::Executor; use tracing::info; @@ -365,12 +370,16 @@ impl Workspace for WorkspaceServer { .get_pool() { let path_clone = params.path.clone(); + let schema_cache = self.schema_cache.load(pool.clone())?; + let schema_cache_arc = Arc::new(schema_cache.as_ref().clone()); let input = parser.iter(AsyncDiagnosticsMapper).collect::>(); + // sorry for the ugly code :( let async_results = run_async(async move { stream::iter(input) - .map(|(_id, range, content, ast, cst)| { + .map(|(_id, range, content, ast, cst, sign)| { let pool = pool.clone(); let path = path_clone.clone(); + let schema_cache = Arc::clone(&schema_cache_arc); async move { if let Some(ast) = ast { pgt_typecheck::check_sql(TypecheckParams { @@ -378,6 +387,23 @@ impl Workspace for WorkspaceServer { sql: &content, ast: &ast, tree: &cst, + schema_cache: schema_cache.as_ref(), + identifiers: sign + .map(|s| { + s.args + .iter() + .map(|a| TypedIdentifier { + path: s.name.1.clone(), + name: a.name.clone(), + type_: IdentifierType { + schema: a.type_.schema.clone(), + name: a.type_.name.clone(), + is_array: a.type_.is_array, + }, + }) + .collect::>() + }) + .unwrap_or_default(), }) .await .map(|d| { diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index 67ed991c..ed0ca40f 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -34,6 +34,13 @@ impl Document { } } + pub fn statement_content(&self, id: &StatementId) -> Option<&str> { + self.positions + .iter() + .find(|(statement_id, _)| statement_id == id) + .map(|(_, range)| &self.content[*range]) + } + /// Returns true if there is at least one fatal error in the diagnostics /// /// A fatal error is a scan error that prevents the document from being used diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs index 92f33926..2b81faba 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -12,7 +12,7 @@ use super::{ change::StatementChange, document::{Document, StatementIterator}, pg_query::PgQueryStore, - sql_function::SQLFunctionBodyStore, + sql_function::{SQLFunctionSignature, get_sql_fn_body, get_sql_fn_signature}, statement_identifier::StatementId, tree_sitter::TreeSitterStore, }; @@ -24,7 +24,6 @@ pub struct ParsedDocument { doc: Document, ast_db: PgQueryStore, cst_db: TreeSitterStore, - sql_fn_db: SQLFunctionBodyStore, annotation_db: AnnotationStore, } @@ -34,7 +33,6 @@ impl ParsedDocument { let cst_db = TreeSitterStore::new(); let ast_db = PgQueryStore::new(); - let sql_fn_db = SQLFunctionBodyStore::new(); let annotation_db = AnnotationStore::new(); doc.iter().for_each(|(stmt, _, content)| { @@ -46,7 +44,6 @@ impl ParsedDocument { doc, ast_db, cst_db, - sql_fn_db, annotation_db, } } @@ -72,7 +69,6 @@ impl ParsedDocument { tracing::debug!("Deleting statement: id {:?}", s,); self.cst_db.remove_statement(s); self.ast_db.clear_statement(s); - self.sql_fn_db.clear_statement(s); self.annotation_db.clear_statement(s); } StatementChange::Modified(s) => { @@ -88,7 +84,6 @@ impl ParsedDocument { self.cst_db.modify_statement(s); self.ast_db.clear_statement(&s.old_stmt); - self.sql_fn_db.clear_statement(&s.old_stmt); self.annotation_db.clear_statement(&s.old_stmt); } } @@ -197,11 +192,7 @@ where .as_ref() { // Check if this is a SQL function definition with a body - if let Some(sub_statement) = - self.parser - .sql_fn_db - .get_function_body(&root_id, ast, &content_owned) - { + if let Some(sub_statement) = get_sql_fn_body(ast, &content_owned) { // Add sub-statements to our pending queue self.pending_sub_statements.push(( root_id.create_child(), @@ -274,6 +265,7 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { String, Option, Arc, + Option, ); fn map( @@ -293,7 +285,26 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { let cst_result = parser.cst_db.get_or_cache_tree(&id, &content_owned); - (id, range, content_owned, ast_option, cst_result) + let sql_fn_sig = id + .parent() + .and_then(|root| { + let c = parser.doc.statement_content(&root)?; + Some((root, c)) + }) + .and_then(|(root, c)| { + let ast_option = parser + .ast_db + .get_or_cache_ast(&root, c) + .as_ref() + .clone() + .ok(); + + let ast_option = ast_option.as_ref()?; + + get_sql_fn_signature(ast_option) + }); + + (id, range, content_owned, ast_option, cst_result, sql_fn_sig) } } @@ -413,7 +424,7 @@ mod tests { #[test] fn sql_function_body() { - let input = "CREATE FUNCTION add(integer, integer) RETURNS integer + let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 777210d5..671e2a8b 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -1,56 +1,80 @@ -use std::sync::Arc; - -use dashmap::DashMap; use pgt_text_size::TextRange; -use super::statement_identifier::StatementId; +#[derive(Debug, Clone)] +pub struct ArgType { + pub schema: Option, + pub name: String, + pub is_array: bool, +} #[derive(Debug, Clone)] -pub struct SQLFunctionBody { - pub range: TextRange, - pub body: String, +pub struct SQLFunctionArgs { + pub name: Option, + pub type_: ArgType, } -pub struct SQLFunctionBodyStore { - db: DashMap>>, +#[derive(Debug, Clone)] +pub struct SQLFunctionSignature { + pub name: (Option, String), + pub args: Vec, } -impl SQLFunctionBodyStore { - pub fn new() -> SQLFunctionBodyStore { - SQLFunctionBodyStore { db: DashMap::new() } - } +#[derive(Debug, Clone)] +pub struct SQLFunctionBody { + pub range: TextRange, + pub body: String, +} - pub fn get_function_body( - &self, - statement: &StatementId, - ast: &pgt_query_ext::NodeEnum, - content: &str, - ) -> Option> { - // First check if we already have this statement cached - if let Some(existing) = self.db.get(statement).map(|x| x.clone()) { - return existing; - } +/// Extracts the function signature from a SQL function definition +pub fn get_sql_fn_signature(ast: &pgt_query_ext::NodeEnum) -> Option { + let create_fn = match ast { + pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, + _ => return None, + }; - // If not cached, try to extract it from the AST - let fn_body = get_sql_fn(ast, content).map(Arc::new); + // Extract language from function options + let language = find_option_value(create_fn, "language")?; - // Cache the result and return it - self.db.insert(statement.clone(), fn_body.clone()); - fn_body + // Only process SQL functions + if language != "sql" { + return None; } - pub fn clear_statement(&self, id: &StatementId) { - self.db.remove(id); - - if let Some(child_id) = id.get_child_id() { - self.db.remove(&child_id); + let fn_name = parse_name(&create_fn.funcname)?; + + // we return None if anything is not expected + let mut fn_args = Vec::new(); + for arg in &create_fn.parameters { + if let Some(pgt_query_ext::NodeEnum::FunctionParameter(node)) = &arg.node { + let arg_name = (!node.name.is_empty()).then_some(node.name.clone()); + + let arg_type = node.arg_type.as_ref()?; + let type_name = parse_name(&arg_type.names)?; + fn_args.push(SQLFunctionArgs { + name: arg_name, + type_: ArgType { + schema: type_name.0, + name: type_name.1, + is_array: node + .arg_type + .as_ref() + .map(|t| !t.array_bounds.is_empty()) + .unwrap_or(false), + }, + }); + } else { + return None; } } + + Some(SQLFunctionSignature { + name: fn_name, + args: fn_args, + }) } -/// Extracts SQL function body and its text range from a CreateFunctionStmt node. -/// Returns None if the function is not an SQL function or if the body can't be found. -fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option { +/// Extracts the SQL body from a function definition +pub fn get_sql_fn_body(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option { let create_fn = match ast { pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf, _ => return None, @@ -120,3 +144,78 @@ fn find_option_value( } }) } + +fn parse_name(nodes: &[pgt_query_ext::protobuf::Node]) -> Option<(Option, String)> { + let names = nodes + .iter() + .map(|n| match &n.node { + Some(pgt_query_ext::NodeEnum::String(s)) => Some(s.sval.clone()), + _ => None, + }) + .collect::>(); + + match names.as_slice() { + [Some(schema), Some(name)] => Some((Some(schema.clone()), name.clone())), + [Some(name)] => Some((None, name.clone())), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sql_function_signature() { + let input = "CREATE FUNCTION add(test0 integer, test1 integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + + let ast = pgt_query_ext::parse(input).unwrap(); + + let sig = get_sql_fn_signature(&ast); + + assert!(sig.is_some()); + + let sig = sig.unwrap(); + + let arg1 = sig.args.first().unwrap(); + + assert_eq!(arg1.name, Some("test0".to_string())); + assert_eq!(arg1.type_.name, "int4"); + + let arg2 = sig.args.get(1).unwrap(); + assert_eq!(arg2.name, Some("test1".to_string())); + assert_eq!(arg2.type_.name, "int4"); + } + + #[test] + fn array_type() { + let input = "CREATE FUNCTION add(test0 integer[], test1 integer) RETURNS integer + AS 'select $1 + $2;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;"; + + let ast = pgt_query_ext::parse(input).unwrap(); + + let sig = get_sql_fn_signature(&ast); + + assert!(sig.is_some()); + + let sig = sig.unwrap(); + + assert!( + sig.args + .iter() + .find(|arg| arg.type_.is_array) + .map(|arg| { + assert_eq!(arg.type_.name, "int4"); + assert!(arg.type_.is_array); + }) + .is_some() + ); + } +} diff --git a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs index 8c02814d..7c7d76f0 100644 --- a/crates/pgt_workspace/src/workspace/server/statement_identifier.rs +++ b/crates/pgt_workspace/src/workspace/server/statement_identifier.rs @@ -57,6 +57,21 @@ impl StatementId { StatementId::Child(s) => s.inner, } } + + pub fn is_root(&self) -> bool { + matches!(self, StatementId::Root(_)) + } + + pub fn is_child(&self) -> bool { + matches!(self, StatementId::Child(_)) + } + + pub fn parent(&self) -> Option { + match self { + StatementId::Root(_) => None, + StatementId::Child(id) => Some(StatementId::Root(id.clone())), + } + } } /// Helper struct to generate unique statement ids