diff --git a/Cargo.lock b/Cargo.lock index 639045d..95ceaf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "equivalent" @@ -44,7 +44,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "nested_enum_utils" -version = "0.1.0" +version = "0.2.0" dependencies = [ "proc-macro-crate", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 8137589..94631a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "nested_enum_utils" -version = "0.1.0" -edition = "2021" +version = "0.2.0" +edition = "2024" readme = "README.md" description = "Macros to provide conversions for nested enums" license = "MIT OR Apache-2.0" diff --git a/src/lib.rs b/src/lib.rs index 4aa5082..076bec7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,14 @@ use std::collections::BTreeSet; use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Literal, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; use syn::{ + braced, parse::{Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, - Data, DeriveInput, Fields, Ident, Token, Type, Variant, + Data, DeriveInput, Fields, Ident, ItemEnum, ItemStruct, Token, Type, Variant, }; fn extract_enum_variants(input: &DeriveInput) -> syn::Result> { @@ -234,3 +235,107 @@ pub fn enum_conversions(attr: TokenStream, item: TokenStream) -> TokenStream { }; TokenStream::from(expanded) } + +// Custom struct to parse arbitrary content inside the attribute brackets +struct CommonCode { + content: TokenStream2, +} + +impl Parse for CommonCode { + fn parse(input: ParseStream) -> syn::Result { + // Parse everything between the braces as a raw token stream + let content; + braced!(content in input); + let content = content.parse()?; + Ok(CommonCode { content }) + } +} + +/// Usage example: +/// +/// #[common_fields({ +/// /// Common size field for all variants +/// #[serde(default)] +/// pub size: u64 +/// })] +/// enum Test { +/// A { } +/// B { x: bool } +/// } +/// +/// Becomes: +/// +/// enum Test { +/// A { +/// /// Common size field for all variants +/// #[serde(default)] +/// pub size: u64 +/// } +/// B { +/// x: bool, +/// /// Common size field for all variants +/// #[serde(default)] +/// pub size: u64 +/// } +/// } +#[proc_macro_attribute] +pub fn common_fields(attr: TokenStream, item: TokenStream) -> TokenStream { + // Parse the common code from the attribute + let common_code = parse_macro_input!(attr as CommonCode); + let common_fields_tokens = common_code.content; + + // Parse the input enum + let mut input_enum = parse_macro_input!(item as ItemEnum); + + // Parse common fields by creating a temporary struct + let temp_struct_tokens = quote! { + struct TempStruct { + #common_fields_tokens + } + }; + + // Parse the temporary struct + let temp_struct: Result = syn::parse2(temp_struct_tokens); + + // Check for parsing errors + if let Err(err) = temp_struct { + // Create a literal from the error message string + let error_string = err.to_string(); + let error_lit = Literal::string(&error_string); + + return TokenStream::from(quote! { + compile_error!(#error_lit); + }); + } + + // Unwrap the struct now that we know it's Ok + let temp_struct = temp_struct.unwrap(); + + // Extract fields from the temporary struct + let common_fields = match temp_struct.fields { + Fields::Named(named) => named.named, + _ => { + let error_lit = Literal::string("Expected named fields in common code block"); + return TokenStream::from(quote! { + compile_error!(#error_lit); + }); + } + }; + + // Process each variant of the enum + for variant in &mut input_enum.variants { + // We only care about struct variants (named fields) + if let Fields::Named(ref mut fields) = variant.fields { + // Add each common field to this variant + for field in common_fields.iter() { + fields.named.push(field.clone()); + } + } + } + + // Return the updated enum + quote! { + #input_enum + } + .into() +} diff --git a/tests/basic.rs b/tests/basic.rs index 360dee0..baa569e 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,4 +1,4 @@ -use nested_enum_utils::enum_conversions; +use nested_enum_utils::{common_fields, enum_conversions}; #[test] fn test_single_enum() { @@ -83,3 +83,14 @@ fn compile_fail() { let t = trybuild::TestCases::new(); t.compile_fail("tests/compile_fail/*.rs"); } + +#[test] +fn test_common_fields() { + #[common_fields({ id: u64 })] + #[allow(dead_code)] + enum Test { + A { x: u32 }, + B { y: String }, + } + let _v = Test::A { x: 42, id: 1 }; +}