Skip to content

Commit dec3c2b

Browse files
tustvoldalamb
andauthored
Add derive based AST visitor (apache#765)
* Add derive based AST visitor * Fix BigDecimal * Fix no visitor feature * Add test * Rename visit_table to visit_relation * Review feedback * Add pre and post visit Co-authored-by: Andrew Lamb <[email protected]>
1 parent 3e99046 commit dec3c2b

16 files changed

+771
-11
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# will have compiled files and executables
33
/target/
44
/sqlparser_bench/target/
5+
/derive/target/
56

67
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
78
# More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock

Diff for: Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ version = "0.28.0"
55
authors = ["Andy Grove <[email protected]>"]
66
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
77
documentation = "https://docs.rs/sqlparser/"
8-
keywords = [ "ansi", "sql", "lexer", "parser" ]
8+
keywords = ["ansi", "sql", "lexer", "parser"]
99
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
1010
license = "Apache-2.0"
1111
include = [
@@ -23,6 +23,7 @@ default = ["std"]
2323
std = []
2424
# Enable JSON output in the `cli` example:
2525
json_example = ["serde_json", "serde"]
26+
visitor = ["sqlparser_derive"]
2627

2728
[dependencies]
2829
bigdecimal = { version = "0.3", features = ["serde"], optional = true }
@@ -32,6 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
3233
# of dev-dependencies because of
3334
# https://github.com/rust-lang/cargo/issues/1596
3435
serde_json = { version = "1.0", optional = true }
36+
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
3537

3638
[dev-dependencies]
3739
simple_logger = "4.0"

Diff for: derive/Cargo.toml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[package]
2+
name = "sqlparser_derive"
3+
description = "proc macro for sqlparser"
4+
version = "0.1.0"
5+
authors = ["Andy Grove <[email protected]>"]
6+
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
7+
documentation = "https://docs.rs/sqlparser/"
8+
keywords = ["ansi", "sql", "lexer", "parser"]
9+
repository = "https://github.com/sqlparser-rs/sqlparser-rs"
10+
license = "Apache-2.0"
11+
include = [
12+
"src/**/*.rs",
13+
"Cargo.toml",
14+
]
15+
edition = "2021"
16+
17+
[lib]
18+
proc-macro = true
19+
20+
[dependencies]
21+
syn = "1.0"
22+
proc-macro2 = "1.0"
23+
quote = "1.0"

Diff for: derive/README.md

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# SQL Parser Derive Macro
2+
3+
## Visit
4+
5+
This crate contains a procedural macro that can automatically derive implementations of the `Visit` trait
6+
7+
```rust
8+
#[derive(Visit)]
9+
struct Foo {
10+
boolean: bool,
11+
bar: Bar,
12+
}
13+
14+
#[derive(Visit)]
15+
enum Bar {
16+
A(),
17+
B(String, bool),
18+
C { named: i32 },
19+
}
20+
```
21+
22+
Will generate code akin to
23+
24+
```rust
25+
impl Visit for Foo {
26+
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
27+
self.boolean.visit(visitor)?;
28+
self.bar.visit(visitor)?;
29+
ControlFlow::Continue(())
30+
}
31+
}
32+
33+
impl Visit for Bar {
34+
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
35+
match self {
36+
Self::A() => {}
37+
Self::B(_1, _2) => {
38+
_1.visit(visitor)?;
39+
_2.visit(visitor)?;
40+
}
41+
Self::C { named } => {
42+
named.visit(visitor)?;
43+
}
44+
}
45+
ControlFlow::Continue(())
46+
}
47+
}
48+
```
49+
50+
Additionally certain types may wish to call a corresponding method on visitor before recursing
51+
52+
```rust
53+
#[derive(Visit)]
54+
#[visit(with = "visit_expr")]
55+
enum Expr {
56+
A(),
57+
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
58+
}
59+
```
60+
61+
Will generate
62+
63+
```rust
64+
impl Visit for Bar {
65+
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
66+
visitor.visit_expr(self)?;
67+
match self {
68+
Self::A() => {}
69+
Self::B(_1, _2, _3) => {
70+
_1.visit(visitor)?;
71+
visitor.visit_relation(_3)?;
72+
_2.visit(visitor)?;
73+
_3.visit(visitor)?;
74+
}
75+
}
76+
ControlFlow::Continue(())
77+
}
78+
}
79+
```

Diff for: derive/src/lib.rs

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
use proc_macro2::TokenStream;
2+
use quote::{format_ident, quote, quote_spanned, ToTokens};
3+
use syn::spanned::Spanned;
4+
use syn::{
5+
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
6+
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
7+
};
8+
9+
/// Implementation of `[#derive(Visit)]`
10+
#[proc_macro_derive(Visit, attributes(visit))]
11+
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12+
// Parse the input tokens into a syntax tree.
13+
let input = parse_macro_input!(input as DeriveInput);
14+
let name = input.ident;
15+
16+
let attributes = Attributes::parse(&input.attrs);
17+
// Add a bound `T: HeapSize` to every type parameter T.
18+
let generics = add_trait_bounds(input.generics);
19+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20+
21+
let (pre_visit, post_visit) = attributes.visit(quote!(self));
22+
let children = visit_children(&input.data);
23+
24+
let expanded = quote! {
25+
// The generated impl.
26+
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
27+
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
28+
#pre_visit
29+
#children
30+
#post_visit
31+
::std::ops::ControlFlow::Continue(())
32+
}
33+
}
34+
};
35+
36+
proc_macro::TokenStream::from(expanded)
37+
}
38+
39+
/// Parses attributes that can be provided to this macro
40+
///
41+
/// `#[visit(leaf, with = "visit_expr")]`
42+
#[derive(Default)]
43+
struct Attributes {
44+
/// Content for the `with` attribute
45+
with: Option<Ident>,
46+
}
47+
48+
impl Attributes {
49+
fn parse(attrs: &[Attribute]) -> Self {
50+
let mut out = Self::default();
51+
for attr in attrs.iter().filter(|a| a.path.is_ident("visit")) {
52+
let meta = attr.parse_meta().expect("visit attribute");
53+
match meta {
54+
Meta::List(l) => {
55+
for nested in &l.nested {
56+
match nested {
57+
NestedMeta::Meta(Meta::NameValue(v)) => out.parse_name_value(v),
58+
_ => panic!("Expected #[visit(key = \"value\")]"),
59+
}
60+
}
61+
}
62+
_ => panic!("Expected #[visit(...)]"),
63+
}
64+
}
65+
out
66+
}
67+
68+
/// Updates self with a name value attribute
69+
fn parse_name_value(&mut self, v: &MetaNameValue) {
70+
if v.path.is_ident("with") {
71+
match &v.lit {
72+
Lit::Str(s) => self.with = Some(format_ident!("{}", s.value(), span = s.span())),
73+
_ => panic!("Expected a string value, got {}", v.lit.to_token_stream()),
74+
}
75+
return;
76+
}
77+
panic!("Unrecognised kv attribute {}", v.path.to_token_stream())
78+
}
79+
80+
/// Returns the pre and post visit token streams
81+
fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
82+
let pre_visit = self.with.as_ref().map(|m| {
83+
let m = format_ident!("pre_{}", m);
84+
quote!(visitor.#m(#s)?;)
85+
});
86+
let post_visit = self.with.as_ref().map(|m| {
87+
let m = format_ident!("post_{}", m);
88+
quote!(visitor.#m(#s)?;)
89+
});
90+
(pre_visit, post_visit)
91+
}
92+
}
93+
94+
// Add a bound `T: Visit` to every type parameter T.
95+
fn add_trait_bounds(mut generics: Generics) -> Generics {
96+
for param in &mut generics.params {
97+
if let GenericParam::Type(ref mut type_param) = *param {
98+
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
99+
}
100+
}
101+
generics
102+
}
103+
104+
// Generate the body of the visit implementation for the given type
105+
fn visit_children(data: &Data) -> TokenStream {
106+
match data {
107+
Data::Struct(data) => match &data.fields {
108+
Fields::Named(fields) => {
109+
let recurse = fields.named.iter().map(|f| {
110+
let name = &f.ident;
111+
let attributes = Attributes::parse(&f.attrs);
112+
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
113+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
114+
});
115+
quote! {
116+
#(#recurse)*
117+
}
118+
}
119+
Fields::Unnamed(fields) => {
120+
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
121+
let index = Index::from(i);
122+
let attributes = Attributes::parse(&f.attrs);
123+
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
124+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
125+
});
126+
quote! {
127+
#(#recurse)*
128+
}
129+
}
130+
Fields::Unit => {
131+
quote!()
132+
}
133+
},
134+
Data::Enum(data) => {
135+
let statements = data.variants.iter().map(|v| {
136+
let name = &v.ident;
137+
match &v.fields {
138+
Fields::Named(fields) => {
139+
let names = fields.named.iter().map(|f| &f.ident);
140+
let visit = fields.named.iter().map(|f| {
141+
let name = &f.ident;
142+
let attributes = Attributes::parse(&f.attrs);
143+
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
144+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
145+
});
146+
147+
quote!(
148+
Self::#name { #(#names),* } => {
149+
#(#visit)*
150+
}
151+
)
152+
}
153+
Fields::Unnamed(fields) => {
154+
let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
155+
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
156+
let name = format_ident!("_{}", i);
157+
let attributes = Attributes::parse(&f.attrs);
158+
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
159+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
160+
});
161+
162+
quote! {
163+
Self::#name ( #(#names),*) => {
164+
#(#visit)*
165+
}
166+
}
167+
}
168+
Fields::Unit => {
169+
quote! {
170+
Self::#name => {}
171+
}
172+
}
173+
}
174+
});
175+
176+
quote! {
177+
match self {
178+
#(#statements),*
179+
}
180+
}
181+
}
182+
Data::Union(_) => unimplemented!(),
183+
}
184+
}

Diff for: src/ast/data_type.rs

+8
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ use core::fmt;
1717
#[cfg(feature = "serde")]
1818
use serde::{Deserialize, Serialize};
1919

20+
#[cfg(feature = "visitor")]
21+
use sqlparser_derive::Visit;
22+
2023
use crate::ast::ObjectName;
2124

2225
use super::value::escape_single_quote_string;
2326

2427
/// SQL data types
2528
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2629
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30+
#[cfg_attr(feature = "visitor", derive(Visit))]
2731
pub enum DataType {
2832
/// Fixed-length character type e.g. CHARACTER(10)
2933
Character(Option<CharacterLength>),
@@ -337,6 +341,7 @@ fn format_datetime_precision_and_tz(
337341
/// guarantee compatibility with the input query we must maintain its exact information.
338342
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
339343
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
344+
#[cfg_attr(feature = "visitor", derive(Visit))]
340345
pub enum TimezoneInfo {
341346
/// No information about time zone. E.g., TIMESTAMP
342347
None,
@@ -384,6 +389,7 @@ impl fmt::Display for TimezoneInfo {
384389
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
385390
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
386391
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
392+
#[cfg_attr(feature = "visitor", derive(Visit))]
387393
pub enum ExactNumberInfo {
388394
/// No additional information e.g. `DECIMAL`
389395
None,
@@ -414,6 +420,7 @@ impl fmt::Display for ExactNumberInfo {
414420
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
415421
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
416422
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
423+
#[cfg_attr(feature = "visitor", derive(Visit))]
417424
pub struct CharacterLength {
418425
/// Default (if VARYING) or maximum (if not VARYING) length
419426
pub length: u64,
@@ -436,6 +443,7 @@ impl fmt::Display for CharacterLength {
436443
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
437444
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
438445
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
446+
#[cfg_attr(feature = "visitor", derive(Visit))]
439447
pub enum CharLengthUnits {
440448
/// CHARACTERS unit
441449
Characters,

0 commit comments

Comments
 (0)