Skip to content

Commit

Permalink
feat(scorers): Add support for generic types in ScorerBuilder (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
nosideeffects authored Apr 2, 2024
1 parent bce521c commit 98337fa
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
13 changes: 9 additions & 4 deletions derive/src/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ pub fn scorer_builder_impl(input: proc_macro::TokenStream) -> proc_macro::TokenS
let label = get_label(&input);

let component_name = input.ident;
let generics = input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let component_string = component_name.to_string();
let build_method = build_method(&component_name);
let build_method = build_method(&component_name, &ty_generics);
let label_method = label_method(
label.unwrap_or_else(|| LitStr::new(&component_string, component_name.span())),
);

let gen = quote! {
impl ::big_brain::scorers::ScorerBuilder for #component_name {
impl #impl_generics ::big_brain::scorers::ScorerBuilder for #component_name #ty_generics #where_clause {
#build_method
#label_method
}
Expand Down Expand Up @@ -48,10 +51,12 @@ fn get_label(input: &DeriveInput) -> Option<LitStr> {
label
}

fn build_method(component_name: &Ident) -> TokenStream {
fn build_method(component_name: &Ident, ty_generics: &syn::TypeGenerics) -> TokenStream {
let turbofish = ty_generics.as_turbofish();

quote! {
fn build(&self, cmd: &mut ::bevy::ecs::system::Commands, scorer: ::bevy::ecs::entity::Entity, _actor: ::bevy::ecs::entity::Entity) {
cmd.entity(scorer).insert(#component_name::clone(self));
cmd.entity(scorer).insert(#component_name #turbofish::clone(self));
}
}
}
Expand Down
27 changes: 27 additions & 0 deletions tests/derive_scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,30 @@ fn check_macro() {
let scorer = MyScorer;
assert_eq!(scorer.label(), Some("MyLabel"))
}

#[derive(Debug, Clone, Component, ScorerBuilder)]
#[scorer_label = "MyGenericLabel"]
pub struct MyGenericScorer<T: Clone + Send + Sync + std::fmt::Debug + 'static> {
pub value: T,
}

#[test]
fn check_generic_macro() {
let scorer = MyGenericScorer { value: 0 };
assert_eq!(scorer.label(), Some("MyGenericLabel"))
}

#[derive(Debug, Clone, Component, ScorerBuilder)]
#[scorer_label = "MyGenericWhereLabel"]
pub struct MyGenericWhereScorer<T>
where
T: Clone + Send + Sync + std::fmt::Debug + 'static,
{
pub value: T,
}

#[test]
fn check_generic_where_macro() {
let scorer = MyGenericWhereScorer { value: 0 };
assert_eq!(scorer.label(), Some("MyGenericWhereLabel"))
}

0 comments on commit 98337fa

Please sign in to comment.