Skip to content

Commit 4eaeb51

Browse files
committed
store: Correctly handle quoted CTE names
1 parent 55b9465 commit 4eaeb51

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

store/postgres/src/sql/parser_tests.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,7 @@
127127
ok: SELECT * FROM (SELECT "id", "timestamp", "sum" FROM "sgd0815"."stats_hour" WHERE block$ <= 2147483647) AS sh
128128
- name: nested query with CTE
129129
sql: select *, (with pg_user as (select 1) select 1) as one from pg_user
130-
err: Unknown table pg_user
130+
err: Unknown table pg_user
131+
- name: Quoted name in CTE
132+
sql: WITH "PG_USER" AS (SELECT 1) SELECT * FROM pg_user;
133+
err: Unknown table pg_user

store/postgres/src/sql/validation.rs

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use sqlparser::ast::{
66
ValueWithSpan, VisitMut, VisitorMut,
77
};
88
use sqlparser::parser::Parser;
9+
use std::fmt::Display;
910
use std::result::Result;
1011
use std::{collections::HashSet, ops::ControlFlow};
1112

@@ -40,14 +41,47 @@ pub enum Error {
4041
InternalError(String),
4142
}
4243

44+
/// A wrapper around table names that correctly handles quoted vs unquoted
45+
/// comparisons of names
46+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47+
struct TableName(String);
48+
49+
impl TableName {
50+
fn as_str(&self) -> &str {
51+
&self.0
52+
}
53+
}
54+
55+
impl From<&Ident> for TableName {
56+
fn from(ident: &Ident) -> Self {
57+
let Ident {
58+
value,
59+
quote_style,
60+
span: _,
61+
} = ident;
62+
// Use quoted names verbatim, and normalize unquoted names to
63+
// lowercase
64+
match quote_style {
65+
Some(_) => Self(value.clone()),
66+
None => Self(value.to_lowercase()),
67+
}
68+
}
69+
}
70+
71+
impl Display for TableName {
72+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73+
write!(f, "{}", self.0)
74+
}
75+
}
76+
4377
/// Helper to track CTEs introduced by the main query or subqueries. Every
4478
/// time we enter a query, we need to track a new set of CTEs which must be
4579
/// discarded once we are done with that query. Otherwise, we might allow
4680
/// access to forbidden tables with a query like `select *, (with pg_user as
4781
/// (select 1) select 1) as one from pg_user`
4882
#[derive(Default)]
4983
struct CteStack {
50-
stack: Vec<HashSet<String>>,
84+
stack: Vec<HashSet<TableName>>,
5185
}
5286

5387
impl CteStack {
@@ -59,9 +93,9 @@ impl CteStack {
5993
self.stack.pop();
6094
}
6195

62-
fn contains(&self, name: &str) -> bool {
96+
fn contains(&self, name: &TableName) -> bool {
6397
for entry in self.stack.iter().rev() {
64-
if entry.contains(&name.to_lowercase()) {
98+
if entry.contains(name) {
6599
return true;
66100
}
67101
}
@@ -77,7 +111,7 @@ impl CteStack {
77111
return ControlFlow::Break(Error::InternalError("CTE stack is empty".into()));
78112
};
79113
for cte in ctes {
80-
entry.insert(cte.alias.name.value.to_lowercase());
114+
entry.insert(TableName::from(&cte.alias.name));
81115
}
82116
ControlFlow::Continue(())
83117
}
@@ -254,20 +288,20 @@ impl VisitorMut for Validator<'_> {
254288
return ControlFlow::Break(Error::NoQualifiedTables(name.to_string()));
255289
}
256290
let table_name = match &name.0[0] {
257-
ObjectNamePart::Identifier(ident) => &ident.value,
291+
ObjectNamePart::Identifier(ident) => TableName::from(ident),
258292
ObjectNamePart::Function(_) => {
259293
return ControlFlow::Break(Error::NoQualifiedTables(name.to_string()));
260294
}
261295
};
262296

263297
// CTES override subgraph tables
264-
if self.ctes.contains(&table_name.to_lowercase()) && args.is_none() {
298+
if self.ctes.contains(&table_name) && args.is_none() {
265299
return ControlFlow::Continue(());
266300
}
267301

268-
let table = match (self.layout.table(table_name), args) {
302+
let table = match (self.layout.table(table_name.as_str()), args) {
269303
(None, None) => {
270-
return ControlFlow::Break(Error::UnknownTable(table_name.clone()));
304+
return ControlFlow::Break(Error::UnknownTable(table_name.to_string()));
271305
}
272306
(Some(_), Some(_)) => {
273307
// Table exists but has args, must be a function
@@ -278,7 +312,7 @@ impl VisitorMut for Validator<'_> {
278312
// aggregation table in the form <name>(<interval>) or
279313
// must be a function
280314

281-
if !self.layout.has_aggregation(table_name) {
315+
if !self.layout.has_aggregation(table_name.as_str()) {
282316
// Not an aggregation, must be a function
283317
return self.validate_function_name(&name);
284318
}
@@ -287,23 +321,24 @@ impl VisitorMut for Validator<'_> {
287321
if settings.is_some() {
288322
// We do not support settings on aggregation tables
289323
return ControlFlow::Break(Error::InvalidAggregationSyntax(
290-
table_name.clone(),
324+
table_name.to_string(),
291325
));
292326
}
293327
let Some(intv) = extract_string_arg(args) else {
294328
// Looks like an aggregation, but argument is not a single string
295329
return ControlFlow::Break(Error::InvalidAggregationSyntax(
296-
table_name.clone(),
330+
table_name.to_string(),
297331
));
298332
};
299333
let Some(intv) = intv.parse::<AggregationInterval>().ok() else {
300334
return ControlFlow::Break(Error::UnknownAggregationInterval(
301-
table_name.clone(),
335+
table_name.to_string(),
302336
intv,
303337
));
304338
};
305339

306-
let Some(table) = self.layout.aggregation_table(table_name, intv) else {
340+
let Some(table) = self.layout.aggregation_table(table_name.as_str(), intv)
341+
else {
307342
return self.validate_function_name(&name);
308343
};
309344
table
@@ -312,7 +347,7 @@ impl VisitorMut for Validator<'_> {
312347
if !table.object.is_object_type() {
313348
// Interfaces and aggregations can not be queried
314349
// with the table name directly
315-
return ControlFlow::Break(Error::UnknownTable(table_name.clone()));
350+
return ControlFlow::Break(Error::UnknownTable(table_name.to_string()));
316351
}
317352
table
318353
}

0 commit comments

Comments
 (0)