diff --git a/Cargo.lock b/Cargo.lock index d38cfba..5ac3ccb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,7 +543,7 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "fungible-pausable-example" -version = "0.0.0" +version = "0.1.0" dependencies = [ "openzeppelin-fungible-token", "openzeppelin-pausable", @@ -811,14 +811,14 @@ checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openzeppelin-fungible-token" -version = "0.0.0" +version = "0.1.0" dependencies = [ "soroban-sdk", ] [[package]] name = "openzeppelin-pausable" -version = "0.0.0" +version = "0.1.0" dependencies = [ "soroban-sdk", ] @@ -852,7 +852,7 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pausable-example" -version = "0.0.0" +version = "0.1.0" dependencies = [ "openzeppelin-pausable", "openzeppelin-pausable-macros", diff --git a/contracts/utils/pausable-macros/src/helper.rs b/contracts/utils/pausable-macros/src/helper.rs index f6cf1ef..4234c94 100644 --- a/contracts/utils/pausable-macros/src/helper.rs +++ b/contracts/utils/pausable-macros/src/helper.rs @@ -1,6 +1,36 @@ -use syn::{FnArg, ItemFn, PatType, Type}; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, FnArg, ItemFn, PatType, Type}; -pub fn check_env_arg(input_fn: &ItemFn) -> (syn::Ident, bool) { +pub fn generate_pause_check(item: TokenStream, check_fn: &str) -> TokenStream { + let input_fn = parse_macro_input!(item as ItemFn); + let (env_ident, is_ref) = check_env_arg(&input_fn); + + let fn_vis = &input_fn.vis; + let fn_sig = &input_fn.sig; + let fn_block = &input_fn.block; + let fn_attrs = &input_fn.attrs; + + let env_arg = if is_ref { + quote! { #env_ident } + } else { + quote! { &#env_ident } + }; + + let check_ident = syn::Ident::new(check_fn, proc_macro2::Span::call_site()); + let output = quote! { + #(#fn_attrs)* // retain other macros + #fn_vis #fn_sig { + openzeppelin_pausable::#check_ident(#env_arg); + + #fn_block + } + }; + + output.into() +} + +fn check_env_arg(input_fn: &ItemFn) -> (syn::Ident, bool) { // Get the first argument let first_arg = input_fn.sig.inputs.first().unwrap_or_else(|| { panic!("function '{}' must have at least one argument", input_fn.sig.ident) diff --git a/contracts/utils/pausable-macros/src/lib.rs b/contracts/utils/pausable-macros/src/lib.rs index 07245c2..ae77c33 100644 --- a/contracts/utils/pausable-macros/src/lib.rs +++ b/contracts/utils/pausable-macros/src/lib.rs @@ -1,8 +1,6 @@ use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, ItemFn}; -use crate::helper::check_env_arg; +use crate::helper::generate_pause_check; mod helper; @@ -28,28 +26,7 @@ mod helper; /// ``` #[proc_macro_attribute] pub fn when_not_paused(_attr: TokenStream, item: TokenStream) -> TokenStream { - let input_fn = parse_macro_input!(item as ItemFn); - let (env_ident, is_ref) = check_env_arg(&input_fn); - - let fn_vis = &input_fn.vis; - let fn_sig = &input_fn.sig; - let fn_block = &input_fn.block; - - let env_arg = if is_ref { - quote! { #env_ident } - } else { - quote! { &#env_ident } - }; - - let output = quote! { - #fn_vis #fn_sig { - openzeppelin_pausable::when_not_paused(#env_arg); - - #fn_block - } - }; - - output.into() + generate_pause_check(item, "when_not_paused") } /// Adds a pause check at the beginning of the function that ensures the @@ -74,26 +51,5 @@ pub fn when_not_paused(_attr: TokenStream, item: TokenStream) -> TokenStream { /// ``` #[proc_macro_attribute] pub fn when_paused(_attr: TokenStream, item: TokenStream) -> TokenStream { - let input_fn = parse_macro_input!(item as ItemFn); - let (env_ident, is_ref) = check_env_arg(&input_fn); - - let fn_vis = &input_fn.vis; - let fn_sig = &input_fn.sig; - let fn_block = &input_fn.block; - - let env_arg = if is_ref { - quote! { #env_ident } - } else { - quote! { &#env_ident } - }; - - let output = quote! { - #fn_vis #fn_sig { - openzeppelin_pausable::when_paused(#env_arg); - - #fn_block - } - }; - - output.into() + generate_pause_check(item, "when_paused") }