From df6cc6d6981e3952099c81b8af8e3d9cb018de1b Mon Sep 17 00:00:00 2001 From: Jacob Biggs Date: Sun, 24 Mar 2024 17:37:48 -0500 Subject: [PATCH] Add support for generic types in ScorerBuilder --- derive/src/scorer.rs | 13 +++++++++---- tests/derive_scorer.rs | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/derive/src/scorer.rs b/derive/src/scorer.rs index a4b167a..0920467 100644 --- a/derive/src/scorer.rs +++ b/derive/src/scorer.rs @@ -10,14 +10,17 @@ pub fn scorer_builder_impl(input: proc_macro::TokenStream) -> proc_macro::TokenS let label = get_label(&input); let component_name = input.ident; + let generics = input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let component_string = component_name.to_string(); - let build_method = build_method(&component_name); + let build_method = build_method(&component_name, &ty_generics); let label_method = label_method( label.unwrap_or_else(|| LitStr::new(&component_string, component_name.span())), ); let gen = quote! { - impl ::big_brain::scorers::ScorerBuilder for #component_name { + impl #impl_generics ::big_brain::scorers::ScorerBuilder for #component_name #ty_generics #where_clause { #build_method #label_method } @@ -48,10 +51,12 @@ fn get_label(input: &DeriveInput) -> Option { label } -fn build_method(component_name: &Ident) -> TokenStream { +fn build_method(component_name: &Ident, ty_generics: &syn::TypeGenerics) -> TokenStream { + let turbofish = ty_generics.as_turbofish(); + quote! { fn build(&self, cmd: &mut ::bevy::ecs::system::Commands, scorer: ::bevy::ecs::entity::Entity, _actor: ::bevy::ecs::entity::Entity) { - cmd.entity(scorer).insert(#component_name::clone(self)); + cmd.entity(scorer).insert(#component_name #turbofish::clone(self)); } } } diff --git a/tests/derive_scorer.rs b/tests/derive_scorer.rs index f6881fc..25d995b 100644 --- a/tests/derive_scorer.rs +++ b/tests/derive_scorer.rs @@ -12,3 +12,30 @@ fn check_macro() { let scorer = MyScorer; assert_eq!(scorer.label(), Some("MyLabel")) } + +#[derive(Debug, Clone, Component, ScorerBuilder)] +#[scorer_label = "MyGenericLabel"] +pub struct MyGenericScorer { + pub value: T, +} + +#[test] +fn check_generic_macro() { + let scorer = MyGenericScorer { value: 0 }; + assert_eq!(scorer.label(), Some("MyGenericLabel")) +} + +#[derive(Debug, Clone, Component, ScorerBuilder)] +#[scorer_label = "MyGenericWhereLabel"] +pub struct MyGenericWhereScorer + where + T: Clone + Send + Sync + std::fmt::Debug + 'static, +{ + pub value: T, +} + +#[test] +fn check_generic_where_macro() { + let scorer = MyGenericWhereScorer { value: 0 }; + assert_eq!(scorer.label(), Some("MyGenericWhereLabel")) +}