diff --git a/crates/service/src/tap/checks/value_check.rs b/crates/service/src/tap/checks/value_check.rs index 9a40b214..f7bc129c 100644 --- a/crates/service/src/tap/checks/value_check.rs +++ b/crates/service/src/tap/checks/value_check.rs @@ -20,6 +20,8 @@ use tap_core::receipt::{ Context, WithValueAndTimestamp, }; use thegraph_core::DeploymentId; +#[cfg(test)] +use tokio::sync::mpsc; use crate::{ database::cost_model, @@ -55,7 +57,7 @@ pub struct MinimumValue { grace_period: Duration, #[cfg(test)] - notify: std::sync::Arc, + msg_receiver: mpsc::Receiver<()>, } struct CostModelWatcher { @@ -66,7 +68,7 @@ struct CostModelWatcher { updated_at: GracePeriod, #[cfg(test)] - notify: std::sync::Arc, + sender: mpsc::Sender<()>, } impl CostModelWatcher { @@ -77,7 +79,7 @@ impl CostModelWatcher { global_model: GlobalModel, cancel_token: tokio_util::sync::CancellationToken, grace_period: GracePeriod, - #[cfg(test)] notify: std::sync::Arc, + #[cfg(test)] sender: mpsc::Sender<()>, ) { let cost_model_watcher = CostModelWatcher { pgpool, @@ -85,7 +87,7 @@ impl CostModelWatcher { cost_models, updated_at: grace_period, #[cfg(test)] - notify, + sender, }; loop { @@ -119,7 +121,7 @@ impl CostModelWatcher { Err(_) => self.handle_unexpected_notification(payload).await, } #[cfg(test)] - self.notify.notify_one(); + self.sender.send(()).await.expect("Channel failed"); } fn handle_insert(&self, deployment: String, model: String, variables: String) { @@ -212,7 +214,7 @@ impl MinimumValue { ); #[cfg(test)] - let notify = std::sync::Arc::new(tokio::sync::Notify::new()); + let (sender, receiver) = mpsc::channel(10); let watcher_cancel_token = tokio_util::sync::CancellationToken::new(); tokio::spawn(CostModelWatcher::cost_models_watcher( @@ -223,7 +225,7 @@ impl MinimumValue { watcher_cancel_token.clone(), updated_at.clone(), #[cfg(test)] - notify.clone(), + sender, )); Self { global_model, @@ -232,7 +234,7 @@ impl MinimumValue { updated_at, grace_period, #[cfg(test)] - notify, + msg_receiver: receiver, } } @@ -399,14 +401,14 @@ mod tests { #[sqlx::test(migrations = "../../migrations")] async fn should_watch_model_insert(pgpool: PgPool) { - let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; + let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; assert_eq!(check.cost_model_map.read().unwrap().len(), 0); // insert 2 cost models for different deployment_id let test_models = test::test_data(); add_cost_models(&pgpool, to_db_models(test_models.clone())).await; - flush_messages(&check.notify).await; + flush_messages(&mut check.msg_receiver).await; assert_eq!( check.cost_model_map.read().unwrap().len(), @@ -420,7 +422,7 @@ mod tests { let test_models = test::test_data(); add_cost_models(&pgpool, to_db_models(test_models.clone())).await; - let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; + let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; assert_eq!(check.cost_model_map.read().unwrap().len(), 2); // remove @@ -429,7 +431,7 @@ mod tests { .await .unwrap(); - check.notify.notified().await; + check.msg_receiver.recv().await.expect("Channel failed"); assert_eq!(check.cost_model_map.read().unwrap().len(), 0); } @@ -445,12 +447,12 @@ mod tests { #[sqlx::test(migrations = "../../migrations")] async fn should_watch_global_model(pgpool: PgPool) { - let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; + let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; let global_model = global_cost_model(); add_cost_models(&pgpool, vec![global_model.clone()]).await; - check.notify.notified().await; + check.msg_receiver.recv().await.expect("Channel failed"); assert!(check.global_model.read().unwrap().is_some()); } @@ -460,7 +462,7 @@ mod tests { let global_model = global_cost_model(); add_cost_models(&pgpool, vec![global_model.clone()]).await; - let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; + let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await; assert!(check.global_model.read().unwrap().is_some()); sqlx::query!(r#"DELETE FROM "CostModels""#) @@ -468,7 +470,7 @@ mod tests { .await .unwrap(); - check.notify.notified().await; + check.msg_receiver.recv().await.expect("Channel failed"); assert_eq!(check.cost_model_map.read().unwrap().len(), 0); } diff --git a/crates/tap-agent/src/agent/sender_account.rs b/crates/tap-agent/src/agent/sender_account.rs index 283e8f0c..384d2deb 100644 --- a/crates/tap-agent/src/agent/sender_account.rs +++ b/crates/tap-agent/src/agent/sender_account.rs @@ -102,7 +102,7 @@ type Balance = U256; /// Information for Ravs that are abstracted away from the SignedRav itself #[derive(Debug, Default, PartialEq, Eq)] -#[cfg_attr(test, derive(Clone))] +#[cfg_attr(any(test, feature = "test"), derive(Clone))] pub struct RavInformation { /// Allocation Id of a Rav pub allocation_id: Address, @@ -141,8 +141,8 @@ impl From<&tap_graph::v2::SignedRav> for RavInformation { /// /// It has different logic depending on the variant #[derive(Debug)] -#[cfg_attr(test, derive(educe::Educe))] -#[cfg_attr(test, educe(PartialEq, Eq, Clone))] +#[cfg_attr(any(test, feature = "test"), derive(educe::Educe))] +#[cfg_attr(any(test, feature = "test"), educe(PartialEq, Eq, Clone))] pub enum ReceiptFees { /// Adds the receipt value to the fee tracker /// @@ -158,7 +158,10 @@ pub enum ReceiptFees { /// If not, signalize the fee_tracker to apply proper backoff RavRequestResponse( UnaggregatedReceipts, - #[cfg_attr(test, educe(PartialEq(ignore), Clone(method(clone_rav_result))))] + #[cfg_attr( + any(test, feature = "test"), + educe(PartialEq(ignore), Clone(method(clone_rav_result))) + )] anyhow::Result>, ), /// Ignores all logic and simply retry Allow/Deny and Rav Request logic @@ -169,7 +172,7 @@ pub enum ReceiptFees { Retry, } -#[cfg(test)] +#[cfg(any(test, feature = "test"))] fn clone_rav_result( res: &anyhow::Result>, ) -> anyhow::Result> { @@ -181,8 +184,8 @@ fn clone_rav_result( /// Enum containing all types of messages that a [SenderAccount] can receive #[derive(Debug)] -#[cfg_attr(test, derive(educe::Educe))] -#[cfg_attr(test, educe(PartialEq, Eq, Clone))] +#[cfg_attr(any(test, feature = "test"), derive(educe::Educe))] +#[cfg_attr(any(test, feature = "test"), educe(PartialEq, Eq, Clone))] pub enum SenderAccountMessage { /// Updates the sender balance and UpdateBalanceAndLastRavs(Balance, RavMap), @@ -1492,22 +1495,25 @@ pub mod tests { ) .await; - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool) .escrow_subgraph_endpoint(&mock_escrow_subgraph.uri()) .network_subgraph_endpoint(&mock_server.uri()) .call() .await; + let allocation_ids = HashSet::from_iter([AllocationId::Legacy(ALLOCATION_ID_0)]); // we expect it to create a sender allocation sender_account .cast(SenderAccountMessage::UpdateAllocationIds( - vec![AllocationId::Legacy(ALLOCATION_ID_0)] - .into_iter() - .collect(), + allocation_ids.clone(), )) .unwrap(); - notify.notified().await; + let message = msg_receiver.recv().await.expect("Channel failed"); + assert_eq!( + message, + SenderAccountMessage::UpdateAllocationIds(allocation_ids) + ); // verify if create sender account let sender_allocation_id = format!("{}:{}:{}", prefix.clone(), SENDER.1, ALLOCATION_ID_0); @@ -1517,7 +1523,18 @@ pub mod tests { sender_account .cast(SenderAccountMessage::UpdateAllocationIds(HashSet::new())) .unwrap(); - notify.notified().await; + let message = msg_receiver.recv().await.expect("Channel failed"); + assert_eq!( + message, + SenderAccountMessage::UpdateReceiptFees( + ALLOCATION_ID_0, + ReceiptFees::UpdateValue(UnaggregatedReceipts { + value: 0, + last_id: 0, + counter: 0, + }) + ) + ); let actor_ref = ActorRef::::where_is(sender_allocation_id.clone()); assert!(actor_ref.is_some()); @@ -1547,7 +1564,11 @@ pub mod tests { sender_account .cast(SenderAccountMessage::UpdateAllocationIds(HashSet::new())) .unwrap(); - notify.notified().await; + let msg = msg_receiver.recv().await.expect("Channel failed"); + assert_eq!( + msg, + SenderAccountMessage::UpdateAllocationIds(HashSet::new()) + ); let actor_ref = ActorRef::::where_is(sender_allocation_id.clone()); assert!(actor_ref.is_none()); @@ -1580,7 +1601,7 @@ pub mod tests { ) .await; - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool) .escrow_subgraph_endpoint(&mock_escrow_subgraph.uri()) .network_subgraph_endpoint(&mock_server.uri()) @@ -1594,7 +1615,7 @@ pub mod tests { ))) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // verify if create sender account let sender_allocation_id = format!("{}:{}:{}", prefix.clone(), SENDER.1, ALLOCATION_ID_0); @@ -1610,14 +1631,14 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // try to delete sender allocation_id sender_account .cast(SenderAccountMessage::UpdateAllocationIds(HashSet::new())) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // should not delete it because it was not in network subgraph let allocation_ref = @@ -1695,7 +1716,7 @@ pub mod tests { #[sqlx::test(migrations = "../../migrations")] async fn test_update_receipt_fees_trigger_rav(pgpool: PgPool) { - let (sender_account, notify, prefix, _) = + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account().pgpool(pgpool).call().await; // create a fake sender allocation @@ -1714,7 +1735,7 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; assert_not_triggered!(&triggered_rav_request); // wait for it to be outside buffer @@ -1726,14 +1747,14 @@ pub mod tests { ReceiptFees::Retry, )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; assert_triggered!(&triggered_rav_request); } #[sqlx::test(migrations = "../../migrations")] async fn test_counter_greater_limit_trigger_rav(pgpool: PgPool) { - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool.clone()) .rav_request_receipt_limit(2) .call() @@ -1754,7 +1775,7 @@ pub mod tests { ReceiptFees::NewReceipt(1, get_current_timestamp_u64_ns()), )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; assert_not_triggered!(&triggered_rav_request); @@ -1764,7 +1785,7 @@ pub mod tests { ReceiptFees::NewReceipt(1, get_current_timestamp_u64_ns()), )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // wait for it to be outside buffer tokio::time::sleep(BUFFER_DURATION).await; @@ -1775,7 +1796,7 @@ pub mod tests { ReceiptFees::Retry, )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; assert_triggered!(&triggered_rav_request); } @@ -1853,7 +1874,7 @@ pub mod tests { // we set to zero to block the sender, no matter the fee let max_unaggregated_fees_per_sender: u128 = 0; - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool) .max_amount_willing_to_lose_grt(max_unaggregated_fees_per_sender) .call() @@ -1877,7 +1898,7 @@ pub mod tests { ReceiptFees::NewReceipt(TRIGGER_VALUE, get_current_timestamp_u64_ns()), )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // wait to try again so it's outside the buffer tokio::time::sleep(RETRY_DURATION).await; @@ -1897,7 +1918,7 @@ pub mod tests { let max_unaggregated_fees_per_sender: u128 = 1000; // Making sure no RAV is going to be triggered during the test - let (sender_account, notify, _, _) = create_sender_account() + let (sender_account, mut msg_receiver, _, _) = create_sender_account() .pgpool(pgpool.clone()) .rav_request_trigger_value(u128::MAX) .max_amount_willing_to_lose_grt(max_unaggregated_fees_per_sender) @@ -1917,7 +1938,7 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; }; } @@ -1934,7 +1955,7 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; }; } @@ -2036,7 +2057,7 @@ pub mod tests { let trigger_rav_request = ESCROW_VALUE * 2; // initialize with no trigger value and no max receipt deny - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool.clone()) .rav_request_trigger_value(trigger_rav_request) .max_amount_willing_to_lose_grt(u128::MAX) @@ -2068,7 +2089,7 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; }; } @@ -2114,7 +2135,7 @@ pub mod tests { async fn test_trusted_sender(pgpool: PgPool) { let max_amount_willing_to_lose_grt = ESCROW_VALUE / 10; // initialize with no trigger value and no max receipt deny - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool) .trusted_sender(true) .rav_request_trigger_value(u128::MAX) @@ -2143,7 +2164,7 @@ pub mod tests { })) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; }; } @@ -2221,7 +2242,7 @@ pub mod tests { .await .unwrap(); - let (sender_account, notify, _, escrow_accounts_tx) = create_sender_account() + let (sender_account, mut msg_receiver, _, escrow_accounts_tx) = create_sender_account() .pgpool(pgpool.clone()) .max_amount_willing_to_lose_grt(u128::MAX) .escrow_subgraph_endpoint(&mock_server.uri()) @@ -2255,7 +2276,7 @@ pub mod tests { .unwrap(); // wait the actor react to the messages - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // should still be active with a 1 escrow available @@ -2279,7 +2300,7 @@ pub mod tests { .await .unwrap(); - let (sender_account, notify, _, escrow_accounts_tx) = create_sender_account() + let (sender_account, mut msg_receiver, _, escrow_accounts_tx) = create_sender_account() .pgpool(pgpool.clone()) .max_amount_willing_to_lose_grt(u128::MAX) .call() @@ -2296,7 +2317,7 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; let deny = call!(sender_account, SenderAccountMessage::GetDeny).unwrap(); assert!(deny, "should block the sender"); @@ -2309,7 +2330,7 @@ pub mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; let deny = call!(sender_account, SenderAccountMessage::GetDeny).unwrap(); assert!(!deny, "should unblock the sender"); @@ -2322,7 +2343,7 @@ pub mod tests { // we set to 1 to block the sender on a really low value let max_unaggregated_fees_per_sender: u128 = 1; - let (sender_account, notify, prefix, _) = create_sender_account() + let (sender_account, mut msg_receiver, prefix, _) = create_sender_account() .pgpool(pgpool) .max_amount_willing_to_lose_grt(max_unaggregated_fees_per_sender) .call() @@ -2349,7 +2370,14 @@ pub mod tests { ReceiptFees::NewReceipt(TRIGGER_VALUE, get_current_timestamp_u64_ns()), )) .unwrap(); - notify.notified().await; + let msg = msg_receiver.recv().await.expect("Channel failed"); + assert!(matches!( + msg, + SenderAccountMessage::UpdateReceiptFees( + ALLOCATION_ID_0, + ReceiptFees::NewReceipt(TRIGGER_VALUE, _) + ) + )); let deny = call!(sender_account, SenderAccountMessage::GetDeny).unwrap(); assert!(deny, "should be blocked"); diff --git a/crates/tap-agent/src/agent/sender_accounts_manager.rs b/crates/tap-agent/src/agent/sender_accounts_manager.rs index 34945a74..bbf0685c 100644 --- a/crates/tap-agent/src/agent/sender_accounts_manager.rs +++ b/crates/tap-agent/src/agent/sender_accounts_manager.rs @@ -38,7 +38,7 @@ lazy_static! { /// Notification received by pgnotify /// /// This contains a list of properties that are sent by postgres when a receipt is inserted -#[derive(Deserialize, Debug, PartialEq, Eq)] +#[derive(Deserialize, Debug, PartialEq, Eq, Clone)] pub struct NewReceiptNotification { /// id inside the table pub id: u64, @@ -53,6 +53,7 @@ pub struct NewReceiptNotification { } /// Manager Actor +#[derive(Debug, Clone)] pub struct SenderAccountsManager; /// Wrapped AllocationId Address with two possible variants @@ -95,6 +96,7 @@ pub enum SenderType { /// Enum containing all types of messages that a [SenderAccountsManager] can receive #[derive(Debug)] +#[cfg_attr(any(test, feature = "test"), derive(Clone))] pub enum SenderAccountsManagerMessage { /// Spawn and Stop [SenderAccount]s that were added or removed /// in comparison with it current state and updates the state @@ -1035,7 +1037,7 @@ mod tests { #[sqlx::test(migrations = "../../migrations")] async fn test_update_sender_allocation(pgpool: PgPool) { - let (prefix, notify, (actor, join_handle)) = + let (prefix, mut notify, (actor, join_handle)) = create_sender_accounts_manager().pgpool(pgpool).call().await; actor @@ -1044,7 +1046,7 @@ mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut notify).await; assert_while_retry! { ActorRef::::where_is(format!( @@ -1068,7 +1070,7 @@ mod tests { )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut notify).await; sender_ref.wait(None).await.unwrap(); // verify if it gets removed @@ -1140,8 +1142,8 @@ mod tests { // create dummy allocation let (mock_sender_allocation, mut receipts) = MockSenderAllocation::new_with_receipts(); - let actor = TestableActor::new(mock_sender_allocation); - let notify = actor.notify.clone(); + let (tx, mut notify) = mpsc::channel(10); + let actor = TestableActor::new(mock_sender_allocation, tx); let _ = Actor::spawn( Some(format!( "{}:{}:{}", @@ -1192,7 +1194,7 @@ mod tests { .await .unwrap(); } - flush_messages(¬ify).await; + flush_messages(&mut notify).await; // check if receipt notification was sent to the allocation for i in 1..=receipts_count { diff --git a/crates/tap-agent/src/agent/sender_allocation.rs b/crates/tap-agent/src/agent/sender_allocation.rs index a016f676..37d25ef2 100644 --- a/crates/tap-agent/src/agent/sender_allocation.rs +++ b/crates/tap-agent/src/agent/sender_allocation.rs @@ -217,6 +217,8 @@ pub struct SenderAllocationArgs { /// Enum containing all types of messages that a [SenderAllocation] can receive #[derive(Debug)] +#[cfg_attr(any(test, feature = "test"), derive(educe::Educe))] +#[cfg_attr(any(test, feature = "test"), educe(Clone))] pub enum SenderAllocationMessage { /// New receipt message, sent by the task spawned by /// [super::sender_accounts_manager::SenderAccountsManager] @@ -227,7 +229,10 @@ pub enum SenderAllocationMessage { TriggerRavRequest, #[cfg(any(test, feature = "test"))] /// Return the internal state (used for tests) - GetUnaggregatedReceipts(ractor::RpcReplyPort), + GetUnaggregatedReceipts( + #[educe(Clone(method(crate::test::actors::clone_rpc_reply)))] + ractor::RpcReplyPort, + ), } /// Actor implementation for [SenderAllocation] @@ -1305,7 +1310,7 @@ pub mod tests { flush_messages, ALLOCATION_ID_0, TAP_EIP712_DOMAIN as TAP_EIP712_DOMAIN_SEPARATOR, TAP_SENDER as SENDER, TAP_SIGNER as SIGNER, }; - use tokio::sync::{watch, Notify}; + use tokio::sync::{mpsc, watch}; use tonic::{transport::Endpoint, Code}; use wiremock::{ matchers::{body_string_contains, method}, @@ -1416,7 +1421,10 @@ pub mod tests { escrow_subgraph_endpoint: &str, #[builder(default = 1000)] rav_request_receipt_limit: u64, sender_account: Option>, - ) -> (ActorRef, Arc) { + ) -> ( + ActorRef, + mpsc::Receiver, + ) { let args = create_sender_allocation_args() .pgpool(pgpool) .maybe_sender_aggregator_endpoint(sender_aggregator_endpoint) @@ -1425,12 +1433,13 @@ pub mod tests { .rav_request_receipt_limit(rav_request_receipt_limit) .call() .await; - let actor = TestableActor::new(SenderAllocation::default()); - let notify = actor.notify.clone(); + + let (sender, msg_receiver) = mpsc::channel(10); + let actor = TestableActor::new(SenderAllocation::default(), sender); let (allocation_ref, _join_handle) = Actor::spawn(None, actor, args).await.unwrap(); - (allocation_ref, notify) + (allocation_ref, msg_receiver) } #[sqlx::test(migrations = "../../migrations")] @@ -1530,7 +1539,7 @@ pub mod tests { let (mut message_receiver, sender_account) = create_mock_sender_account().await; - let (sender_allocation, notify) = create_sender_allocation() + let (sender_allocation, mut msg_receiver) = create_sender_allocation() .pgpool(pgpool.clone()) .escrow_subgraph_endpoint(&mock_escrow_subgraph_server.uri()) .sender_account(sender_account) @@ -1567,7 +1576,7 @@ pub mod tests { ) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; // should emit update aggregate fees message to sender account let expected_message = SenderAccountMessage::UpdateReceiptFees( @@ -1623,7 +1632,7 @@ pub mod tests { let (mut message_receiver, sender_account) = create_mock_sender_account().await; // Create a sender_allocation. - let (sender_allocation, notify) = create_sender_allocation() + let (sender_allocation, mut msg_receiver_alloc) = create_sender_allocation() .pgpool(pgpool.clone()) .escrow_subgraph_endpoint(&mock_server.uri()) .sender_account(sender_account) @@ -1635,7 +1644,7 @@ pub mod tests { .cast(SenderAllocationMessage::TriggerRavRequest) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver_alloc).await; let total_unaggregated_fees = call!( sender_allocation, @@ -1703,7 +1712,7 @@ pub mod tests { let (mut message_receiver, sender_account) = create_mock_sender_account().await; // Create a sender_allocation. - let (sender_allocation, notify) = create_sender_allocation() + let (sender_allocation, mut msg_receiver_alloc) = create_sender_allocation() .pgpool(pgpool.clone()) .escrow_subgraph_endpoint(&mock_server.uri()) .rav_request_receipt_limit(2000) @@ -1716,7 +1725,7 @@ pub mod tests { .cast(SenderAllocationMessage::TriggerRavRequest) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver_alloc).await; let total_unaggregated_fees = call!( sender_allocation, @@ -2073,7 +2082,7 @@ pub mod tests { let (mut message_receiver, sender_account) = create_mock_sender_account().await; // Create a sender_allocation. - let (sender_allocation, notify) = create_sender_allocation() + let (sender_allocation, mut notify) = create_sender_allocation() .pgpool(pgpool.clone()) .escrow_subgraph_endpoint(&mock_escrow_subgraph_server.uri()) .sender_account(sender_account) @@ -2086,7 +2095,7 @@ pub mod tests { .cast(SenderAllocationMessage::TriggerRavRequest) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut notify).await; // If it is an error then rav request failed @@ -2159,7 +2168,7 @@ pub mod tests { let (mut message_receiver, sender_account) = create_mock_sender_account().await; - let (sender_allocation, notify) = create_sender_allocation() + let (sender_allocation, mut notify) = create_sender_allocation() .pgpool(pgpool.clone()) .escrow_subgraph_endpoint(&mock_server.uri()) .sender_account(sender_account) @@ -2172,7 +2181,7 @@ pub mod tests { .cast(SenderAllocationMessage::TriggerRavRequest) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut notify).await; // If it is an error then rav request failed diff --git a/crates/tap-agent/src/lib.rs b/crates/tap-agent/src/lib.rs index 2cd518d5..a0ce6154 100644 --- a/crates/tap-agent/src/lib.rs +++ b/crates/tap-agent/src/lib.rs @@ -35,7 +35,8 @@ pub mod database; /// Prometheus Metrics server pub mod metrics; pub mod tap; -#[cfg(any(test, feature = "test"))] + /// Test utils to interact with Tap Actors +#[cfg(any(test, feature = "test"))] pub mod test; pub mod tracker; diff --git a/crates/tap-agent/src/test.rs b/crates/tap-agent/src/test.rs index 3639bb30..372576f7 100644 --- a/crates/tap-agent/src/test.rs +++ b/crates/tap-agent/src/test.rs @@ -5,7 +5,6 @@ use std::{ collections::{HashMap, HashSet}, net::SocketAddr, - sync::Arc, time::Duration, }; @@ -32,8 +31,8 @@ use thegraph_core::alloy::{ pub const ALLOCATION_ID_0: Address = test_assets::ALLOCATION_ID_0; pub const ALLOCATION_ID_1: Address = test_assets::ALLOCATION_ID_1; use tokio::sync::{ + mpsc, watch::{self, Sender}, - Notify, }; use tracing::error; @@ -111,7 +110,7 @@ pub async fn create_sender_account( #[builder(default = false)] trusted_sender: bool, ) -> ( ActorRef, - Arc, + mpsc::Receiver, String, Sender, ) { @@ -181,17 +180,17 @@ pub async fn create_sender_account( sender_type: SenderType::Legacy, }; - let actor = TestableActor::new(SenderAccount); - let notify = actor.notify.clone(); + let (sender, mut receiver) = mpsc::channel(100); + let actor = TestableActor::new(SenderAccount, sender); let (sender, _) = Actor::spawn(Some(prefix.clone()), actor, args) .await .unwrap(); // flush all messages - flush_messages(¬ify).await; + flush_messages(&mut receiver).await; - (sender, notify, prefix, escrow_accounts_tx) + (sender, receiver, prefix, escrow_accounts_tx) } #[bon::builder] @@ -203,7 +202,7 @@ pub async fn create_sender_accounts_manager( initial_escrow_accounts_v2: Option, ) -> ( String, - Arc, + mpsc::Receiver, (ActorRef, JoinHandle<()>), ) { let config = get_sender_account_config(); @@ -254,11 +253,11 @@ pub async fn create_sender_accounts_manager( ]), prefix: Some(prefix.clone()), }; - let actor = TestableActor::new(SenderAccountsManager); - let notify = actor.notify.clone(); + let (sender, receiver) = mpsc::channel(100); + let actor = TestableActor::new(SenderAccountsManager, sender); ( prefix, - notify, + receiver, Actor::spawn(None, actor, args).await.unwrap(), ) } @@ -696,7 +695,7 @@ pub async fn store_rav_with_options( } pub mod actors { - use std::sync::Arc; + use std::{fmt::Debug, sync::Arc}; use ractor::{Actor, ActorProcessingErr, ActorRef, SupervisionEvent}; use test_assets::{ALLOCATION_ID_0, TAP_SIGNER}; @@ -711,7 +710,7 @@ pub mod actors { unaggregated_receipts::UnaggregatedReceipts, }; - #[cfg(test)] + #[cfg(any(test, feature = "test"))] pub fn clone_rpc_reply(_: &ractor::RpcReplyPort) -> ractor::RpcReplyPort { ractor::concurrency::oneshot().0.into() } @@ -744,18 +743,15 @@ pub mod actors { T: Actor, { inner: T, - pub notify: Arc, + pub sender: mpsc::Sender, } impl TestableActor where T: Actor, { - pub fn new(inner: T) -> Self { - Self { - inner, - notify: Arc::new(Notify::new()), - } + pub fn new(inner: T, sender: mpsc::Sender) -> Self { + Self { inner, sender } } } @@ -794,6 +790,7 @@ pub mod actors { impl Actor for TestableActor where T: Actor, + T::Msg: Debug + Clone, { type Msg = T::Msg; type State = T::State; @@ -821,8 +818,9 @@ pub mod actors { msg: Self::Msg, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { + let message = msg.clone(); let result = self.inner.handle(myself, msg, state).await; - self.notify.notify_one(); + self.sender.send(message).await.expect("Channel failed"); result } diff --git a/crates/tap-agent/tests/sender_account_manager_test.rs b/crates/tap-agent/tests/sender_account_manager_test.rs index cbe1c5ae..d644dbfd 100644 --- a/crates/tap-agent/tests/sender_account_manager_test.rs +++ b/crates/tap-agent/tests/sender_account_manager_test.rs @@ -62,7 +62,7 @@ async fn sender_account_manager_layer_test(pgpool: PgPool) { )) .await; - let (prefix, notify, (actor, join_handle)) = create_sender_accounts_manager() + let (prefix, mut msg_receiver, (actor, join_handle)) = create_sender_accounts_manager() .pgpool(pgpool.clone()) .network_subgraph(&mock_network_subgraph_server.uri()) .escrow_subgraph(&mock_escrow_subgraph_server.uri()) @@ -78,7 +78,7 @@ async fn sender_account_manager_layer_test(pgpool: PgPool) { vec![SENDER.1].into_iter().collect(), )) .unwrap(); - flush_messages(¬ify).await; + flush_messages(&mut msg_receiver).await; assert_while_retry!({ ActorRef::::where_is(format!( "{}:legacy:{}", diff --git a/crates/tap-agent/tests/sender_account_test.rs b/crates/tap-agent/tests/sender_account_test.rs index c2a1c379..3a13bbed 100644 --- a/crates/tap-agent/tests/sender_account_test.rs +++ b/crates/tap-agent/tests/sender_account_test.rs @@ -39,7 +39,7 @@ async fn sender_account_layer_test(pgpool: PgPool) { .await .unwrap(); - let (sender_account, notify, _, _) = create_sender_account() + let (sender_account, mut msg_receiver, _, _) = create_sender_account() .pgpool(pgpool.clone()) .max_amount_willing_to_lose_grt(TRIGGER_VALUE + 1000) .escrow_subgraph_endpoint(&mock_escrow_subgraph_server.uri()) @@ -48,14 +48,17 @@ async fn sender_account_layer_test(pgpool: PgPool) { .await; // we expect it to create a sender allocation + let allocation_ids = HashSet::from_iter([AllocationId::Legacy(ALLOCATION_ID_0)]); sender_account .cast(SenderAccountMessage::UpdateAllocationIds( - vec![AllocationId::Legacy(ALLOCATION_ID_0)] - .into_iter() - .collect(), + allocation_ids.clone(), )) .unwrap(); - notify.notified().await; + let msg = msg_receiver.recv().await.expect("Channel failed"); + assert_eq!( + msg, + SenderAccountMessage::UpdateAllocationIds(allocation_ids) + ); mock_server .register( diff --git a/crates/tap-agent/tests/tap_agent_test.rs b/crates/tap-agent/tests/tap_agent_test.rs index a58f19e4..c0913d46 100644 --- a/crates/tap-agent/tests/tap_agent_test.rs +++ b/crates/tap-agent/tests/tap_agent_test.rs @@ -4,7 +4,6 @@ use std::{ collections::{HashMap, HashSet}, str::FromStr, - sync::Arc, time::Duration, }; @@ -29,13 +28,13 @@ use test_assets::{ INDEXER_ALLOCATIONS, TAP_EIP712_DOMAIN, TAP_SENDER, TAP_SIGNER, }; use thegraph_core::alloy::primitives::Address; -use tokio::sync::{watch, Notify}; +use tokio::sync::{mpsc, watch}; use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; pub async fn start_agent( pgpool: PgPool, ) -> ( - Arc, + mpsc::Receiver, (ActorRef, JoinHandle<()>), ) { let escrow_subgraph_mock_server: MockServer = MockServer::start().await; @@ -108,15 +107,15 @@ pub async fn start_agent( prefix: None, }; - let actor = TestableActor::new(SenderAccountsManager); - let notify = actor.notify.clone(); - (notify, Actor::spawn(None, actor, args).await.unwrap()) + let (sender, receiver) = mpsc::channel(10); + let actor = TestableActor::new(SenderAccountsManager, sender); + (receiver, Actor::spawn(None, actor, args).await.unwrap()) } #[sqlx::test(migrations = "../../migrations")] async fn test_start_tap_agent(pgpool: PgPool) { - let (notify, (_actor_ref, _handle)) = start_agent(pgpool.clone()).await; - flush_messages(¬ify).await; + let (mut msg_receiver, (_actor_ref, _handle)) = start_agent(pgpool.clone()).await; + flush_messages(&mut msg_receiver).await; // verify if create sender account assert_while_retry!(ActorRef::::where_is(format!( diff --git a/crates/test-assets/src/lib.rs b/crates/test-assets/src/lib.rs index 6efde0d7..bb8689e4 100644 --- a/crates/test-assets/src/lib.rs +++ b/crates/test-assets/src/lib.rs @@ -20,7 +20,7 @@ use thegraph_core::{ }, deployment_id, DeploymentId, }; -use tokio::sync::Notify; +use tokio::sync::mpsc; /// Assert something is true while sleeping and retrying /// @@ -378,9 +378,9 @@ pub async fn create_signed_receipt_v2( .unwrap() } -pub async fn flush_messages(notify: &Notify) { +pub async fn flush_messages(notify: &mut mpsc::Receiver) { loop { - if tokio::time::timeout(Duration::from_millis(10), notify.notified()) + if tokio::time::timeout(Duration::from_millis(10), notify.recv()) .await .is_err() {