@@ -6,6 +6,7 @@ use sqlparser::ast::{
66 ValueWithSpan , VisitMut , VisitorMut ,
77} ;
88use sqlparser:: parser:: Parser ;
9+ use std:: fmt:: Display ;
910use std:: result:: Result ;
1011use std:: { collections:: HashSet , ops:: ControlFlow } ;
1112
@@ -40,14 +41,47 @@ pub enum Error {
4041 InternalError ( String ) ,
4142}
4243
44+ /// A wrapper around table names that correctly handles quoted vs unquoted
45+ /// comparisons of names
46+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
47+ struct TableName ( String ) ;
48+
49+ impl TableName {
50+ fn as_str ( & self ) -> & str {
51+ & self . 0
52+ }
53+ }
54+
55+ impl From < & Ident > for TableName {
56+ fn from ( ident : & Ident ) -> Self {
57+ let Ident {
58+ value,
59+ quote_style,
60+ span : _,
61+ } = ident;
62+ // Use quoted names verbatim, and normalize unquoted names to
63+ // lowercase
64+ match quote_style {
65+ Some ( _) => Self ( value. clone ( ) ) ,
66+ None => Self ( value. to_lowercase ( ) ) ,
67+ }
68+ }
69+ }
70+
71+ impl Display for TableName {
72+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
73+ write ! ( f, "{}" , self . 0 )
74+ }
75+ }
76+
4377/// Helper to track CTEs introduced by the main query or subqueries. Every
4478/// time we enter a query, we need to track a new set of CTEs which must be
4579/// discarded once we are done with that query. Otherwise, we might allow
4680/// access to forbidden tables with a query like `select *, (with pg_user as
4781/// (select 1) select 1) as one from pg_user`
4882#[ derive( Default ) ]
4983struct CteStack {
50- stack : Vec < HashSet < String > > ,
84+ stack : Vec < HashSet < TableName > > ,
5185}
5286
5387impl CteStack {
@@ -59,9 +93,9 @@ impl CteStack {
5993 self . stack . pop ( ) ;
6094 }
6195
62- fn contains ( & self , name : & str ) -> bool {
96+ fn contains ( & self , name : & TableName ) -> bool {
6397 for entry in self . stack . iter ( ) . rev ( ) {
64- if entry. contains ( & name. to_lowercase ( ) ) {
98+ if entry. contains ( name) {
6599 return true ;
66100 }
67101 }
@@ -77,7 +111,7 @@ impl CteStack {
77111 return ControlFlow :: Break ( Error :: InternalError ( "CTE stack is empty" . into ( ) ) ) ;
78112 } ;
79113 for cte in ctes {
80- entry. insert ( cte. alias . name . value . to_lowercase ( ) ) ;
114+ entry. insert ( TableName :: from ( & cte. alias . name ) ) ;
81115 }
82116 ControlFlow :: Continue ( ( ) )
83117 }
@@ -254,20 +288,20 @@ impl VisitorMut for Validator<'_> {
254288 return ControlFlow :: Break ( Error :: NoQualifiedTables ( name. to_string ( ) ) ) ;
255289 }
256290 let table_name = match & name. 0 [ 0 ] {
257- ObjectNamePart :: Identifier ( ident) => & ident. value ,
291+ ObjectNamePart :: Identifier ( ident) => TableName :: from ( ident) ,
258292 ObjectNamePart :: Function ( _) => {
259293 return ControlFlow :: Break ( Error :: NoQualifiedTables ( name. to_string ( ) ) ) ;
260294 }
261295 } ;
262296
263297 // CTES override subgraph tables
264- if self . ctes . contains ( & table_name. to_lowercase ( ) ) && args. is_none ( ) {
298+ if self . ctes . contains ( & table_name) && args. is_none ( ) {
265299 return ControlFlow :: Continue ( ( ) ) ;
266300 }
267301
268- let table = match ( self . layout . table ( table_name) , args) {
302+ let table = match ( self . layout . table ( table_name. as_str ( ) ) , args) {
269303 ( None , None ) => {
270- return ControlFlow :: Break ( Error :: UnknownTable ( table_name. clone ( ) ) ) ;
304+ return ControlFlow :: Break ( Error :: UnknownTable ( table_name. to_string ( ) ) ) ;
271305 }
272306 ( Some ( _) , Some ( _) ) => {
273307 // Table exists but has args, must be a function
@@ -278,7 +312,7 @@ impl VisitorMut for Validator<'_> {
278312 // aggregation table in the form <name>(<interval>) or
279313 // must be a function
280314
281- if !self . layout . has_aggregation ( table_name) {
315+ if !self . layout . has_aggregation ( table_name. as_str ( ) ) {
282316 // Not an aggregation, must be a function
283317 return self . validate_function_name ( & name) ;
284318 }
@@ -287,23 +321,24 @@ impl VisitorMut for Validator<'_> {
287321 if settings. is_some ( ) {
288322 // We do not support settings on aggregation tables
289323 return ControlFlow :: Break ( Error :: InvalidAggregationSyntax (
290- table_name. clone ( ) ,
324+ table_name. to_string ( ) ,
291325 ) ) ;
292326 }
293327 let Some ( intv) = extract_string_arg ( args) else {
294328 // Looks like an aggregation, but argument is not a single string
295329 return ControlFlow :: Break ( Error :: InvalidAggregationSyntax (
296- table_name. clone ( ) ,
330+ table_name. to_string ( ) ,
297331 ) ) ;
298332 } ;
299333 let Some ( intv) = intv. parse :: < AggregationInterval > ( ) . ok ( ) else {
300334 return ControlFlow :: Break ( Error :: UnknownAggregationInterval (
301- table_name. clone ( ) ,
335+ table_name. to_string ( ) ,
302336 intv,
303337 ) ) ;
304338 } ;
305339
306- let Some ( table) = self . layout . aggregation_table ( table_name, intv) else {
340+ let Some ( table) = self . layout . aggregation_table ( table_name. as_str ( ) , intv)
341+ else {
307342 return self . validate_function_name ( & name) ;
308343 } ;
309344 table
@@ -312,7 +347,7 @@ impl VisitorMut for Validator<'_> {
312347 if !table. object . is_object_type ( ) {
313348 // Interfaces and aggregations can not be queried
314349 // with the table name directly
315- return ControlFlow :: Break ( Error :: UnknownTable ( table_name. clone ( ) ) ) ;
350+ return ControlFlow :: Break ( Error :: UnknownTable ( table_name. to_string ( ) ) ) ;
316351 }
317352 table
318353 }
0 commit comments