diff --git a/Cargo.lock b/Cargo.lock index 16b1de5e6..5cb3fe2d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2765,7 +2765,7 @@ dependencies = [ "pgt_schema_cache", "pgt_test_utils", "pgt_text_size", - "pgt_treesitter_queries", + "pgt_treesitter", "schemars", "serde", "serde_json", @@ -3047,10 +3047,13 @@ dependencies = [ ] [[package]] -name = "pgt_treesitter_queries" +name = "pgt_treesitter" version = "0.0.0" dependencies = [ "clap 4.5.23", + "pgt_schema_cache", + "pgt_test_utils", + "pgt_text_size", "tree-sitter", "tree_sitter_sql", ] @@ -3074,7 +3077,7 @@ dependencies = [ "pgt_schema_cache", "pgt_test_utils", "pgt_text_size", - "pgt_treesitter_queries", + "pgt_treesitter", "sqlx", "tokio", "tree-sitter", diff --git a/Cargo.toml b/Cargo.toml index 15c6f02ff..a5195d2da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,7 +82,7 @@ pgt_suppressions = { path = "./crates/pgt_suppressions", version = "0. pgt_text_edit = { path = "./crates/pgt_text_edit", version = "0.0.0" } pgt_text_size = { path = "./crates/pgt_text_size", version = "0.0.0" } pgt_tokenizer = { path = "./crates/pgt_tokenizer", version = "0.0.0" } -pgt_treesitter_queries = { path = "./crates/pgt_treesitter_queries", version = "0.0.0" } +pgt_treesitter = { path = "./crates/pgt_treesitter", version = "0.0.0" } pgt_typecheck = { path = "./crates/pgt_typecheck", version = "0.0.0" } pgt_workspace = { path = "./crates/pgt_workspace", version = "0.0.0" } diff --git a/crates/pgt_completions/Cargo.toml b/crates/pgt_completions/Cargo.toml index 916a00209..0ebb8e56e 100644 --- a/crates/pgt_completions/Cargo.toml +++ b/crates/pgt_completions/Cargo.toml @@ -14,18 +14,17 @@ version = "0.0.0" [dependencies] async-std = "1.12.0" -pgt_text_size.workspace = true - - -fuzzy-matcher = "0.3.7" -pgt_schema_cache.workspace = true -pgt_treesitter_queries.workspace = true -schemars = { workspace = true, optional = true } -serde = { workspace = true, features = ["derive"] } -serde_json = { workspace = true } -tracing = { workspace = true } -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +pgt_schema_cache.workspace = true +pgt_text_size.workspace = true +pgt_treesitter.workspace = true + +fuzzy-matcher = "0.3.7" +schemars = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tracing = { workspace = true } +tree-sitter.workspace = true +tree_sitter_sql.workspace = true sqlx.workspace = true diff --git a/crates/pgt_completions/src/builder.rs b/crates/pgt_completions/src/builder.rs index 96576053f..bf8eb66a6 100644 --- a/crates/pgt_completions/src/builder.rs +++ b/crates/pgt_completions/src/builder.rs @@ -1,10 +1,11 @@ use crate::{ CompletionItemKind, CompletionText, - context::CompletionContext, item::CompletionItem, relevance::{filtering::CompletionFilter, scoring::CompletionScore}, }; +use pgt_treesitter::TreesitterContext; + pub(crate) struct PossibleCompletionItem<'a> { pub label: String, pub description: String, @@ -17,11 +18,11 @@ pub(crate) struct PossibleCompletionItem<'a> { pub(crate) struct CompletionBuilder<'a> { items: Vec>, - ctx: &'a CompletionContext<'a>, + ctx: &'a TreesitterContext<'a>, } impl<'a> CompletionBuilder<'a> { - pub fn new(ctx: &'a CompletionContext) -> Self { + pub fn new(ctx: &'a TreesitterContext) -> Self { CompletionBuilder { items: vec![], ctx } } diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index bd5efd19d..e18589af0 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -1,8 +1,9 @@ use pgt_text_size::TextSize; +use pgt_treesitter::{TreeSitterContextParams, context::TreesitterContext}; + use crate::{ builder::CompletionBuilder, - context::CompletionContext, item::CompletionItem, providers::{ complete_columns, complete_functions, complete_policies, complete_roles, complete_schemas, @@ -28,16 +29,20 @@ pub struct CompletionParams<'a> { pub fn complete(params: CompletionParams) -> Vec { let sanitized_params = SanitizedCompletionParams::from(params); - let ctx = CompletionContext::new(&sanitized_params); + let ctx = TreesitterContext::new(TreeSitterContextParams { + position: sanitized_params.position, + text: &sanitized_params.text, + tree: &sanitized_params.tree, + }); let mut builder = CompletionBuilder::new(&ctx); - complete_tables(&ctx, &mut builder); - complete_functions(&ctx, &mut builder); - complete_columns(&ctx, &mut builder); - complete_schemas(&ctx, &mut builder); - complete_policies(&ctx, &mut builder); - complete_roles(&ctx, &mut builder); + complete_tables(&ctx, sanitized_params.schema, &mut builder); + complete_functions(&ctx, sanitized_params.schema, &mut builder); + complete_columns(&ctx, sanitized_params.schema, &mut builder); + complete_schemas(&ctx, sanitized_params.schema, &mut builder); + complete_policies(&ctx, sanitized_params.schema, &mut builder); + complete_roles(&ctx, sanitized_params.schema, &mut builder); builder.finish() } diff --git a/crates/pgt_completions/src/lib.rs b/crates/pgt_completions/src/lib.rs index f8ca1a550..c4e592eef 100644 --- a/crates/pgt_completions/src/lib.rs +++ b/crates/pgt_completions/src/lib.rs @@ -1,6 +1,5 @@ mod builder; mod complete; -mod context; mod item; mod providers; mod relevance; diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 04d0af656..ba3b24813 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -1,14 +1,20 @@ +use pgt_schema_cache::SchemaCache; +use pgt_treesitter::{TreesitterContext, WrappingClause}; + use crate::{ CompletionItemKind, builder::{CompletionBuilder, PossibleCompletionItem}, - context::{CompletionContext, WrappingClause}, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; use super::helper::{find_matching_alias_for_table, get_completion_text_with_schema_or_alias}; -pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { - let available_columns = &ctx.schema_cache.columns; +pub fn complete_columns<'a>( + ctx: &TreesitterContext<'a>, + schema_cache: &'a SchemaCache, + builder: &mut CompletionBuilder<'a>, +) { + let available_columns = &schema_cache.columns; for col in available_columns { let relevance = CompletionRelevanceData::Column(col); @@ -49,11 +55,13 @@ mod tests { use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ - CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, - assert_no_complete_results, get_test_deps, get_test_params, + CompletionAssertion, assert_complete_results, assert_no_complete_results, + get_test_deps, get_test_params, }, }; + use pgt_test_utils::QueryWithCursorPosition; + struct TestCase { query: String, message: &'static str, @@ -62,7 +70,7 @@ mod tests { } impl TestCase { - fn get_input_query(&self) -> InputQuery { + fn get_input_query(&self) -> QueryWithCursorPosition { let strs: Vec<&str> = self.query.split_whitespace().collect(); strs.join(" ").as_str().into() } @@ -94,7 +102,10 @@ mod tests { let queries: Vec = vec![ TestCase { message: "correctly prefers the columns of present tables", - query: format!(r#"select na{} from public.audio_books;"#, CURSOR_POS), + query: format!( + r#"select na{} from public.audio_books;"#, + QueryWithCursorPosition::cursor_marker() + ), label: "narrator", description: "public.audio_books", }, @@ -111,14 +122,17 @@ mod tests { join public.users u on u.id = subquery.id; "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), label: "narrator_id", description: "private.audio_books", }, TestCase { message: "works without a schema", - query: format!(r#"select na{} from users;"#, CURSOR_POS), + query: format!( + r#"select na{} from users;"#, + QueryWithCursorPosition::cursor_marker() + ), label: "name", description: "public.users", }, @@ -165,7 +179,7 @@ mod tests { pool.execute(setup).await.unwrap(); let case = TestCase { - query: format!(r#"select n{};"#, CURSOR_POS), + query: format!(r#"select n{};"#, QueryWithCursorPosition::cursor_marker()), description: "", label: "", message: "", @@ -220,7 +234,10 @@ mod tests { let test_case = TestCase { message: "suggests user created tables first", - query: format!(r#"select {} from users"#, CURSOR_POS), + query: format!( + r#"select {} from users"#, + QueryWithCursorPosition::cursor_marker() + ), label: "", description: "", }; @@ -270,7 +287,10 @@ mod tests { let test_case = TestCase { message: "suggests user created tables first", - query: format!(r#"select * from private.{}"#, CURSOR_POS), + query: format!( + r#"select * from private.{}"#, + QueryWithCursorPosition::cursor_marker() + ), label: "", description: "", }; @@ -311,7 +331,11 @@ mod tests { pool.execute(setup).await.unwrap(); assert_complete_results( - format!(r#"select {} from users"#, CURSOR_POS).as_str(), + format!( + r#"select {} from users"#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("address2".into()), CompletionAssertion::Label("email2".into()), @@ -324,7 +348,11 @@ mod tests { .await; assert_complete_results( - format!(r#"select {} from private.users"#, CURSOR_POS).as_str(), + format!( + r#"select {} from private.users"#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("address1".into()), CompletionAssertion::Label("email1".into()), @@ -338,7 +366,11 @@ mod tests { // asserts fuzzy finding for "settings" assert_complete_results( - format!(r#"select sett{} from private.users"#, CURSOR_POS).as_str(), + format!( + r#"select sett{} from private.users"#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![CompletionAssertion::Label("user_settings".into())], None, &pool, @@ -372,7 +404,7 @@ mod tests { assert_complete_results( format!( "select u.id, p.{} from auth.users u join auth.posts p on u.id = p.user_id;", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -394,7 +426,7 @@ mod tests { assert_complete_results( format!( "select u.id, p.content from auth.users u join auth.posts p on u.id = p.{};", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -440,7 +472,7 @@ mod tests { assert_complete_results( format!( "select u.id, p.content from auth.users u join auth.{}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -479,7 +511,7 @@ mod tests { assert_complete_results( format!( "select u.id, auth.posts.content from auth.users u join auth.posts on u.{}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -496,7 +528,7 @@ mod tests { assert_complete_results( format!( "select u.id, p.content from auth.users u join auth.posts p on p.user_id = u.{}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -536,7 +568,7 @@ mod tests { assert_complete_results( format!( "select {} from public.one o join public.two on o.id = t.id;", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -555,7 +587,7 @@ mod tests { assert_complete_results( format!( "select a, {} from public.one o join public.two on o.id = t.id;", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -577,7 +609,7 @@ mod tests { assert_complete_results( format!( "select o.id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -593,7 +625,7 @@ mod tests { assert_complete_results( format!( "select id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![CompletionAssertion::Label("z".to_string())], @@ -625,7 +657,11 @@ mod tests { // are lower in the alphabet assert_complete_results( - format!("insert into instruments ({})", CURSOR_POS).as_str(), + format!( + "insert into instruments ({})", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("id".to_string()), CompletionAssertion::Label("name".to_string()), @@ -637,7 +673,11 @@ mod tests { .await; assert_complete_results( - format!("insert into instruments (id, {})", CURSOR_POS).as_str(), + format!( + "insert into instruments (id, {})", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("name".to_string()), CompletionAssertion::Label("z".to_string()), @@ -648,7 +688,11 @@ mod tests { .await; assert_complete_results( - format!("insert into instruments (id, {}, name)", CURSOR_POS).as_str(), + format!( + "insert into instruments (id, {}, name)", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![CompletionAssertion::Label("z".to_string())], None, &pool, @@ -659,7 +703,7 @@ mod tests { assert_complete_results( format!( "insert into instruments (name, {}) values ('my_bass');", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -673,7 +717,11 @@ mod tests { // no completions in the values list! assert_no_complete_results( - format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(), + format!( + "insert into instruments (id, name) values ({})", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), None, &pool, ) @@ -700,7 +748,11 @@ mod tests { pool.execute(setup).await.unwrap(); assert_complete_results( - format!("select name from instruments where {} ", CURSOR_POS).as_str(), + format!( + "select name from instruments where {} ", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("created_at".into()), CompletionAssertion::Label("id".into()), @@ -715,7 +767,7 @@ mod tests { assert_complete_results( format!( "select name from instruments where z = 'something' and created_at > {}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), // simply do not complete columns + schemas; functions etc. are ok @@ -732,7 +784,7 @@ mod tests { assert_complete_results( format!( "select name from instruments where id = 'something' and {}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -749,7 +801,7 @@ mod tests { assert_complete_results( format!( "select name from instruments i join others o on i.z = o.a where i.{}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -783,22 +835,37 @@ mod tests { pool.execute(setup).await.unwrap(); let queries = vec![ - format!("alter table instruments drop column {}", CURSOR_POS), + format!( + "alter table instruments drop column {}", + QueryWithCursorPosition::cursor_marker() + ), format!( "alter table instruments drop column if exists {}", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), format!( "alter table instruments alter column {} set default", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() + ), + format!( + "alter table instruments alter {} set default", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "alter table public.instruments alter column {}", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "alter table instruments alter {}", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "alter table instruments rename {} to new_col", + QueryWithCursorPosition::cursor_marker() ), - format!("alter table instruments alter {} set default", CURSOR_POS), - format!("alter table public.instruments alter column {}", CURSOR_POS), - format!("alter table instruments alter {}", CURSOR_POS), - format!("alter table instruments rename {} to new_col", CURSOR_POS), format!( "alter table public.instruments rename column {} to new_col", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), ]; @@ -834,19 +901,19 @@ mod tests { let col_queries = vec![ format!( r#"create policy "my_pol" on public.instruments for select using ({})"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), format!( r#"create policy "my_pol" on public.instruments for insert with check ({})"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), format!( r#"create policy "my_pol" on public.instruments for update using (id = 1 and {})"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), format!( r#"create policy "my_pol" on public.instruments for insert with check (id = 1 and {})"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ), ]; diff --git a/crates/pgt_completions/src/providers/functions.rs b/crates/pgt_completions/src/providers/functions.rs index 615e4f951..b2ac2fae8 100644 --- a/crates/pgt_completions/src/providers/functions.rs +++ b/crates/pgt_completions/src/providers/functions.rs @@ -1,17 +1,21 @@ -use pgt_schema_cache::Function; +use pgt_schema_cache::{Function, SchemaCache}; +use pgt_treesitter::TreesitterContext; use crate::{ CompletionItemKind, CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, - context::CompletionContext, providers::helper::get_range_to_replace, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; use super::helper::get_completion_text_with_schema_or_alias; -pub fn complete_functions<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) { - let available_functions = &ctx.schema_cache.functions; +pub fn complete_functions<'a>( + ctx: &'a TreesitterContext, + schema_cache: &'a SchemaCache, + builder: &mut CompletionBuilder<'a>, +) { + let available_functions = &schema_cache.functions; for func in available_functions { let relevance = CompletionRelevanceData::Function(func); @@ -30,7 +34,7 @@ pub fn complete_functions<'a>(ctx: &'a CompletionContext, builder: &mut Completi } } -fn get_completion_text(ctx: &CompletionContext, func: &Function) -> CompletionText { +fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionText { let range = get_range_to_replace(ctx); let mut text = get_completion_text_with_schema_or_alias(ctx, &func.name, &func.schema) .map(|ct| ct.text) @@ -70,11 +74,12 @@ mod tests { use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ - CURSOR_POS, CompletionAssertion, assert_complete_results, get_test_deps, - get_test_params, + CompletionAssertion, assert_complete_results, get_test_deps, get_test_params, }, }; + use pgt_test_utils::QueryWithCursorPosition; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn completes_fn(pool: PgPool) { let setup = r#" @@ -89,7 +94,7 @@ mod tests { $$; "#; - let query = format!("select coo{}", CURSOR_POS); + let query = format!("select coo{}", QueryWithCursorPosition::cursor_marker()); let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); @@ -122,7 +127,10 @@ mod tests { $$; "#; - let query = format!(r#"select * from coo{}()"#, CURSOR_POS); + let query = format!( + r#"select * from coo{}()"#, + QueryWithCursorPosition::cursor_marker() + ); let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); @@ -156,7 +164,7 @@ mod tests { $$; "#; - let query = format!(r#"select coo{}"#, CURSOR_POS); + let query = format!(r#"select coo{}"#, QueryWithCursorPosition::cursor_marker()); let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); @@ -190,7 +198,10 @@ mod tests { $$; "#; - let query = format!(r#"select * from coo{}()"#, CURSOR_POS); + let query = format!( + r#"select * from coo{}()"#, + QueryWithCursorPosition::cursor_marker() + ); let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); @@ -259,7 +270,7 @@ mod tests { let query = format!( r#"create policy "my_pol" on public.instruments for insert with check (id = {})"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ); assert_complete_results( diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index 811125bd1..cd1046f12 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -1,9 +1,10 @@ use pgt_text_size::{TextRange, TextSize}; +use pgt_treesitter::TreesitterContext; -use crate::{CompletionText, context::CompletionContext, remove_sanitized_token}; +use crate::{CompletionText, remove_sanitized_token}; pub(crate) fn find_matching_alias_for_table( - ctx: &CompletionContext, + ctx: &TreesitterContext, table_name: &str, ) -> Option { for (alias, table) in ctx.mentioned_table_aliases.iter() { @@ -14,7 +15,7 @@ pub(crate) fn find_matching_alias_for_table( None } -pub(crate) fn get_range_to_replace(ctx: &CompletionContext) -> TextRange { +pub(crate) fn get_range_to_replace(ctx: &TreesitterContext) -> TextRange { match ctx.node_under_cursor.as_ref() { Some(node) => { let content = ctx.get_node_under_cursor_content().unwrap_or("".into()); @@ -30,7 +31,7 @@ pub(crate) fn get_range_to_replace(ctx: &CompletionContext) -> TextRange { } pub(crate) fn get_completion_text_with_schema_or_alias( - ctx: &CompletionContext, + ctx: &TreesitterContext, item_name: &str, schema_or_alias_name: &str, ) -> Option { diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index 216fcefaa..a5ffdb43e 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -1,16 +1,21 @@ +use pgt_schema_cache::SchemaCache; use pgt_text_size::{TextRange, TextSize}; +use pgt_treesitter::TreesitterContext; use crate::{ CompletionItemKind, CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, - context::CompletionContext, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; use super::helper::get_range_to_replace; -pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { - let available_policies = &ctx.schema_cache.policies; +pub fn complete_policies<'a>( + ctx: &TreesitterContext<'a>, + schema_cache: &'a SchemaCache, + builder: &mut CompletionBuilder<'a>, +) { + let available_policies = &schema_cache.policies; let surrounded_by_quotes = ctx .get_node_under_cursor_content() @@ -61,7 +66,8 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi mod tests { use sqlx::{Executor, PgPool}; - use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + use crate::test_helper::{CompletionAssertion, assert_complete_results}; + use pgt_test_utils::QueryWithCursorPosition; #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn completes_within_quotation_marks(pool: PgPool) { @@ -89,7 +95,11 @@ mod tests { pool.execute(setup).await.unwrap(); assert_complete_results( - format!("alter policy \"{}\" on private.users;", CURSOR_POS).as_str(), + format!( + "alter policy \"{}\" on private.users;", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("read for public users disallowed".into()), CompletionAssertion::Label("write for public users allowed".into()), @@ -100,7 +110,11 @@ mod tests { .await; assert_complete_results( - format!("alter policy \"w{}\" on private.users;", CURSOR_POS).as_str(), + format!( + "alter policy \"w{}\" on private.users;", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![CompletionAssertion::Label( "write for public users allowed".into(), )], diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 01641543f..b7664349c 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -1,12 +1,17 @@ use crate::{ CompletionItemKind, builder::{CompletionBuilder, PossibleCompletionItem}, - context::CompletionContext, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; +use pgt_schema_cache::SchemaCache; +use pgt_treesitter::TreesitterContext; -pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { - let available_roles = &ctx.schema_cache.roles; +pub fn complete_roles<'a>( + _ctx: &TreesitterContext<'a>, + schema_cache: &'a SchemaCache, + builder: &mut CompletionBuilder<'a>, +) { + let available_roles = &schema_cache.roles; for role in available_roles { let relevance = CompletionRelevanceData::Role(role); @@ -29,7 +34,9 @@ pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionB mod tests { use sqlx::{Executor, PgPool}; - use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + use crate::test_helper::{CompletionAssertion, assert_complete_results}; + + use pgt_test_utils::QueryWithCursorPosition; const SETUP: &str = r#" create table users ( @@ -42,7 +49,7 @@ mod tests { #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn works_in_drop_role(pool: PgPool) { assert_complete_results( - format!("drop role {}", CURSOR_POS).as_str(), + format!("drop role {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( @@ -63,7 +70,7 @@ mod tests { #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn works_in_alter_role(pool: PgPool) { assert_complete_results( - format!("alter role {}", CURSOR_POS).as_str(), + format!("alter role {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( @@ -86,7 +93,7 @@ mod tests { pool.execute(SETUP).await.unwrap(); assert_complete_results( - format!("set role {}", CURSOR_POS).as_str(), + format!("set role {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( @@ -104,7 +111,11 @@ mod tests { .await; assert_complete_results( - format!("set session authorization {}", CURSOR_POS).as_str(), + format!( + "set session authorization {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( @@ -133,7 +144,7 @@ mod tests { for all to {} using (true);"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -157,7 +168,7 @@ mod tests { r#"create policy "my cool policy" on public.users for select to {}"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -186,7 +197,7 @@ mod tests { r#"grant select on table public.users to {}"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -211,7 +222,7 @@ mod tests { r#"grant select on table public.users to owner, {}"#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ @@ -232,7 +243,11 @@ mod tests { .await; assert_complete_results( - format!(r#"grant {} to owner"#, CURSOR_POS).as_str(), + format!( + r#"grant {} to owner"#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ // recognizing already mentioned roles is not supported for now CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), @@ -256,12 +271,30 @@ mod tests { pool.execute(SETUP).await.unwrap(); let queries = vec![ - format!("revoke {} from owner", CURSOR_POS), - format!("revoke admin option for {} from owner", CURSOR_POS), - format!("revoke owner from {}", CURSOR_POS), - format!("revoke all on schema public from {} granted by", CURSOR_POS), - format!("revoke all on schema public from owner, {}", CURSOR_POS), - format!("revoke all on table userse from owner, {}", CURSOR_POS), + format!( + "revoke {} from owner", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "revoke admin option for {} from owner", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "revoke owner from {}", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "revoke all on schema public from {} granted by", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "revoke all on schema public from owner, {}", + QueryWithCursorPosition::cursor_marker() + ), + format!( + "revoke all on table userse from owner, {}", + QueryWithCursorPosition::cursor_marker() + ), ]; for query in queries { diff --git a/crates/pgt_completions/src/providers/schemas.rs b/crates/pgt_completions/src/providers/schemas.rs index 561da0f85..43c523875 100644 --- a/crates/pgt_completions/src/providers/schemas.rs +++ b/crates/pgt_completions/src/providers/schemas.rs @@ -1,11 +1,16 @@ use crate::{ builder::{CompletionBuilder, PossibleCompletionItem}, - context::CompletionContext, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; +use pgt_schema_cache::SchemaCache; +use pgt_treesitter::TreesitterContext; -pub fn complete_schemas<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) { - let available_schemas = &ctx.schema_cache.schemas; +pub fn complete_schemas<'a>( + _ctx: &'a TreesitterContext, + schema_cache: &'a SchemaCache, + builder: &mut CompletionBuilder<'a>, +) { + let available_schemas = &schema_cache.schemas; for schema in available_schemas { let relevance = CompletionRelevanceData::Schema(schema); @@ -31,9 +36,11 @@ mod tests { use crate::{ CompletionItemKind, - test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}, + test_helper::{CompletionAssertion, assert_complete_results}, }; + use pgt_test_utils::QueryWithCursorPosition; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn autocompletes_schemas(pool: PgPool) { let setup = r#" @@ -50,7 +57,7 @@ mod tests { "#; assert_complete_results( - format!("select * from {}", CURSOR_POS).as_str(), + format!("select * from {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("public".to_string(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".to_string(), CompletionItemKind::Schema), @@ -97,7 +104,11 @@ mod tests { "#; assert_complete_results( - format!("select * from u{}", CURSOR_POS).as_str(), + format!( + "select * from u{}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("ultimate".into(), CompletionItemKind::Schema), diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 3fbee8f12..f78b697c9 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -1,14 +1,20 @@ +use pgt_schema_cache::SchemaCache; +use pgt_treesitter::TreesitterContext; + use crate::{ builder::{CompletionBuilder, PossibleCompletionItem}, - context::CompletionContext, item::CompletionItemKind, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; use super::helper::get_completion_text_with_schema_or_alias; -pub fn complete_tables<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) { - let available_tables = &ctx.schema_cache.tables; +pub fn complete_tables<'a>( + ctx: &'a TreesitterContext, + schema_cache: &'a SchemaCache, + builder: &mut CompletionBuilder<'a>, +) { + let available_tables = &schema_cache.tables; for table in available_tables { let relevance = CompletionRelevanceData::Table(table); @@ -47,11 +53,13 @@ mod tests { use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ - CURSOR_POS, CompletionAssertion, assert_complete_results, assert_no_complete_results, + CompletionAssertion, assert_complete_results, assert_no_complete_results, get_test_deps, get_test_params, }, }; + use pgt_test_utils::QueryWithCursorPosition; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn autocompletes_simple_table(pool: PgPool) { let setup = r#" @@ -62,7 +70,10 @@ mod tests { ); "#; - let query = format!("select * from u{}", CURSOR_POS); + let query = format!( + "select * from u{}", + QueryWithCursorPosition::cursor_marker() + ); let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); @@ -98,9 +109,27 @@ mod tests { pool.execute(setup).await.unwrap(); let test_cases = vec![ - (format!("select * from u{}", CURSOR_POS), "users"), - (format!("select * from e{}", CURSOR_POS), "emails"), - (format!("select * from a{}", CURSOR_POS), "addresses"), + ( + format!( + "select * from u{}", + QueryWithCursorPosition::cursor_marker() + ), + "users", + ), + ( + format!( + "select * from e{}", + QueryWithCursorPosition::cursor_marker() + ), + "emails", + ), + ( + format!( + "select * from a{}", + QueryWithCursorPosition::cursor_marker() + ), + "addresses", + ), ]; for (query, expected_label) in test_cases { @@ -142,10 +171,25 @@ mod tests { pool.execute(setup).await.unwrap(); let test_cases = vec![ - (format!("select * from u{}", CURSOR_POS), "user_y"), // user_y is preferred alphanumerically - (format!("select * from private.u{}", CURSOR_POS), "user_z"), ( - format!("select * from customer_support.u{}", CURSOR_POS), + format!( + "select * from u{}", + QueryWithCursorPosition::cursor_marker() + ), + "user_y", + ), // user_y is preferred alphanumerically + ( + format!( + "select * from private.u{}", + QueryWithCursorPosition::cursor_marker() + ), + "user_z", + ), + ( + format!( + "select * from customer_support.u{}", + QueryWithCursorPosition::cursor_marker() + ), "user_y", ), ]; @@ -186,7 +230,10 @@ mod tests { $$; "#; - let query = format!(r#"select * from coo{}"#, CURSOR_POS); + let query = format!( + r#"select * from coo{}"#, + QueryWithCursorPosition::cursor_marker() + ); let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); @@ -213,7 +260,7 @@ mod tests { pool.execute(setup).await.unwrap(); assert_complete_results( - format!("update {}", CURSOR_POS).as_str(), + format!("update {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![CompletionAssertion::LabelAndKind( "public".into(), CompletionItemKind::Schema, @@ -224,7 +271,7 @@ mod tests { .await; assert_complete_results( - format!("update public.{}", CURSOR_POS).as_str(), + format!("update public.{}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![CompletionAssertion::LabelAndKind( "coos".into(), CompletionItemKind::Table, @@ -235,14 +282,22 @@ mod tests { .await; assert_no_complete_results( - format!("update public.coos {}", CURSOR_POS).as_str(), + format!( + "update public.coos {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), None, &pool, ) .await; assert_complete_results( - format!("update coos set {}", CURSOR_POS).as_str(), + format!( + "update coos set {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), @@ -253,7 +308,11 @@ mod tests { .await; assert_complete_results( - format!("update coos set name = 'cool' where {}", CURSOR_POS).as_str(), + format!( + "update coos set name = 'cool' where {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), @@ -275,10 +334,15 @@ mod tests { pool.execute(setup).await.unwrap(); - assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), None, &pool).await; + assert_no_complete_results( + format!("delete {}", QueryWithCursorPosition::cursor_marker()).as_str(), + None, + &pool, + ) + .await; assert_complete_results( - format!("delete from {}", CURSOR_POS).as_str(), + format!("delete from {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("coos".into(), CompletionItemKind::Table), @@ -289,7 +353,11 @@ mod tests { .await; assert_complete_results( - format!("delete from public.{}", CURSOR_POS).as_str(), + format!( + "delete from public.{}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![CompletionAssertion::Label("coos".into())], None, &pool, @@ -297,7 +365,11 @@ mod tests { .await; assert_complete_results( - format!("delete from public.coos where {}", CURSOR_POS).as_str(), + format!( + "delete from public.coos where {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), @@ -329,7 +401,11 @@ mod tests { "#; assert_complete_results( - format!("select * from auth.users u join {}", CURSOR_POS).as_str(), + format!( + "select * from auth.users u join {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), @@ -365,7 +441,7 @@ mod tests { pool.execute(setup).await.unwrap(); assert_complete_results( - format!("alter table {}", CURSOR_POS).as_str(), + format!("alter table {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), @@ -378,7 +454,11 @@ mod tests { .await; assert_complete_results( - format!("alter table if exists {}", CURSOR_POS).as_str(), + format!( + "alter table if exists {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), @@ -391,7 +471,7 @@ mod tests { .await; assert_complete_results( - format!("drop table {}", CURSOR_POS).as_str(), + format!("drop table {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), @@ -404,7 +484,11 @@ mod tests { .await; assert_complete_results( - format!("drop table if exists {}", CURSOR_POS).as_str(), + format!( + "drop table if exists {}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), @@ -432,7 +516,7 @@ mod tests { pool.execute(setup).await.unwrap(); assert_complete_results( - format!("insert into {}", CURSOR_POS).as_str(), + format!("insert into {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), @@ -444,7 +528,11 @@ mod tests { .await; assert_complete_results( - format!("insert into auth.{}", CURSOR_POS).as_str(), + format!( + "insert into auth.{}", + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), vec![CompletionAssertion::LabelAndKind( "users".into(), CompletionItemKind::Table, @@ -458,7 +546,7 @@ mod tests { assert_complete_results( format!( "insert into {} (name, email) values ('jules', 'a@b.com');", - CURSOR_POS + QueryWithCursorPosition::cursor_marker() ) .as_str(), vec![ diff --git a/crates/pgt_completions/src/providers/triggers.rs b/crates/pgt_completions/src/providers/triggers.rs deleted file mode 100644 index 6bc04debc..000000000 --- a/crates/pgt_completions/src/providers/triggers.rs +++ /dev/null @@ -1,169 +0,0 @@ -use crate::{ - CompletionItemKind, - builder::{CompletionBuilder, PossibleCompletionItem}, - context::CompletionContext, - relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, -}; - -use super::helper::get_completion_text_with_schema_or_alias; - -pub fn complete_functions<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) { - let available_functions = &ctx.schema_cache.functions; - - for func in available_functions { - let relevance = CompletionRelevanceData::Function(func); - - let item = PossibleCompletionItem { - label: func.name.clone(), - score: CompletionScore::from(relevance.clone()), - filter: CompletionFilter::from(relevance), - description: format!("Schema: {}", func.schema), - kind: CompletionItemKind::Function, - completion_text: get_completion_text_with_schema_or_alias( - ctx, - &func.name, - &func.schema, - ), - }; - - builder.add_item(item); - } -} - -#[cfg(test)] -mod tests { - use crate::{ - CompletionItem, CompletionItemKind, complete, - test_helper::{CURSOR_POS, get_test_deps, get_test_params}, - }; - - #[tokio::test] - async fn completes_fn() { - let setup = r#" - create or replace function cool() - returns trigger - language plpgsql - security invoker - as $$ - begin - raise exception 'dont matter'; - end; - $$; - "#; - - let query = format!("select coo{}", CURSOR_POS); - - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; - let params = get_test_params(&tree, &cache, query.as_str().into()); - let results = complete(params); - - let CompletionItem { label, .. } = results - .into_iter() - .next() - .expect("Should return at least one completion item"); - - assert_eq!(label, "cool"); - } - - #[tokio::test] - async fn prefers_fn_if_invocation() { - let setup = r#" - create table coos ( - id serial primary key, - name text - ); - - create or replace function cool() - returns trigger - language plpgsql - security invoker - as $$ - begin - raise exception 'dont matter'; - end; - $$; - "#; - - let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; - let params = get_test_params(&tree, &cache, query.as_str().into()); - let results = complete(params); - - let CompletionItem { label, kind, .. } = results - .into_iter() - .next() - .expect("Should return at least one completion item"); - - assert_eq!(label, "cool"); - assert_eq!(kind, CompletionItemKind::Function); - } - - #[tokio::test] - async fn prefers_fn_in_select_clause() { - let setup = r#" - create table coos ( - id serial primary key, - name text - ); - - create or replace function cool() - returns trigger - language plpgsql - security invoker - as $$ - begin - raise exception 'dont matter'; - end; - $$; - "#; - - let query = format!(r#"select coo{}"#, CURSOR_POS); - - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; - let params = get_test_params(&tree, &cache, query.as_str().into()); - let results = complete(params); - - let CompletionItem { label, kind, .. } = results - .into_iter() - .next() - .expect("Should return at least one completion item"); - - assert_eq!(label, "cool"); - assert_eq!(kind, CompletionItemKind::Function); - } - - #[tokio::test] - async fn prefers_function_in_from_clause_if_invocation() { - let setup = r#" - create table coos ( - id serial primary key, - name text - ); - - create or replace function cool() - returns trigger - language plpgsql - security invoker - as $$ - begin - raise exception 'dont matter'; - end; - $$; - "#; - - let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; - let params = get_test_params(&tree, &cache, query.as_str().into()); - let results = complete(params); - - let CompletionItem { label, kind, .. } = results - .into_iter() - .next() - .expect("Should return at least one completion item"); - - assert_eq!(label, "cool"); - assert_eq!(kind, CompletionItemKind::Function); - } -} diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index beea6ddb8..18e3d7ce5 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,6 +1,6 @@ use pgt_schema_cache::ProcKind; -use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause, WrappingNode}; +use pgt_treesitter::context::{NodeUnderCursor, TreesitterContext, WrappingClause, WrappingNode}; use super::CompletionRelevanceData; @@ -16,7 +16,7 @@ impl<'a> From> for CompletionFilter<'a> { } impl CompletionFilter<'_> { - pub fn is_relevant(&self, ctx: &CompletionContext) -> Option<()> { + pub fn is_relevant(&self, ctx: &TreesitterContext) -> Option<()> { self.completable_context(ctx)?; self.check_clause(ctx)?; self.check_invocation(ctx)?; @@ -25,7 +25,7 @@ impl CompletionFilter<'_> { Some(()) } - fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { + fn completable_context(&self, ctx: &TreesitterContext) -> Option<()> { if ctx.wrapping_node_kind.is_none() && ctx.wrapping_clause_type.is_none() { return None; } @@ -70,7 +70,7 @@ impl CompletionFilter<'_> { Some(()) } - fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { + fn check_clause(&self, ctx: &TreesitterContext) -> Option<()> { ctx.wrapping_clause_type .as_ref() .map(|clause| { @@ -208,7 +208,7 @@ impl CompletionFilter<'_> { .and_then(|is_ok| if is_ok { Some(()) } else { None }) } - fn check_invocation(&self, ctx: &CompletionContext) -> Option<()> { + fn check_invocation(&self, ctx: &TreesitterContext) -> Option<()> { if !ctx.is_invocation { return Some(()); } @@ -221,7 +221,7 @@ impl CompletionFilter<'_> { Some(()) } - fn check_mentioned_schema_or_alias(&self, ctx: &CompletionContext) -> Option<()> { + fn check_mentioned_schema_or_alias(&self, ctx: &TreesitterContext) -> Option<()> { if ctx.schema_or_alias_name.is_none() { return Some(()); } @@ -255,9 +255,11 @@ mod tests { use sqlx::{Executor, PgPool}; use crate::test_helper::{ - CURSOR_POS, CompletionAssertion, assert_complete_results, assert_no_complete_results, + CompletionAssertion, assert_complete_results, assert_no_complete_results, }; + use pgt_test_utils::QueryWithCursorPosition; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn completion_after_asterisk(pool: PgPool) { let setup = r#" @@ -270,11 +272,16 @@ mod tests { pool.execute(setup).await.unwrap(); - assert_no_complete_results(format!("select * {}", CURSOR_POS).as_str(), None, &pool).await; + assert_no_complete_results( + format!("select * {}", QueryWithCursorPosition::cursor_marker()).as_str(), + None, + &pool, + ) + .await; // if there s a COMMA after the asterisk, we're good assert_complete_results( - format!("select *, {}", CURSOR_POS).as_str(), + format!("select *, {}", QueryWithCursorPosition::cursor_marker()).as_str(), vec![ CompletionAssertion::Label("address".into()), CompletionAssertion::Label("email".into()), @@ -288,13 +295,20 @@ mod tests { #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn completion_after_create_table(pool: PgPool) { - assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), None, &pool) - .await; + assert_no_complete_results( + format!("create table {}", QueryWithCursorPosition::cursor_marker()).as_str(), + None, + &pool, + ) + .await; } #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn completion_in_column_definitions(pool: PgPool) { - let query = format!(r#"create table instruments ( {} )"#, CURSOR_POS); + let query = format!( + r#"create table instruments ( {} )"#, + QueryWithCursorPosition::cursor_marker() + ); assert_no_complete_results(query.as_str(), None, &pool).await; } } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index a0b5efa53..4bbf325f4 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -1,6 +1,8 @@ use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; -use crate::context::{CompletionContext, WrappingClause, WrappingNode}; +use pgt_treesitter::context::{TreesitterContext, WrappingClause, WrappingNode}; + +use crate::sanitization; use super::CompletionRelevanceData; @@ -24,7 +26,7 @@ impl CompletionScore<'_> { self.score } - pub fn calc_score(&mut self, ctx: &CompletionContext) { + pub fn calc_score(&mut self, ctx: &TreesitterContext) { self.check_is_user_defined(); self.check_matches_schema(ctx); self.check_matches_query_input(ctx); @@ -35,10 +37,10 @@ impl CompletionScore<'_> { self.check_columns_in_stmt(ctx); } - fn check_matches_query_input(&mut self, ctx: &CompletionContext) { + fn check_matches_query_input(&mut self, ctx: &TreesitterContext) { let content = match ctx.get_node_under_cursor_content() { - Some(c) => c.replace('"', ""), - None => return, + Some(c) if !sanitization::is_sanitized_token(c.as_str()) => c.replace('"', ""), + _ => return, }; let name = match self.data { @@ -69,7 +71,7 @@ impl CompletionScore<'_> { } } - fn check_matching_clause_type(&mut self, ctx: &CompletionContext) { + fn check_matching_clause_type(&mut self, ctx: &TreesitterContext) { let clause_type = match ctx.wrapping_clause_type.as_ref() { None => return, Some(ct) => ct, @@ -135,14 +137,16 @@ impl CompletionScore<'_> { } } - fn check_matching_wrapping_node(&mut self, ctx: &CompletionContext) { + fn check_matching_wrapping_node(&mut self, ctx: &TreesitterContext) { let wrapping_node = match ctx.wrapping_node_kind.as_ref() { None => return, Some(wn) => wn, }; let has_mentioned_schema = ctx.schema_or_alias_name.is_some(); - let has_node_text = ctx.get_node_under_cursor_content().is_some(); + let has_node_text = ctx + .get_node_under_cursor_content() + .is_some_and(|txt| !sanitization::is_sanitized_token(txt.as_str())); self.score += match self.data { CompletionRelevanceData::Table(_) => match wrapping_node { @@ -170,7 +174,7 @@ impl CompletionScore<'_> { } } - fn check_is_invocation(&mut self, ctx: &CompletionContext) { + fn check_is_invocation(&mut self, ctx: &TreesitterContext) { self.score += match self.data { CompletionRelevanceData::Function(_) if ctx.is_invocation => 30, CompletionRelevanceData::Function(_) if !ctx.is_invocation => -10, @@ -179,7 +183,7 @@ impl CompletionScore<'_> { }; } - fn check_matches_schema(&mut self, ctx: &CompletionContext) { + fn check_matches_schema(&mut self, ctx: &TreesitterContext) { let schema_name = match ctx.schema_or_alias_name.as_ref() { None => return, Some(n) => n, @@ -228,7 +232,7 @@ impl CompletionScore<'_> { } } - fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) { + fn check_relations_in_stmt(&mut self, ctx: &TreesitterContext) { match self.data { CompletionRelevanceData::Table(_) | CompletionRelevanceData::Function(_) => return, _ => {} @@ -312,7 +316,7 @@ impl CompletionScore<'_> { } } - fn check_columns_in_stmt(&mut self, ctx: &CompletionContext) { + fn check_columns_in_stmt(&mut self, ctx: &TreesitterContext) { if let CompletionRelevanceData::Column(column) = self.data { /* * Columns can be mentioned in one of two ways: diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index bf4d98160..155256c8a 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -23,6 +23,10 @@ pub(crate) fn remove_sanitized_token(it: &str) -> String { it.replace(SANITIZED_TOKEN, "") } +pub(crate) fn is_sanitized_token(txt: &str) -> bool { + txt == SANITIZED_TOKEN +} + #[derive(PartialEq, Eq, Debug)] pub(crate) enum NodeText { Replaced, @@ -118,10 +122,6 @@ where tree: Cow::Borrowed(params.tree), } } - - pub fn is_sanitized_token(txt: &str) -> bool { - txt == SANITIZED_TOKEN - } } /// Checks if the cursor is positioned inbetween two SQL nodes. diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 1bd5229ca..e6c347614 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -1,40 +1,12 @@ -use std::fmt::Display; - use pgt_schema_cache::SchemaCache; +use pgt_test_utils::QueryWithCursorPosition; use sqlx::{Executor, PgPool}; use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete}; -pub static CURSOR_POS: char = '€'; - -#[derive(Clone)] -pub struct InputQuery { - sql: String, - position: usize, -} - -impl From<&str> for InputQuery { - fn from(value: &str) -> Self { - let position = value - .find(CURSOR_POS) - .expect("Insert Cursor Position into your Query."); - - InputQuery { - sql: value.replace(CURSOR_POS, "").trim().to_string(), - position, - } - } -} - -impl Display for InputQuery { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.sql) - } -} - pub(crate) async fn get_test_deps( setup: Option<&str>, - input: InputQuery, + input: QueryWithCursorPosition, test_db: &PgPool, ) -> (tree_sitter::Tree, pgt_schema_cache::SchemaCache) { if let Some(setup) = setup { @@ -63,7 +35,7 @@ pub(crate) async fn get_test_deps( #[allow(dead_code)] pub(crate) async fn test_against_connection_string( conn_str: &str, - input: InputQuery, + input: QueryWithCursorPosition, ) -> (tree_sitter::Tree, pgt_schema_cache::SchemaCache) { let pool = sqlx::PgPool::connect(conn_str) .await @@ -83,16 +55,12 @@ pub(crate) async fn test_against_connection_string( (tree, schema_cache) } -pub(crate) fn get_text_and_position(q: InputQuery) -> (usize, String) { - (q.position, q.sql) -} - pub(crate) fn get_test_params<'a>( tree: &'a tree_sitter::Tree, schema_cache: &'a pgt_schema_cache::SchemaCache, - sql: InputQuery, + sql: QueryWithCursorPosition, ) -> CompletionParams<'a> { - let (position, text) = get_text_and_position(sql); + let (position, text) = sql.get_text_and_position(); CompletionParams { position: (position as u32).into(), @@ -102,46 +70,6 @@ pub(crate) fn get_test_params<'a>( } } -#[cfg(test)] -mod tests { - use crate::test_helper::CURSOR_POS; - - use super::InputQuery; - - #[test] - fn input_query_should_extract_correct_position() { - struct TestCase { - query: String, - expected_pos: usize, - expected_sql_len: usize, - } - - let cases = vec![ - TestCase { - query: format!("select * from{}", CURSOR_POS), - expected_pos: 13, - expected_sql_len: 13, - }, - TestCase { - query: format!("{}select * from", CURSOR_POS), - expected_pos: 0, - expected_sql_len: 13, - }, - TestCase { - query: format!("select {} from", CURSOR_POS), - expected_pos: 7, - expected_sql_len: 12, - }, - ]; - - for case in cases { - let query = InputQuery::from(case.query.as_str()); - assert_eq!(query.position, case.expected_pos); - assert_eq!(query.sql.len(), case.expected_sql_len); - } - } -} - #[derive(Debug, PartialEq, Eq)] pub(crate) enum CompletionAssertion { Label(String), diff --git a/crates/pgt_test_utils/src/lib.rs b/crates/pgt_test_utils/src/lib.rs index e21c6ce4b..11bb1aebe 100644 --- a/crates/pgt_test_utils/src/lib.rs +++ b/crates/pgt_test_utils/src/lib.rs @@ -1 +1,85 @@ +use std::fmt::Display; + pub static MIGRATIONS: sqlx::migrate::Migrator = sqlx::migrate!("./testdb_migrations"); + +static CURSOR_POS: char = '€'; + +#[derive(Clone)] +pub struct QueryWithCursorPosition { + sql: String, + position: usize, +} + +impl QueryWithCursorPosition { + pub fn cursor_marker() -> char { + CURSOR_POS + } + + pub fn get_text_and_position(&self) -> (usize, String) { + (self.position, self.sql.clone()) + } +} + +impl From for QueryWithCursorPosition { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<&str> for QueryWithCursorPosition { + fn from(value: &str) -> Self { + let position = value + .find(CURSOR_POS) + .expect("Use `QueryWithCursorPosition::cursor_marker()` to insert cursor position into your Query."); + + QueryWithCursorPosition { + sql: value.replace(CURSOR_POS, "").trim().to_string(), + position, + } + } +} + +impl Display for QueryWithCursorPosition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.sql) + } +} + +#[cfg(test)] +mod tests { + + use super::QueryWithCursorPosition; + + #[test] + fn input_query_should_extract_correct_position() { + struct TestCase { + query: String, + expected_pos: usize, + expected_sql_len: usize, + } + + let cases = vec![ + TestCase { + query: format!("select * from{}", QueryWithCursorPosition::cursor_marker()), + expected_pos: 13, + expected_sql_len: 13, + }, + TestCase { + query: format!("{}select * from", QueryWithCursorPosition::cursor_marker()), + expected_pos: 0, + expected_sql_len: 13, + }, + TestCase { + query: format!("select {} from", QueryWithCursorPosition::cursor_marker()), + expected_pos: 7, + expected_sql_len: 12, + }, + ]; + + for case in cases { + let query = QueryWithCursorPosition::from(case.query.as_str()); + assert_eq!(query.position, case.expected_pos); + assert_eq!(query.sql.len(), case.expected_sql_len); + } + } +} diff --git a/crates/pgt_treesitter_queries/Cargo.toml b/crates/pgt_treesitter/Cargo.toml similarity index 54% rename from crates/pgt_treesitter_queries/Cargo.toml rename to crates/pgt_treesitter/Cargo.toml index 5806861f5..f2d8b46e1 100644 --- a/crates/pgt_treesitter_queries/Cargo.toml +++ b/crates/pgt_treesitter/Cargo.toml @@ -6,17 +6,20 @@ edition.workspace = true homepage.workspace = true keywords.workspace = true license.workspace = true -name = "pgt_treesitter_queries" +name = "pgt_treesitter" repository.workspace = true version = "0.0.0" [dependencies] -clap = { version = "4.5.23", features = ["derive"] } -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +clap = { version = "4.5.23", features = ["derive"] } +pgt_schema_cache.workspace = true +pgt_text_size.workspace = true +tree-sitter.workspace = true +tree_sitter_sql.workspace = true [dev-dependencies] +pgt_test_utils.workspace = true [lib] doctest = false diff --git a/crates/pgt_completions/src/context/base_parser.rs b/crates/pgt_treesitter/src/context/base_parser.rs similarity index 100% rename from crates/pgt_completions/src/context/base_parser.rs rename to crates/pgt_treesitter/src/context/base_parser.rs diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_treesitter/src/context/grant_parser.rs similarity index 94% rename from crates/pgt_completions/src/context/grant_parser.rs rename to crates/pgt_treesitter/src/context/grant_parser.rs index 14ba882ae..c9aebc33b 100644 --- a/crates/pgt_completions/src/context/grant_parser.rs +++ b/crates/pgt_treesitter/src/context/grant_parser.rs @@ -187,14 +187,15 @@ mod tests { use crate::{ context::base_parser::CompletionStatementParser, context::grant_parser::{GrantContext, GrantParser}, - test_helper::CURSOR_POS, }; + use pgt_test_utils::QueryWithCursorPosition; + fn with_pos(query: String) -> (usize, String) { let mut pos: Option = None; for (p, c) in query.char_indices() { - if c == CURSOR_POS { + if c == QueryWithCursorPosition::cursor_marker() { pos = Some(p); break; } @@ -202,7 +203,9 @@ mod tests { ( pos.expect("Please add cursor position!"), - query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + query + .replace(QueryWithCursorPosition::cursor_marker(), "REPLACED_TOKEN") + .to_string(), ) } @@ -212,7 +215,7 @@ mod tests { r#" grant {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -235,7 +238,7 @@ mod tests { r#" grant select on {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -258,7 +261,7 @@ mod tests { r#" grant select on table {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -281,7 +284,7 @@ mod tests { r#" grant select on public.{} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -304,7 +307,7 @@ mod tests { r#" grant select on table public.{} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -327,7 +330,7 @@ mod tests { r#" grant select on public.users to {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -350,7 +353,7 @@ mod tests { r#" grant select on public.{} to test_role "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -373,7 +376,7 @@ mod tests { r#" grant select on "MySchema"."MyTable" to {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); @@ -396,7 +399,7 @@ mod tests { r#" grant select on public.users to alice, {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = GrantParser::get_context(query.as_str(), pos); diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_treesitter/src/context/mod.rs similarity index 78% rename from crates/pgt_completions/src/context/mod.rs rename to crates/pgt_treesitter/src/context/mod.rs index 01e563b0a..9cfaadea1 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_treesitter/src/context/mod.rs @@ -7,22 +7,14 @@ mod grant_parser; mod policy_parser; mod revoke_parser; -use pgt_schema_cache::SchemaCache; -use pgt_text_size::TextRange; -use pgt_treesitter_queries::{ - TreeSitterQueriesExecutor, - queries::{self, QueryResult}, -}; - -use crate::{ - NodeText, - context::{ - base_parser::CompletionStatementParser, - grant_parser::GrantParser, - policy_parser::{PolicyParser, PolicyStmtKind}, - revoke_parser::RevokeParser, - }, - sanitization::SanitizedCompletionParams, +use crate::queries::{self, QueryResult, TreeSitterQueriesExecutor}; +use pgt_text_size::{TextRange, TextSize}; + +use crate::context::{ + base_parser::CompletionStatementParser, + grant_parser::GrantParser, + policy_parser::{PolicyParser, PolicyStmtKind}, + revoke_parser::RevokeParser, }; #[derive(Debug, PartialEq, Eq, Hash, Clone)] @@ -59,9 +51,9 @@ pub enum WrappingClause<'a> { } #[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub(crate) struct MentionedColumn { - pub(crate) column: String, - pub(crate) alias: Option, +pub struct MentionedColumn { + pub column: String, + pub alias: Option, } /// We can map a few nodes, such as the "update" node, to actual SQL clauses. @@ -81,10 +73,10 @@ pub enum WrappingNode { } #[derive(Debug)] -pub(crate) enum NodeUnderCursor<'a> { +pub enum NodeUnderCursor<'a> { TsNode(tree_sitter::Node<'a>), CustomNode { - text: NodeText, + text: String, range: TextRange, kind: String, previous_node_kind: Option, @@ -150,13 +142,18 @@ impl TryFrom for WrappingNode { } } +pub struct TreeSitterContextParams<'a> { + pub position: TextSize, + pub text: &'a str, + pub tree: &'a tree_sitter::Tree, +} + #[derive(Debug)] -pub(crate) struct CompletionContext<'a> { +pub struct TreesitterContext<'a> { pub node_under_cursor: Option>, pub tree: &'a tree_sitter::Tree, pub text: &'a str, - pub schema_cache: &'a SchemaCache, pub position: usize, /// If the cursor is on a node that uses dot notation @@ -178,6 +175,7 @@ pub(crate) struct CompletionContext<'a> { /// on u.id = i.user_id; /// ``` pub schema_or_alias_name: Option, + pub wrapping_clause_type: Option>, pub wrapping_node_kind: Option, @@ -190,12 +188,11 @@ pub(crate) struct CompletionContext<'a> { pub mentioned_columns: HashMap>, HashSet>, } -impl<'a> CompletionContext<'a> { - pub fn new(params: &'a SanitizedCompletionParams) -> Self { +impl<'a> TreesitterContext<'a> { + pub fn new(params: TreeSitterContextParams<'a>) -> Self { let mut ctx = Self { - tree: params.tree.as_ref(), - text: ¶ms.text, - schema_cache: params.schema, + tree: params.tree, + text: params.text, position: usize::from(params.position), node_under_cursor: None, schema_or_alias_name: None, @@ -211,11 +208,11 @@ impl<'a> CompletionContext<'a> { // policy handling is important to Supabase, but they are a PostgreSQL specific extension, // so the tree_sitter_sql language does not support it. // We infer the context manually. - if PolicyParser::looks_like_matching_stmt(¶ms.text) { + if PolicyParser::looks_like_matching_stmt(params.text) { ctx.gather_policy_context(); - } else if GrantParser::looks_like_matching_stmt(¶ms.text) { + } else if GrantParser::looks_like_matching_stmt(params.text) { ctx.gather_grant_context(); - } else if RevokeParser::looks_like_matching_stmt(¶ms.text) { + } else if RevokeParser::looks_like_matching_stmt(params.text) { ctx.gather_revoke_context(); } else { ctx.gather_tree_context(); @@ -229,7 +226,7 @@ impl<'a> CompletionContext<'a> { let revoke_context = RevokeParser::get_context(self.text, self.position); self.node_under_cursor = Some(NodeUnderCursor::CustomNode { - text: revoke_context.node_text.into(), + text: revoke_context.node_text, range: revoke_context.node_range, kind: revoke_context.node_kind.clone(), previous_node_kind: None, @@ -257,7 +254,7 @@ impl<'a> CompletionContext<'a> { let grant_context = GrantParser::get_context(self.text, self.position); self.node_under_cursor = Some(NodeUnderCursor::CustomNode { - text: grant_context.node_text.into(), + text: grant_context.node_text, range: grant_context.node_range, kind: grant_context.node_kind.clone(), previous_node_kind: None, @@ -285,7 +282,7 @@ impl<'a> CompletionContext<'a> { let policy_context = PolicyParser::get_context(self.text, self.position); self.node_under_cursor = Some(NodeUnderCursor::CustomNode { - text: policy_context.node_text.into(), + text: policy_context.node_text, range: policy_context.node_range, kind: policy_context.node_kind.clone(), previous_node_kind: Some(policy_context.previous_node_kind), @@ -397,29 +394,18 @@ impl<'a> CompletionContext<'a> { } } - fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option { + fn get_ts_node_content(&self, ts_node: &tree_sitter::Node<'a>) -> Option { let source = self.text; - ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { - if SanitizedCompletionParams::is_sanitized_token(txt) { - NodeText::Replaced - } else { - NodeText::Original(txt.into()) - } - }) + ts_node + .utf8_text(source.as_bytes()) + .ok() + .map(|txt| txt.into()) } pub fn get_node_under_cursor_content(&self) -> Option { match self.node_under_cursor.as_ref()? { - NodeUnderCursor::TsNode(node) => { - self.get_ts_node_content(node).and_then(|nt| match nt { - NodeText::Replaced => None, - NodeText::Original(c) => Some(c.to_string()), - }) - } - NodeUnderCursor::CustomNode { text, .. } => match text { - NodeText::Replaced => None, - NodeText::Original(c) => Some(c.to_string()), - }, + NodeUnderCursor::TsNode(node) => self.get_ts_node_content(node), + NodeUnderCursor::CustomNode { text, .. } => Some(text.clone()), } } @@ -501,15 +487,10 @@ impl<'a> CompletionContext<'a> { match current_node_kind { "object_reference" | "field" => { let content = self.get_ts_node_content(¤t_node); - if let Some(node_txt) = content { - match node_txt { - NodeText::Original(txt) => { - let parts: Vec<&str> = txt.split('.').collect(); - if parts.len() == 2 { - self.schema_or_alias_name = Some(parts[0].to_string()); - } - } - NodeText::Replaced => {} + if let Some(txt) = content { + let parts: Vec<&str> = txt.split('.').collect(); + if parts.len() == 2 { + self.schema_or_alias_name = Some(parts[0].to_string()); } } } @@ -638,12 +619,7 @@ impl<'a> CompletionContext<'a> { break; } - if let Some(sibling_content) = - self.get_ts_node_content(&sib).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) - { + if let Some(sibling_content) = self.get_ts_node_content(&sib) { if sibling_content == tokens[idx] { idx += 1; } @@ -674,9 +650,7 @@ impl<'a> CompletionContext<'a> { while let Some(sib) = first_sibling.next_sibling() { match sib.kind() { "object_reference" => { - if let Some(NodeText::Original(txt)) = - self.get_ts_node_content(&sib) - { + if let Some(txt) = self.get_ts_node_content(&sib) { let mut iter = txt.split('.').rev(); let table = iter.next().unwrap().to_string(); let schema = iter.next().map(|s| s.to_string()); @@ -690,9 +664,7 @@ impl<'a> CompletionContext<'a> { } "column" => { - if let Some(NodeText::Original(txt)) = - self.get_ts_node_content(&sib) - { + if let Some(txt) = self.get_ts_node_content(&sib) { let entry = MentionedColumn { column: txt, alias: None, @@ -717,7 +689,7 @@ impl<'a> CompletionContext<'a> { WrappingClause::AlterColumn => { while let Some(sib) = first_sibling.next_sibling() { if sib.kind() == "object_reference" { - if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { + if let Some(txt) = self.get_ts_node_content(&sib) { let mut iter = txt.split('.').rev(); let table = iter.next().unwrap().to_string(); let schema = iter.next().map(|s| s.to_string()); @@ -777,7 +749,7 @@ impl<'a> CompletionContext<'a> { } } - pub(crate) fn parent_matches_one_of_kind(&self, kinds: &[&'static str]) -> bool { + pub fn parent_matches_one_of_kind(&self, kinds: &[&'static str]) -> bool { self.node_under_cursor .as_ref() .is_some_and(|under_cursor| match under_cursor { @@ -788,7 +760,7 @@ impl<'a> CompletionContext<'a> { NodeUnderCursor::CustomNode { .. } => false, }) } - pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool { + pub fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool { self.node_under_cursor.as_ref().is_some_and(|under_cursor| { match under_cursor { NodeUnderCursor::TsNode(node) => { @@ -816,12 +788,9 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { - use crate::{ - NodeText, - context::{CompletionContext, WrappingClause}, - sanitization::SanitizedCompletionParams, - test_helper::{CURSOR_POS, get_text_and_position}, - }; + use crate::context::{TreeSitterContextParams, TreesitterContext, WrappingClause}; + + use pgt_test_utils::QueryWithCursorPosition; use super::NodeUnderCursor; @@ -838,56 +807,82 @@ mod tests { fn identifies_clauses() { let test_cases = vec![ ( - format!("Select {}* from users;", CURSOR_POS), + format!( + "Select {}* from users;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::Select, ), ( - format!("Select * from u{};", CURSOR_POS), + format!( + "Select * from u{};", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::From, ), ( - format!("Select {}* from users where n = 1;", CURSOR_POS), + format!( + "Select {}* from users where n = 1;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::Select, ), ( - format!("Select * from users where {}n = 1;", CURSOR_POS), + format!( + "Select * from users where {}n = 1;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::Where, ), ( - format!("update users set u{} = 1 where n = 2;", CURSOR_POS), + format!( + "update users set u{} = 1 where n = 2;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::Update, ), ( - format!("update users set u = 1 where n{} = 2;", CURSOR_POS), + format!( + "update users set u = 1 where n{} = 2;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::Where, ), ( - format!("delete{} from users;", CURSOR_POS), + format!( + "delete{} from users;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::Delete, ), ( - format!("delete from {}users;", CURSOR_POS), + format!( + "delete from {}users;", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::From, ), ( - format!("select name, age, location from public.u{}sers", CURSOR_POS), + format!( + "select name, age, location from public.u{}sers", + QueryWithCursorPosition::cursor_marker() + ), WrappingClause::From, ), ]; for (query, expected_clause) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: &text, + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); assert_eq!(ctx.wrapping_clause_type, Some(expected_clause)); } @@ -897,29 +892,46 @@ mod tests { fn identifies_schema() { let test_cases = vec![ ( - format!("Select * from private.u{}", CURSOR_POS), + format!( + "Select * from private.u{}", + QueryWithCursorPosition::cursor_marker() + ), Some("private"), ), ( - format!("Select * from private.u{}sers()", CURSOR_POS), + format!( + "Select * from private.u{}sers()", + QueryWithCursorPosition::cursor_marker() + ), Some("private"), ), - (format!("Select * from u{}sers", CURSOR_POS), None), - (format!("Select * from u{}sers()", CURSOR_POS), None), + ( + format!( + "Select * from u{}sers", + QueryWithCursorPosition::cursor_marker() + ), + None, + ), + ( + format!( + "Select * from u{}sers()", + QueryWithCursorPosition::cursor_marker() + ), + None, + ), ]; for (query, expected_schema) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: &text, + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); assert_eq!( ctx.schema_or_alias_name, @@ -931,32 +943,55 @@ mod tests { #[test] fn identifies_invocation() { let test_cases = vec![ - (format!("Select * from u{}sers", CURSOR_POS), false), - (format!("Select * from u{}sers()", CURSOR_POS), true), - (format!("Select cool{};", CURSOR_POS), false), - (format!("Select cool{}();", CURSOR_POS), true), ( - format!("Select upp{}ercase as title from users;", CURSOR_POS), + format!( + "Select * from u{}sers", + QueryWithCursorPosition::cursor_marker() + ), + false, + ), + ( + format!( + "Select * from u{}sers()", + QueryWithCursorPosition::cursor_marker() + ), + true, + ), + ( + format!("Select cool{};", QueryWithCursorPosition::cursor_marker()), + false, + ), + ( + format!("Select cool{}();", QueryWithCursorPosition::cursor_marker()), + true, + ), + ( + format!( + "Select upp{}ercase as title from users;", + QueryWithCursorPosition::cursor_marker() + ), false, ), ( - format!("Select upp{}ercase(name) as title from users;", CURSOR_POS), + format!( + "Select upp{}ercase(name) as title from users;", + QueryWithCursorPosition::cursor_marker() + ), true, ), ]; for (query, is_invocation) in test_cases { - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: text.as_str(), + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); assert_eq!(ctx.is_invocation, is_invocation); } @@ -965,32 +1000,34 @@ mod tests { #[test] fn does_not_fail_on_leading_whitespace() { let cases = vec![ - format!("{} select * from", CURSOR_POS), - format!(" {} select * from", CURSOR_POS), + format!( + "{} select * from", + QueryWithCursorPosition::cursor_marker() + ), + format!( + " {} select * from", + QueryWithCursorPosition::cursor_marker() + ), ]; for query in cases { - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: &text, + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); let node = ctx.node_under_cursor.as_ref().unwrap(); match node { NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("select".into())) - ); + assert_eq!(ctx.get_ts_node_content(node), Some("select".into())); assert_eq!( ctx.wrapping_clause_type, @@ -1004,29 +1041,28 @@ mod tests { #[test] fn does_not_fail_on_trailing_whitespace() { - let query = format!("select * from {}", CURSOR_POS); + let query = format!( + "select * from {}", + QueryWithCursorPosition::cursor_marker() + ); - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: &text, + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); let node = ctx.node_under_cursor.as_ref().unwrap(); match node { NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("from".into())) - ); + assert_eq!(ctx.get_ts_node_content(node), Some("from".into())); } _ => unreachable!(), } @@ -1034,29 +1070,25 @@ mod tests { #[test] fn does_not_fail_with_empty_statements() { - let query = format!("{}", CURSOR_POS); + let query = format!("{}", QueryWithCursorPosition::cursor_marker()); - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: &text, + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); let node = ctx.node_under_cursor.as_ref().unwrap(); match node { NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("".into())) - ); + assert_eq!(ctx.get_ts_node_content(node), Some("".into())); assert_eq!(ctx.wrapping_clause_type, None); } _ => unreachable!(), @@ -1067,29 +1099,25 @@ mod tests { fn does_not_fail_on_incomplete_keywords() { // Instead of autocompleting "FROM", we'll assume that the user // is selecting a certain column name, such as `frozen_account`. - let query = format!("select * fro{}", CURSOR_POS); + let query = format!("select * fro{}", QueryWithCursorPosition::cursor_marker()); - let (position, text) = get_text_and_position(query.as_str().into()); + let (position, text) = QueryWithCursorPosition::from(query).get_text_and_position(); let tree = get_tree(text.as_str()); - let params = SanitizedCompletionParams { + let params = TreeSitterContextParams { position: (position as u32).into(), - text, - tree: std::borrow::Cow::Owned(tree), - schema: &pgt_schema_cache::SchemaCache::default(), + text: &text, + tree: &tree, }; - let ctx = CompletionContext::new(¶ms); + let ctx = TreesitterContext::new(params); let node = ctx.node_under_cursor.as_ref().unwrap(); match node { NodeUnderCursor::TsNode(node) => { - assert_eq!( - ctx.get_ts_node_content(node), - Some(NodeText::Original("fro".into())) - ); + assert_eq!(ctx.get_ts_node_content(node), Some("fro".into())); assert_eq!(ctx.wrapping_clause_type, Some(WrappingClause::Select)); } _ => unreachable!(), diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_treesitter/src/context/policy_parser.rs similarity index 95% rename from crates/pgt_completions/src/context/policy_parser.rs rename to crates/pgt_treesitter/src/context/policy_parser.rs index bcc604990..776645163 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_treesitter/src/context/policy_parser.rs @@ -212,16 +212,17 @@ mod tests { use crate::{ context::base_parser::CompletionStatementParser, context::policy_parser::{PolicyContext, PolicyStmtKind}, - test_helper::CURSOR_POS, }; + use pgt_test_utils::QueryWithCursorPosition; + use super::PolicyParser; fn with_pos(query: String) -> (usize, String) { let mut pos: Option = None; for (p, c) in query.char_indices() { - if c == CURSOR_POS { + if c == QueryWithCursorPosition::cursor_marker() { pos = Some(p); break; } @@ -229,7 +230,9 @@ mod tests { ( pos.expect("Please add cursor position!"), - query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + query + .replace(QueryWithCursorPosition::cursor_marker(), "REPLACED_TOKEN") + .to_string(), ) } @@ -239,7 +242,7 @@ mod tests { r#" create policy {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -265,7 +268,7 @@ mod tests { r#" create policy "my cool policy" {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -291,7 +294,7 @@ mod tests { r#" create policy "my cool policy" on {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -317,7 +320,7 @@ mod tests { r#" create policy "my cool policy" on auth.{} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -344,7 +347,7 @@ mod tests { create policy "my cool policy" on auth.users as {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -372,7 +375,7 @@ mod tests { as permissive {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -400,7 +403,7 @@ mod tests { as permissive to {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -432,7 +435,7 @@ mod tests { to all using (true); "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -464,7 +467,7 @@ mod tests { to all using (true); "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -493,7 +496,7 @@ mod tests { r#" drop policy {} on auth.users; "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -520,7 +523,7 @@ mod tests { r#" drop policy "{}" on auth.users; "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -549,7 +552,7 @@ mod tests { r#" drop policy "{} on auth.users; "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -567,7 +570,7 @@ mod tests { to all using (id = {}) "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -598,7 +601,7 @@ mod tests { to all using ({} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); @@ -629,7 +632,7 @@ mod tests { to all with check ({} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = PolicyParser::get_context(query.as_str(), pos); diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_treesitter/src/context/revoke_parser.rs similarity index 94% rename from crates/pgt_completions/src/context/revoke_parser.rs rename to crates/pgt_treesitter/src/context/revoke_parser.rs index e0c43934c..4f5b09ec8 100644 --- a/crates/pgt_completions/src/context/revoke_parser.rs +++ b/crates/pgt_treesitter/src/context/revoke_parser.rs @@ -180,14 +180,15 @@ mod tests { use crate::{ context::base_parser::CompletionStatementParser, context::revoke_parser::{RevokeContext, RevokeParser}, - test_helper::CURSOR_POS, }; + use pgt_test_utils::QueryWithCursorPosition; + fn with_pos(query: String) -> (usize, String) { let mut pos: Option = None; for (p, c) in query.char_indices() { - if c == CURSOR_POS { + if c == QueryWithCursorPosition::cursor_marker() { pos = Some(p); break; } @@ -195,7 +196,9 @@ mod tests { ( pos.expect("Please add cursor position!"), - query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + query + .replace(QueryWithCursorPosition::cursor_marker(), "REPLACED_TOKEN") + .to_string(), ) } @@ -205,7 +208,7 @@ mod tests { r#" revoke {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = RevokeParser::get_context(query.as_str(), pos); @@ -228,7 +231,7 @@ mod tests { r#" revoke select on {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = RevokeParser::get_context(query.as_str(), pos); @@ -251,7 +254,7 @@ mod tests { r#" revoke select on public.{} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = RevokeParser::get_context(query.as_str(), pos); @@ -274,7 +277,7 @@ mod tests { r#" revoke select on public.users from {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = RevokeParser::get_context(query.as_str(), pos); @@ -297,7 +300,7 @@ mod tests { r#" revoke select on public.users from alice, {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = RevokeParser::get_context(query.as_str(), pos); @@ -320,7 +323,7 @@ mod tests { r#" revoke select on "MySchema"."MyTable" from {} "#, - CURSOR_POS + QueryWithCursorPosition::cursor_marker() )); let context = RevokeParser::get_context(query.as_str(), pos); diff --git a/crates/pgt_treesitter/src/lib.rs b/crates/pgt_treesitter/src/lib.rs new file mode 100644 index 000000000..6b19db53a --- /dev/null +++ b/crates/pgt_treesitter/src/lib.rs @@ -0,0 +1,5 @@ +pub mod context; +pub mod queries; + +pub use context::*; +pub use queries::*; diff --git a/crates/pgt_treesitter_queries/src/queries/insert_columns.rs b/crates/pgt_treesitter/src/queries/insert_columns.rs similarity index 97% rename from crates/pgt_treesitter_queries/src/queries/insert_columns.rs rename to crates/pgt_treesitter/src/queries/insert_columns.rs index 3e88d998f..94d67b690 100644 --- a/crates/pgt_treesitter_queries/src/queries/insert_columns.rs +++ b/crates/pgt_treesitter/src/queries/insert_columns.rs @@ -1,6 +1,6 @@ use std::sync::LazyLock; -use crate::{Query, QueryResult}; +use crate::queries::{Query, QueryResult}; use super::QueryTryFrom; @@ -51,7 +51,7 @@ impl<'a> QueryTryFrom<'a> for InsertColumnMatch<'a> { } impl<'a> Query<'a> for InsertColumnMatch<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { let mut cursor = tree_sitter::QueryCursor::new(); let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); @@ -73,7 +73,7 @@ impl<'a> Query<'a> for InsertColumnMatch<'a> { #[cfg(test)] mod tests { use super::InsertColumnMatch; - use crate::TreeSitterQueriesExecutor; + use crate::queries::TreeSitterQueriesExecutor; #[test] fn finds_all_insert_columns() { diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter/src/queries/mod.rs similarity index 72% rename from crates/pgt_treesitter_queries/src/lib.rs rename to crates/pgt_treesitter/src/queries/mod.rs index 4bf71e744..1d24f07a4 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter/src/queries/mod.rs @@ -1,8 +1,91 @@ -pub mod queries; +mod insert_columns; +mod parameters; +mod relations; +mod select_columns; +mod table_aliases; +mod where_columns; use std::slice::Iter; -use queries::{Query, QueryResult}; +pub use insert_columns::*; +pub use parameters::*; +pub use relations::*; +pub use select_columns::*; +pub use table_aliases::*; +pub use where_columns::*; + +#[derive(Debug)] +pub enum QueryResult<'a> { + Relation(RelationMatch<'a>), + Parameter(ParameterMatch<'a>), + TableAliases(TableAliasMatch<'a>), + SelectClauseColumns(SelectColumnMatch<'a>), + InsertClauseColumns(InsertColumnMatch<'a>), + WhereClauseColumns(WhereColumnMatch<'a>), +} + +impl QueryResult<'_> { + pub fn within_range(&self, range: &tree_sitter::Range) -> bool { + match self { + QueryResult::Relation(rm) => { + let start = match rm.schema { + Some(s) => s.start_position(), + None => rm.table.start_position(), + }; + + let end = rm.table.end_position(); + + start >= range.start_point && end <= range.end_point + } + Self::Parameter(pm) => { + let node_range = pm.node.range(); + + node_range.start_point >= range.start_point + && node_range.end_point <= range.end_point + } + QueryResult::TableAliases(m) => { + let start = m.table.start_position(); + let end = m.alias.end_position(); + start >= range.start_point && end <= range.end_point + } + Self::SelectClauseColumns(cm) => { + let start = match cm.alias { + Some(n) => n.start_position(), + None => cm.column.start_position(), + }; + + let end = cm.column.end_position(); + + start >= range.start_point && end <= range.end_point + } + Self::WhereClauseColumns(cm) => { + let start = match cm.alias { + Some(n) => n.start_position(), + None => cm.column.start_position(), + }; + + let end = cm.column.end_position(); + + start >= range.start_point && end <= range.end_point + } + Self::InsertClauseColumns(cm) => { + let start = cm.column.start_position(); + let end = cm.column.end_position(); + start >= range.start_point && end <= range.end_point + } + } + } +} + +// This trait enforces that for any `Self` that implements `Query`, +// its &Self must implement TryFrom<&QueryResult> +pub(crate) trait QueryTryFrom<'a>: Sized { + type Ref: for<'any> TryFrom<&'a QueryResult<'a>, Error = String>; +} + +pub(crate) trait Query<'a>: QueryTryFrom<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec>; +} pub struct TreeSitterQueriesExecutor<'a> { root_node: tree_sitter::Node<'a>, @@ -68,9 +151,8 @@ impl<'a> Iterator for QueryResultIter<'a> { #[cfg(test)] mod tests { - use crate::{ - TreeSitterQueriesExecutor, - queries::{ParameterMatch, RelationMatch, TableAliasMatch}, + use crate::queries::{ + ParameterMatch, RelationMatch, TableAliasMatch, TreeSitterQueriesExecutor, }; #[test] diff --git a/crates/pgt_treesitter_queries/src/queries/parameters.rs b/crates/pgt_treesitter/src/queries/parameters.rs similarity index 96% rename from crates/pgt_treesitter_queries/src/queries/parameters.rs rename to crates/pgt_treesitter/src/queries/parameters.rs index 85ea9ad25..0b7f2e3df 100644 --- a/crates/pgt_treesitter_queries/src/queries/parameters.rs +++ b/crates/pgt_treesitter/src/queries/parameters.rs @@ -1,6 +1,6 @@ use std::sync::LazyLock; -use crate::{Query, QueryResult}; +use crate::queries::{Query, QueryResult}; use super::QueryTryFrom; @@ -59,7 +59,7 @@ impl<'a> QueryTryFrom<'a> for ParameterMatch<'a> { } impl<'a> Query<'a> for ParameterMatch<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { let mut cursor = tree_sitter::QueryCursor::new(); let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); diff --git a/crates/pgt_treesitter_queries/src/queries/relations.rs b/crates/pgt_treesitter/src/queries/relations.rs similarity index 98% rename from crates/pgt_treesitter_queries/src/queries/relations.rs rename to crates/pgt_treesitter/src/queries/relations.rs index 2d7e44317..cb6a6bea9 100644 --- a/crates/pgt_treesitter_queries/src/queries/relations.rs +++ b/crates/pgt_treesitter/src/queries/relations.rs @@ -1,6 +1,6 @@ use std::sync::LazyLock; -use crate::{Query, QueryResult}; +use crate::queries::{Query, QueryResult}; use super::QueryTryFrom; @@ -79,7 +79,7 @@ impl<'a> QueryTryFrom<'a> for RelationMatch<'a> { } impl<'a> Query<'a> for RelationMatch<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { let mut cursor = tree_sitter::QueryCursor::new(); let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); @@ -112,8 +112,9 @@ impl<'a> Query<'a> for RelationMatch<'a> { #[cfg(test)] mod tests { + use crate::queries::TreeSitterQueriesExecutor; + use super::RelationMatch; - use crate::TreeSitterQueriesExecutor; #[test] fn finds_table_without_schema() { diff --git a/crates/pgt_treesitter_queries/src/queries/select_columns.rs b/crates/pgt_treesitter/src/queries/select_columns.rs similarity index 97% rename from crates/pgt_treesitter_queries/src/queries/select_columns.rs rename to crates/pgt_treesitter/src/queries/select_columns.rs index 00b6977d0..f232abc38 100644 --- a/crates/pgt_treesitter_queries/src/queries/select_columns.rs +++ b/crates/pgt_treesitter/src/queries/select_columns.rs @@ -1,6 +1,6 @@ use std::sync::LazyLock; -use crate::{Query, QueryResult}; +use crate::queries::{Query, QueryResult}; use super::QueryTryFrom; @@ -63,7 +63,7 @@ impl<'a> QueryTryFrom<'a> for SelectColumnMatch<'a> { } impl<'a> Query<'a> for SelectColumnMatch<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { let mut cursor = tree_sitter::QueryCursor::new(); let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); @@ -96,7 +96,7 @@ impl<'a> Query<'a> for SelectColumnMatch<'a> { #[cfg(test)] mod tests { - use crate::TreeSitterQueriesExecutor; + use crate::queries::TreeSitterQueriesExecutor; use super::SelectColumnMatch; diff --git a/crates/pgt_treesitter_queries/src/queries/table_aliases.rs b/crates/pgt_treesitter/src/queries/table_aliases.rs similarity index 97% rename from crates/pgt_treesitter_queries/src/queries/table_aliases.rs rename to crates/pgt_treesitter/src/queries/table_aliases.rs index 4297a2186..70d4d52ef 100644 --- a/crates/pgt_treesitter_queries/src/queries/table_aliases.rs +++ b/crates/pgt_treesitter/src/queries/table_aliases.rs @@ -1,6 +1,6 @@ use std::sync::LazyLock; -use crate::{Query, QueryResult}; +use crate::queries::{Query, QueryResult}; use super::QueryTryFrom; @@ -69,7 +69,7 @@ impl<'a> QueryTryFrom<'a> for TableAliasMatch<'a> { } impl<'a> Query<'a> for TableAliasMatch<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { let mut cursor = tree_sitter::QueryCursor::new(); let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); diff --git a/crates/pgt_treesitter_queries/src/queries/where_columns.rs b/crates/pgt_treesitter/src/queries/where_columns.rs similarity index 97% rename from crates/pgt_treesitter_queries/src/queries/where_columns.rs rename to crates/pgt_treesitter/src/queries/where_columns.rs index 8e19590de..b683300b6 100644 --- a/crates/pgt_treesitter_queries/src/queries/where_columns.rs +++ b/crates/pgt_treesitter/src/queries/where_columns.rs @@ -1,6 +1,6 @@ use std::sync::LazyLock; -use crate::{Query, QueryResult}; +use crate::queries::{Query, QueryResult}; use super::QueryTryFrom; @@ -64,7 +64,7 @@ impl<'a> QueryTryFrom<'a> for WhereColumnMatch<'a> { } impl<'a> Query<'a> for WhereColumnMatch<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { let mut cursor = tree_sitter::QueryCursor::new(); let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs deleted file mode 100644 index b9f39aed8..000000000 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ /dev/null @@ -1,86 +0,0 @@ -mod insert_columns; -mod parameters; -mod relations; -mod select_columns; -mod table_aliases; -mod where_columns; - -pub use insert_columns::*; -pub use parameters::*; -pub use relations::*; -pub use select_columns::*; -pub use table_aliases::*; -pub use where_columns::*; - -#[derive(Debug)] -pub enum QueryResult<'a> { - Relation(RelationMatch<'a>), - Parameter(ParameterMatch<'a>), - TableAliases(TableAliasMatch<'a>), - SelectClauseColumns(SelectColumnMatch<'a>), - InsertClauseColumns(InsertColumnMatch<'a>), - WhereClauseColumns(WhereColumnMatch<'a>), -} - -impl QueryResult<'_> { - pub fn within_range(&self, range: &tree_sitter::Range) -> bool { - match self { - QueryResult::Relation(rm) => { - let start = match rm.schema { - Some(s) => s.start_position(), - None => rm.table.start_position(), - }; - - let end = rm.table.end_position(); - - start >= range.start_point && end <= range.end_point - } - Self::Parameter(pm) => { - let node_range = pm.node.range(); - - node_range.start_point >= range.start_point - && node_range.end_point <= range.end_point - } - QueryResult::TableAliases(m) => { - let start = m.table.start_position(); - let end = m.alias.end_position(); - start >= range.start_point && end <= range.end_point - } - Self::SelectClauseColumns(cm) => { - let start = match cm.alias { - Some(n) => n.start_position(), - None => cm.column.start_position(), - }; - - let end = cm.column.end_position(); - - start >= range.start_point && end <= range.end_point - } - Self::WhereClauseColumns(cm) => { - let start = match cm.alias { - Some(n) => n.start_position(), - None => cm.column.start_position(), - }; - - let end = cm.column.end_position(); - - start >= range.start_point && end <= range.end_point - } - Self::InsertClauseColumns(cm) => { - let start = cm.column.start_position(); - let end = cm.column.end_position(); - start >= range.start_point && end <= range.end_point - } - } - } -} - -// This trait enforces that for any `Self` that implements `Query`, -// its &Self must implement TryFrom<&QueryResult> -pub(crate) trait QueryTryFrom<'a>: Sized { - type Ref: for<'any> TryFrom<&'a QueryResult<'a>, Error = String>; -} - -pub(crate) trait Query<'a>: QueryTryFrom<'a> { - fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec>; -} diff --git a/crates/pgt_typecheck/Cargo.toml b/crates/pgt_typecheck/Cargo.toml index caacc6d17..175ecd596 100644 --- a/crates/pgt_typecheck/Cargo.toml +++ b/crates/pgt_typecheck/Cargo.toml @@ -12,16 +12,16 @@ version = "0.0.0" [dependencies] -pgt_console.workspace = true -pgt_diagnostics.workspace = true -pgt_query_ext.workspace = true -pgt_schema_cache.workspace = true -pgt_text_size.workspace = true -pgt_treesitter_queries.workspace = true -sqlx.workspace = true -tokio.workspace = true -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +pgt_console.workspace = true +pgt_diagnostics.workspace = true +pgt_query_ext.workspace = true +pgt_schema_cache.workspace = true +pgt_text_size.workspace = true +pgt_treesitter.workspace = true +sqlx.workspace = true +tokio.workspace = true +tree-sitter.workspace = true +tree_sitter_sql.workspace = true [dev-dependencies] insta.workspace = true diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 710b2fe98..1ee4095dc 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,5 +1,5 @@ use pgt_schema_cache::PostgresType; -use pgt_treesitter_queries::{TreeSitterQueriesExecutor, queries::ParameterMatch}; +use pgt_treesitter::queries::{ParameterMatch, TreeSitterQueriesExecutor}; /// A typed identifier is a parameter that has a type associated with it. /// It is used to replace parameters within the SQL string. diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index c6f05c6e2..a41dd06eb 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -82,17 +82,17 @@ mod tests { use super::get_statement_for_completions; - static CURSOR_POSITION: &str = "€"; + use pgt_test_utils::QueryWithCursorPosition; fn get_doc_and_pos(sql: &str) -> (Document, TextSize) { let pos = sql - .find(CURSOR_POSITION) + .find(QueryWithCursorPosition::cursor_marker()) .expect("Please add cursor position to test sql"); let pos: u32 = pos.try_into().unwrap(); ( - Document::new(sql.replace(CURSOR_POSITION, ""), 5), + Document::new(sql.replace(QueryWithCursorPosition::cursor_marker(), ""), 5), TextSize::new(pos), ) } @@ -107,7 +107,7 @@ mod tests { select 1; "#, - CURSOR_POSITION + QueryWithCursorPosition::cursor_marker() ); let (doc, position) = get_doc_and_pos(sql.as_str()); @@ -120,7 +120,7 @@ mod tests { #[test] fn does_not_break_when_no_statements_exist() { - let sql = CURSOR_POSITION.to_string(); + let sql = QueryWithCursorPosition::cursor_marker().to_string(); let (doc, position) = get_doc_and_pos(sql.as_str()); @@ -129,7 +129,10 @@ mod tests { #[test] fn does_not_return_overlapping_statements_if_too_close() { - let sql = format!("select * from {}select 1;", CURSOR_POSITION); + let sql = format!( + "select * from {}select 1;", + QueryWithCursorPosition::cursor_marker() + ); let (doc, position) = get_doc_and_pos(sql.as_str()); @@ -141,7 +144,10 @@ mod tests { #[test] fn is_fine_with_spaces() { - let sql = format!("select * from {} ;", CURSOR_POSITION); + let sql = format!( + "select * from {} ;", + QueryWithCursorPosition::cursor_marker() + ); let (doc, position) = get_doc_and_pos(sql.as_str()); @@ -153,7 +159,7 @@ mod tests { #[test] fn considers_offset() { - let sql = format!("select * from {}", CURSOR_POSITION); + let sql = format!("select * from {}", QueryWithCursorPosition::cursor_marker()); let (doc, position) = get_doc_and_pos(sql.as_str()); @@ -174,7 +180,7 @@ mod tests { select {} from cool; $$; "#, - CURSOR_POSITION + QueryWithCursorPosition::cursor_marker() ); let sql = sql.trim(); @@ -189,7 +195,10 @@ mod tests { #[test] fn does_not_consider_too_far_offset() { - let sql = format!("select * from {}", CURSOR_POSITION); + let sql = format!( + "select * from {}", + QueryWithCursorPosition::cursor_marker() + ); let (doc, position) = get_doc_and_pos(sql.as_str()); @@ -198,7 +207,10 @@ mod tests { #[test] fn does_not_consider_offset_if_statement_terminated_by_semi() { - let sql = format!("select * from users;{}", CURSOR_POSITION); + let sql = format!( + "select * from users;{}", + QueryWithCursorPosition::cursor_marker() + ); let (doc, position) = get_doc_and_pos(sql.as_str());