Skip to content

Commit

Permalink
fix(ampd): directly pass interval to queued broadcaster and ignore fi…
Browse files Browse the repository at this point in the history
…rst tick (axelarnetwork#480)

* fix: directly pass interval to queued broadcaster and pass the first unwanted tick

* refactor: address PR comments

* refactor: simplify steps to take after interval tick in run function
  • Loading branch information
maancham authored Jun 28, 2024
1 parent 65c23ee commit 918f064
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
3 changes: 2 additions & 1 deletion ampd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use evm::json_rpc::EthereumClient;
use router_api::ChainName;
use thiserror::Error;
use tokio::signal::unix::{signal, SignalKind};
use tokio::time::interval;
use tokio_util::sync::CancellationToken;
use tracing::info;

Expand Down Expand Up @@ -178,7 +179,7 @@ where
broadcaster,
broadcast_cfg.batch_gas_limit,
broadcast_cfg.queue_cap,
broadcast_cfg.broadcast_interval,
interval(broadcast_cfg.broadcast_interval),
);

Self {
Expand Down
80 changes: 38 additions & 42 deletions ampd/src/queue/queued_broadcaster.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use std::time::Duration;

use async_trait::async_trait;
use axelar_wasm_std::FnExt;
use cosmrs::{Any, Gas};
use error_stack::{self, Report, ResultExt};
use mockall::automock;
use thiserror::Error;
use tokio::select;
use tokio::sync::{mpsc, oneshot};
use tokio::time;
use tokio::time::Interval;
use tracing::info;
use tracing::warn;

Expand Down Expand Up @@ -60,8 +57,8 @@ where
broadcaster: T,
queue: MsgQueue,
batch_gas_limit: Gas,
broadcast_interval: Duration,
channel: Option<(mpsc::Sender<MsgAndResChan>, mpsc::Receiver<MsgAndResChan>)>,
broadcast_interval: Interval,
}

impl<T> QueuedBroadcaster<T>
Expand All @@ -72,14 +69,14 @@ where
broadcaster: T,
batch_gas_limit: Gas,
capacity: usize,
broadcast_interval: Duration,
broadcast_interval: Interval,
) -> Self {
Self {
broadcaster,
queue: MsgQueue::default(),
batch_gas_limit,
broadcast_interval,
channel: Some(mpsc::channel(capacity)),
broadcast_interval,
}
}

Expand All @@ -88,15 +85,17 @@ where
.channel
.take()
.expect("broadcast channel is expected to be set during initialization and must be available when running the broadcaster");
let mut interval = time::interval(self.broadcast_interval);

loop {
select! {
msg = rx.recv() => match msg {
None => break,
Some(msg_and_res_chan) => interval = self.handle_msg(interval, msg_and_res_chan).await?,
Some(msg_and_res_chan) => self.handle_msg(msg_and_res_chan).await?,
},
_ = self.broadcast_interval.tick() => {
self.broadcast_all().await?;
self.broadcast_interval.reset();
},
_ = interval.tick() => self.broadcast_all().await?.then(|_| {interval.reset()}),
}
}

Expand Down Expand Up @@ -133,11 +132,7 @@ where
}
}

async fn handle_msg(
&mut self,
mut interval: time::Interval,
msg_and_res_chan: MsgAndResChan,
) -> Result<time::Interval> {
async fn handle_msg(&mut self, msg_and_res_chan: MsgAndResChan) -> Result<()> {
let (msg, tx) = msg_and_res_chan;

match self.broadcaster.estimate_fee(vec![msg.clone()]).await {
Expand All @@ -151,7 +146,7 @@ where
"exceeded batch gas limit. gas limit can be adjusted in ampd config"
);
self.broadcast_all().await?;
interval.reset();
self.broadcast_interval.reset();
}

let message_type = msg.type_url.clone();
Expand All @@ -171,7 +166,7 @@ where
}
}

Ok(interval)
Ok(())
}
}

Expand All @@ -183,7 +178,7 @@ mod test {
use cosmrs::{bank::MsgSend, tx::Msg, AccountId};
use error_stack::Report;
use tokio::test;
use tokio::time::{sleep, Duration};
use tokio::time::{interval, Duration};

use super::{Error, QueuedBroadcaster};
use crate::broadcaster::{self, MockBroadcaster};
Expand All @@ -196,8 +191,8 @@ mod test {
.expect_estimate_fee()
.return_once(|_| Err(Report::new(broadcaster::Error::FeeEstimation)));

let queued_broadcaster =
QueuedBroadcaster::new(broadcaster, 100, 10, Duration::from_secs(5));
let broadcast_interval = interval(Duration::from_secs(5));
let queued_broadcaster = QueuedBroadcaster::new(broadcaster, 100, 10, broadcast_interval);
let client = queued_broadcaster.client();
let handle = tokio::spawn(queued_broadcaster.run());

Expand All @@ -214,7 +209,7 @@ mod test {
assert!(handle.await.unwrap().is_ok());
}

#[test]
#[test(start_paused = true)]
async fn should_not_broadcast_when_gas_limit_has_not_been_reached() {
let tx_count = 9;
let batch_gas_limit = 100;
Expand All @@ -236,17 +231,17 @@ mod test {
.expect_broadcast()
.once()
.returning(move |msgs| {
assert!(msgs.len() == tx_count);
assert_eq!(msgs.len(), tx_count);

Ok(TxResponse::default())
});

let queued_broadcaster = QueuedBroadcaster::new(
broadcaster,
batch_gas_limit,
tx_count,
Duration::from_secs(5),
);
let mut broadcast_interval = interval(Duration::from_secs(5));
// get rid of tick on startup
broadcast_interval.tick().await;

let queued_broadcaster =
QueuedBroadcaster::new(broadcaster, batch_gas_limit, tx_count, broadcast_interval);
let client = queued_broadcaster.client();
let handle = tokio::spawn(queued_broadcaster.run());

Expand All @@ -258,11 +253,10 @@ mod test {
assert!(handle.await.unwrap().is_ok());
}

#[test]
#[test(start_paused = true)]
async fn should_broadcast_when_broadcast_interval_has_been_reached() {
let tx_count = 9;
let batch_gas_limit = 100;
let broadcast_interval = Duration::from_millis(100);
let gas_limit = 10;

let mut broadcaster = MockBroadcaster::new();
Expand All @@ -281,10 +275,13 @@ mod test {
.expect_broadcast()
.once()
.returning(move |msgs| {
assert!(msgs.len() == tx_count);
assert_eq!(msgs.len(), tx_count);

Ok(TxResponse::default())
});
let mut broadcast_interval = interval(Duration::from_millis(100));
// get rid of tick on startup
broadcast_interval.tick().await;

let queued_broadcaster =
QueuedBroadcaster::new(broadcaster, batch_gas_limit, tx_count, broadcast_interval);
Expand All @@ -294,13 +291,12 @@ mod test {
for _ in 0..tx_count {
client.broadcast(dummy_msg()).await.unwrap();
}
sleep(broadcast_interval).await;
drop(client);

assert!(handle.await.unwrap().is_ok());
}

#[test]
#[test(start_paused = true)]
async fn should_broadcast_when_gas_limit_has_been_reached() {
let tx_count = 10;
let batch_gas_limit = 100;
Expand All @@ -322,31 +318,31 @@ mod test {
.expect_broadcast()
.once()
.returning(move |msgs| {
assert!(msgs.len() == tx_count - 1);
assert_eq!(msgs.len(), tx_count - 1);

Ok(TxResponse::default())
});
broadcaster
.expect_broadcast()
.once()
.returning(move |msgs| {
assert!(msgs.len() == 1);

assert_eq!(msgs.len(), 1);
Ok(TxResponse::default())
});

let queued_broadcaster = QueuedBroadcaster::new(
broadcaster,
batch_gas_limit,
tx_count,
Duration::from_secs(5),
);
let mut broadcast_interval = interval(Duration::from_secs(5));
// get rid of tick on startup
broadcast_interval.tick().await;

let queued_broadcaster =
QueuedBroadcaster::new(broadcaster, batch_gas_limit, tx_count, broadcast_interval);
let client = queued_broadcaster.client();
let handle = tokio::spawn(queued_broadcaster.run());

for _ in 0..tx_count {
client.broadcast(dummy_msg()).await.unwrap();
}

drop(client);

assert!(handle.await.unwrap().is_ok());
Expand Down

0 comments on commit 918f064

Please sign in to comment.