Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 35 additions & 1 deletion common/src/rest_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ where
}

// Handle a REST request with query parameters
pub fn handle_rest_with_query_parameter<F, Fut>(
pub fn handle_rest_with_query_parameters<F, Fut>(
context: Arc<Context<Message>>,
topic: &str,
handler: F,
Expand Down Expand Up @@ -142,3 +142,37 @@ impl<T: ToPrimitive> ToCheckedF64 for T {
self.to_f64().ok_or_else(|| anyhow!("Failed to convert {name} to f64"))
}
}

// Macros for extracting and validating REST query parameters
#[macro_export]
macro_rules! extract_strict_query_params {
($params:expr, { $($key:literal => $var:ident : Option<$type:ty>,)* }) => {
$(
let mut $var: Option<$type> = None;
)*

for (k, v) in &$params {
match k.as_str() {
$(
$key => {
$var = match v.parse::<$type>() {
Ok(val) => Some(val),
Err(_) => {
return Ok($crate::messages::RESTResponse::with_text(
400,
concat!("Invalid ", $key, " query parameter: must be a valid type"),
));
}
};
}
)*
_ => {
return Ok($crate::messages::RESTResponse::with_text(
400,
concat!("Unexpected query parameter: only allowed keys are: ", $( $key, " ", )*)
));
}
}
}
};
}
1 change: 1 addition & 0 deletions modules/drdd_state/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tokio = { version = "1", features = ["full"] }
serde = { version = "1.0.214", features = ["derive"] }
serde_json = "1.0.132"
hex = "0.4.3"
imbl = { version = "5.0.0", features = ["serde"] }

[lib]
path = "src/drdd_state.rs"
6 changes: 3 additions & 3 deletions modules/drdd_state/src/drdd_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! Stores historical DRep delegation distributions
use acropolis_common::{
messages::{CardanoMessage, Message},
rest_helper::handle_rest_with_query_parameter,
rest_helper::handle_rest_with_query_parameters,
};
use anyhow::Result;
use caryatid_sdk::{module, Context, Module};
Expand Down Expand Up @@ -57,7 +57,7 @@ impl DRDDState {
};
match message.as_ref() {
Message::Cardano((_, CardanoMessage::DRepStakeDistribution(msg))) => {
let span = info_span!("spdd_state.handle", epoch = msg.epoch);
let span = info_span!("drdd_state.handle", epoch = msg.epoch);
async {
let mut state = state_handler.lock().await;

Expand Down Expand Up @@ -110,7 +110,7 @@ impl DRDDState {
};

// Register /drdd REST endpoint
handle_rest_with_query_parameter(context.clone(), &handle_drdd_topic, move |params| {
handle_rest_with_query_parameters(context.clone(), &handle_drdd_topic, move |params| {
let state_rest = state_opt.clone();
handle_drdd(state_rest.clone(), params)
});
Expand Down
41 changes: 12 additions & 29 deletions modules/drdd_state/src/rest.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::state::State;
use acropolis_common::{messages::RESTResponse, DRepCredential};
use acropolis_common::{extract_strict_query_params, messages::RESTResponse, DRepCredential};
use anyhow::Result;
use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
Expand Down Expand Up @@ -28,38 +28,21 @@ pub async fn handle_drdd(
}
};

let drdd_opt = if let Some(epoch_str) = params.get("epoch") {
if params.len() > 1 {
return Ok(RESTResponse::with_text(
400,
"Only 'epoch' is a valid query parameter",
));
}
extract_strict_query_params!(params, {
"epoch" => epoch: Option<u64>,
});

match epoch_str.parse::<u64>() {
Ok(epoch) => match locked.get_epoch(epoch) {
Some(drdd) => Some(drdd),
None => {
return Ok(RESTResponse::with_text(
404,
&format!("DRDD not found for epoch {}", epoch),
));
}
},
Err(_) => {
let drdd_opt = match epoch {
Some(epoch) => match locked.get_epoch(epoch) {
Some(drdd) => Some(drdd),
None => {
return Ok(RESTResponse::with_text(
400,
"Invalid epoch query parameter: must be a number",
404,
&format!("DRDD not found for epoch {}", epoch),
));
}
}
} else if params.is_empty() {
locked.get_latest()
} else {
return Ok(RESTResponse::with_text(
400,
"Unexpected query parameter: only 'epoch' is allowed",
));
},
None => locked.get_latest(),
};

if let Some(drdd) = drdd_opt {
Expand Down
17 changes: 8 additions & 9 deletions modules/drdd_state/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
use acropolis_common::DRepCredential;
use std::collections::BTreeMap;
use imbl::OrdMap;
use tracing::info;

pub struct State {
historical_distributions: BTreeMap<u64, DRepDistribution>,
historical_distributions: OrdMap<u64, DRepDistribution>,
}

#[derive(Clone)]
pub struct DRepDistribution {
pub dreps: BTreeMap<DRepCredential, u64>,
pub dreps: OrdMap<DRepCredential, u64>,
pub abstain: u64,
pub no_confidence: u64,
}

impl State {
pub fn new() -> Self {
Self {
historical_distributions: BTreeMap::new(),
historical_distributions: OrdMap::new(),
}
}

pub fn insert_drdd(&mut self, epoch: u64, drdd: DRepDistribution) {
self.historical_distributions.insert(epoch, drdd);
}

pub fn get_latest(&self) -> Option<DRepDistribution> {
self.historical_distributions.last_key_value().map(|(_, map)| map.clone())
pub fn get_latest(&self) -> Option<&DRepDistribution> {
self.historical_distributions.iter().next_back().map(|(_, map)| map)
}

pub fn get_epoch(&self, epoch: u64) -> Option<DRepDistribution> {
self.historical_distributions.get(&epoch).cloned()
pub fn get_epoch(&self, epoch: u64) -> Option<&DRepDistribution> {
self.historical_distributions.get(&epoch)
}

pub async fn tick(&self) -> anyhow::Result<()> {
Expand Down
1 change: 1 addition & 0 deletions modules/spdd_state/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tracing = "0.1.40"
tokio = { version = "1", features = ["full"] }
serde_json = "1.0.132"
hex = "0.4.3"
imbl = { version = "5.0.0", features = ["serde"] }

[lib]
path = "src/spdd_state.rs"
41 changes: 12 additions & 29 deletions modules/spdd_state/src/rest.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::state::State;
use acropolis_common::messages::RESTResponse;
use acropolis_common::serialization::Bech32WithHrp;
use acropolis_common::{extract_strict_query_params, messages::RESTResponse};
use anyhow::Result;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
Expand All @@ -20,38 +20,21 @@ pub async fn handle_spdd(
}
};

let spdd_opt = if let Some(epoch_str) = params.get("epoch") {
if params.len() > 1 {
return Ok(RESTResponse::with_text(
400,
"Only 'epoch' is a valid query parameter",
));
}
extract_strict_query_params!(params, {
"epoch" => epoch: Option<u64>,
});

match epoch_str.parse::<u64>() {
Ok(epoch) => match locked.get_epoch(epoch) {
Some(spdd) => Some(spdd),
None => {
return Ok(RESTResponse::with_text(
404,
&format!("SPDD not found for epoch {}", epoch),
));
}
},
Err(_) => {
let spdd_opt = match epoch {
Some(epoch) => match locked.get_epoch(epoch) {
Some(spdd) => Some(spdd),
None => {
return Ok(RESTResponse::with_text(
400,
"Invalid epoch query parameter: must be a number",
404,
&format!("SPDD not found for epoch {}", epoch),
));
}
}
} else if params.is_empty() {
locked.get_latest()
} else {
return Ok(RESTResponse::with_text(
400,
"Unexpected query parameter: only 'epoch' is allowed",
));
},
None => locked.get_latest(),
};

if let Some(spdd) = spdd_opt {
Expand Down
9 changes: 5 additions & 4 deletions modules/spdd_state/src/spdd_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
//! Stores historical stake pool delegation distributions
use acropolis_common::{
messages::{CardanoMessage, Message},
rest_helper::handle_rest_with_query_parameter,
rest_helper::handle_rest_with_query_parameters,
DelegatedStake, KeyHash,
};
use anyhow::Result;
use caryatid_sdk::{module, Context, Module};
use config::Config;
use std::{collections::BTreeMap, sync::Arc};
use imbl::OrdMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{error, info, info_span, Instrument};
mod state;
Expand Down Expand Up @@ -60,7 +61,7 @@ impl SPDDState {
async {
let mut state = state_handler.lock().await;

let spdd: BTreeMap<KeyHash, DelegatedStake> =
let spdd: OrdMap<KeyHash, DelegatedStake> =
Comment thread
whankinsiv marked this conversation as resolved.
Outdated
msg.spos.iter().map(|(k, v)| (k.clone(), *v)).collect();

state.insert_spdd(msg.epoch, spdd);
Expand Down Expand Up @@ -106,7 +107,7 @@ impl SPDDState {
};

// Register /spdd REST endpoint
handle_rest_with_query_parameter(context.clone(), &handle_spdd_topic, move |params| {
handle_rest_with_query_parameters(context.clone(), &handle_spdd_topic, move |params| {
let state_rest = state_opt.clone();
handle_spdd(state_rest.clone(), params)
});
Expand Down
16 changes: 8 additions & 8 deletions modules/spdd_state/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
use acropolis_common::{DelegatedStake, KeyHash};
use std::collections::BTreeMap;
use imbl::OrdMap;
use tracing::info;

pub struct State {
historical_distributions: BTreeMap<u64, BTreeMap<KeyHash, DelegatedStake>>,
historical_distributions: OrdMap<u64, OrdMap<KeyHash, DelegatedStake>>,
}

impl State {
pub fn new() -> Self {
Self {
historical_distributions: BTreeMap::new(),
historical_distributions: OrdMap::new(),
}
}

pub fn insert_spdd(&mut self, epoch: u64, spdd: BTreeMap<KeyHash, DelegatedStake>) {
pub fn insert_spdd(&mut self, epoch: u64, spdd: OrdMap<KeyHash, DelegatedStake>) {
self.historical_distributions.insert(epoch, spdd);
}

pub fn get_latest(&self) -> Option<BTreeMap<KeyHash, DelegatedStake>> {
self.historical_distributions.last_key_value().map(|(_, map)| map.clone())
pub fn get_latest(&self) -> Option<&OrdMap<KeyHash, DelegatedStake>> {
self.historical_distributions.iter().next_back().map(|(_, map)| map)
}

pub fn get_epoch(&self, epoch: u64) -> Option<BTreeMap<KeyHash, DelegatedStake>> {
self.historical_distributions.get(&epoch).cloned()
pub fn get_epoch(&self, epoch: u64) -> Option<&OrdMap<KeyHash, DelegatedStake>> {
self.historical_distributions.get(&epoch)
}

pub async fn tick(&self) -> anyhow::Result<()> {
Expand Down