|
1 | 1 | use std::collections::BTreeSet;
|
2 | 2 |
|
3 | 3 | use proc_macro::TokenStream;
|
4 |
| -use proc_macro2::TokenStream as TokenStream2; |
| 4 | +use proc_macro2::{Literal, TokenStream as TokenStream2}; |
5 | 5 | use quote::{quote, ToTokens};
|
6 | 6 | use syn::{
|
| 7 | + braced, |
7 | 8 | parse::{Parse, ParseStream},
|
8 | 9 | parse_macro_input,
|
9 | 10 | punctuated::Punctuated,
|
10 |
| - Data, DeriveInput, Fields, Ident, Token, Type, Variant, |
| 11 | + Data, DeriveInput, Fields, Ident, ItemEnum, ItemStruct, Token, Type, Variant, |
11 | 12 | };
|
12 | 13 |
|
13 | 14 | fn extract_enum_variants(input: &DeriveInput) -> syn::Result<Vec<(&syn::Ident, &syn::Type)>> {
|
@@ -234,3 +235,107 @@ pub fn enum_conversions(attr: TokenStream, item: TokenStream) -> TokenStream {
|
234 | 235 | };
|
235 | 236 | TokenStream::from(expanded)
|
236 | 237 | }
|
| 238 | + |
| 239 | +// Custom struct to parse arbitrary content inside the attribute brackets |
| 240 | +struct CommonCode { |
| 241 | + content: TokenStream2, |
| 242 | +} |
| 243 | + |
| 244 | +impl Parse for CommonCode { |
| 245 | + fn parse(input: ParseStream) -> syn::Result<Self> { |
| 246 | + // Parse everything between the braces as a raw token stream |
| 247 | + let content; |
| 248 | + braced!(content in input); |
| 249 | + let content = content.parse()?; |
| 250 | + Ok(CommonCode { content }) |
| 251 | + } |
| 252 | +} |
| 253 | + |
| 254 | +/// Usage example: |
| 255 | +/// |
| 256 | +/// #[common_fields({ |
| 257 | +/// /// Common size field for all variants |
| 258 | +/// #[serde(default)] |
| 259 | +/// pub size: u64 |
| 260 | +/// })] |
| 261 | +/// enum Test { |
| 262 | +/// A { } |
| 263 | +/// B { x: bool } |
| 264 | +/// } |
| 265 | +/// |
| 266 | +/// Becomes: |
| 267 | +/// |
| 268 | +/// enum Test { |
| 269 | +/// A { |
| 270 | +/// /// Common size field for all variants |
| 271 | +/// #[serde(default)] |
| 272 | +/// pub size: u64 |
| 273 | +/// } |
| 274 | +/// B { |
| 275 | +/// x: bool, |
| 276 | +/// /// Common size field for all variants |
| 277 | +/// #[serde(default)] |
| 278 | +/// pub size: u64 |
| 279 | +/// } |
| 280 | +/// } |
| 281 | +#[proc_macro_attribute] |
| 282 | +pub fn common_fields(attr: TokenStream, item: TokenStream) -> TokenStream { |
| 283 | + // Parse the common code from the attribute |
| 284 | + let common_code = parse_macro_input!(attr as CommonCode); |
| 285 | + let common_fields_tokens = common_code.content; |
| 286 | + |
| 287 | + // Parse the input enum |
| 288 | + let mut input_enum = parse_macro_input!(item as ItemEnum); |
| 289 | + |
| 290 | + // Parse common fields by creating a temporary struct |
| 291 | + let temp_struct_tokens = quote! { |
| 292 | + struct TempStruct { |
| 293 | + #common_fields_tokens |
| 294 | + } |
| 295 | + }; |
| 296 | + |
| 297 | + // Parse the temporary struct |
| 298 | + let temp_struct: Result<ItemStruct, syn::Error> = syn::parse2(temp_struct_tokens); |
| 299 | + |
| 300 | + // Check for parsing errors |
| 301 | + if let Err(err) = temp_struct { |
| 302 | + // Create a literal from the error message string |
| 303 | + let error_string = err.to_string(); |
| 304 | + let error_lit = Literal::string(&error_string); |
| 305 | + |
| 306 | + return TokenStream::from(quote! { |
| 307 | + compile_error!(#error_lit); |
| 308 | + }); |
| 309 | + } |
| 310 | + |
| 311 | + // Unwrap the struct now that we know it's Ok |
| 312 | + let temp_struct = temp_struct.unwrap(); |
| 313 | + |
| 314 | + // Extract fields from the temporary struct |
| 315 | + let common_fields = match temp_struct.fields { |
| 316 | + Fields::Named(named) => named.named, |
| 317 | + _ => { |
| 318 | + let error_lit = Literal::string("Expected named fields in common code block"); |
| 319 | + return TokenStream::from(quote! { |
| 320 | + compile_error!(#error_lit); |
| 321 | + }); |
| 322 | + } |
| 323 | + }; |
| 324 | + |
| 325 | + // Process each variant of the enum |
| 326 | + for variant in &mut input_enum.variants { |
| 327 | + // We only care about struct variants (named fields) |
| 328 | + if let Fields::Named(ref mut fields) = variant.fields { |
| 329 | + // Add each common field to this variant |
| 330 | + for field in common_fields.iter() { |
| 331 | + fields.named.push(field.clone()); |
| 332 | + } |
| 333 | + } |
| 334 | + } |
| 335 | + |
| 336 | + // Return the updated enum |
| 337 | + quote! { |
| 338 | + #input_enum |
| 339 | + } |
| 340 | + .into() |
| 341 | +} |
0 commit comments