Skip to content

Commit 15ba3b8

Browse files
authored
Merge pull request #2 from n0-computer/common-fields
Common fields
2 parents 89e05b9 + 0003ee4 commit 15ba3b8

File tree

4 files changed

+123
-7
lines changed

4 files changed

+123
-7
lines changed

Cargo.lock

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "nested_enum_utils"
3-
version = "0.1.0"
4-
edition = "2021"
3+
version = "0.2.0"
4+
edition = "2024"
55
readme = "README.md"
66
description = "Macros to provide conversions for nested enums"
77
license = "MIT OR Apache-2.0"

src/lib.rs

+107-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use std::collections::BTreeSet;
22

33
use proc_macro::TokenStream;
4-
use proc_macro2::TokenStream as TokenStream2;
4+
use proc_macro2::{Literal, TokenStream as TokenStream2};
55
use quote::{quote, ToTokens};
66
use syn::{
7+
braced,
78
parse::{Parse, ParseStream},
89
parse_macro_input,
910
punctuated::Punctuated,
10-
Data, DeriveInput, Fields, Ident, Token, Type, Variant,
11+
Data, DeriveInput, Fields, Ident, ItemEnum, ItemStruct, Token, Type, Variant,
1112
};
1213

1314
fn extract_enum_variants(input: &DeriveInput) -> syn::Result<Vec<(&syn::Ident, &syn::Type)>> {
@@ -234,3 +235,107 @@ pub fn enum_conversions(attr: TokenStream, item: TokenStream) -> TokenStream {
234235
};
235236
TokenStream::from(expanded)
236237
}
238+
239+
// Custom struct to parse arbitrary content inside the attribute brackets
240+
struct CommonCode {
241+
content: TokenStream2,
242+
}
243+
244+
impl Parse for CommonCode {
245+
fn parse(input: ParseStream) -> syn::Result<Self> {
246+
// Parse everything between the braces as a raw token stream
247+
let content;
248+
braced!(content in input);
249+
let content = content.parse()?;
250+
Ok(CommonCode { content })
251+
}
252+
}
253+
254+
/// Usage example:
255+
///
256+
/// #[common_fields({
257+
/// /// Common size field for all variants
258+
/// #[serde(default)]
259+
/// pub size: u64
260+
/// })]
261+
/// enum Test {
262+
/// A { }
263+
/// B { x: bool }
264+
/// }
265+
///
266+
/// Becomes:
267+
///
268+
/// enum Test {
269+
/// A {
270+
/// /// Common size field for all variants
271+
/// #[serde(default)]
272+
/// pub size: u64
273+
/// }
274+
/// B {
275+
/// x: bool,
276+
/// /// Common size field for all variants
277+
/// #[serde(default)]
278+
/// pub size: u64
279+
/// }
280+
/// }
281+
#[proc_macro_attribute]
282+
pub fn common_fields(attr: TokenStream, item: TokenStream) -> TokenStream {
283+
// Parse the common code from the attribute
284+
let common_code = parse_macro_input!(attr as CommonCode);
285+
let common_fields_tokens = common_code.content;
286+
287+
// Parse the input enum
288+
let mut input_enum = parse_macro_input!(item as ItemEnum);
289+
290+
// Parse common fields by creating a temporary struct
291+
let temp_struct_tokens = quote! {
292+
struct TempStruct {
293+
#common_fields_tokens
294+
}
295+
};
296+
297+
// Parse the temporary struct
298+
let temp_struct: Result<ItemStruct, syn::Error> = syn::parse2(temp_struct_tokens);
299+
300+
// Check for parsing errors
301+
if let Err(err) = temp_struct {
302+
// Create a literal from the error message string
303+
let error_string = err.to_string();
304+
let error_lit = Literal::string(&error_string);
305+
306+
return TokenStream::from(quote! {
307+
compile_error!(#error_lit);
308+
});
309+
}
310+
311+
// Unwrap the struct now that we know it's Ok
312+
let temp_struct = temp_struct.unwrap();
313+
314+
// Extract fields from the temporary struct
315+
let common_fields = match temp_struct.fields {
316+
Fields::Named(named) => named.named,
317+
_ => {
318+
let error_lit = Literal::string("Expected named fields in common code block");
319+
return TokenStream::from(quote! {
320+
compile_error!(#error_lit);
321+
});
322+
}
323+
};
324+
325+
// Process each variant of the enum
326+
for variant in &mut input_enum.variants {
327+
// We only care about struct variants (named fields)
328+
if let Fields::Named(ref mut fields) = variant.fields {
329+
// Add each common field to this variant
330+
for field in common_fields.iter() {
331+
fields.named.push(field.clone());
332+
}
333+
}
334+
}
335+
336+
// Return the updated enum
337+
quote! {
338+
#input_enum
339+
}
340+
.into()
341+
}

tests/basic.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use nested_enum_utils::enum_conversions;
1+
use nested_enum_utils::{common_fields, enum_conversions};
22

33
#[test]
44
fn test_single_enum() {
@@ -83,3 +83,14 @@ fn compile_fail() {
8383
let t = trybuild::TestCases::new();
8484
t.compile_fail("tests/compile_fail/*.rs");
8585
}
86+
87+
#[test]
88+
fn test_common_fields() {
89+
#[common_fields({ id: u64 })]
90+
#[allow(dead_code)]
91+
enum Test {
92+
A { x: u32 },
93+
B { y: String },
94+
}
95+
let _v = Test::A { x: 42, id: 1 };
96+
}

0 commit comments

Comments
 (0)