From 04234f2080b8cbd7f9ff8b7f2ff5ad5d70a9ca17 Mon Sep 17 00:00:00 2001 From: Darnell Andries Date: Thu, 18 Jul 2024 16:22:21 -0700 Subject: [PATCH] Add Nitriding key sync support --- Cargo.lock | 170 ++++++++++++++++++- Cargo.toml | 6 +- src/handler.rs | 145 ++++++++-------- src/instance.rs | 88 ++++++++++ src/main.rs | 43 +++-- src/result.rs | 27 +++ src/state.rs | 439 ++++++++++++++++++++++++++++-------------------- src/tests.rs | 348 +++++++++++++++++++++++++++++++------- src/util.rs | 21 +++ 9 files changed, 948 insertions(+), 339 deletions(-) create mode 100644 src/instance.rs create mode 100644 src/result.rs diff --git a/Cargo.lock b/Cargo.lock index d0c07f0..70de1cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,6 +242,7 @@ checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" dependencies = [ "funty", "radium", + "serde", "tap", "wyz", ] @@ -340,6 +341,22 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.11" @@ -421,6 +438,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + [[package]] name = "fiat-crypto" version = "0.2.5" @@ -688,6 +714,7 @@ dependencies = [ "itoa", "pin-project-lite", "tokio", + "want", ] [[package]] @@ -697,6 +724,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" dependencies = [ "bytes", + "futures-channel", "futures-util", "http 1.0.0", "http-body 1.0.0", @@ -704,6 +732,19 @@ dependencies = [ "pin-project-lite", "socket2 0.5.5", "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", ] [[package]] @@ -994,9 +1035,7 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppoprf" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d97a7a93423a5988e153e29de494b2625528196476213a4401e91140487faf1a" +version = "0.4.0" dependencies = [ "base64 0.13.1", "bincode", @@ -1147,6 +1186,43 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "reqwest" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "http-body-util", + "hyper 1.1.0", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.0", + "system-configuration", + "tokio", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "rlimit" version = "0.10.1" @@ -1309,16 +1385,18 @@ dependencies = [ [[package]] name = "star-randsrv" -version = "0.2.0" +version = "0.3.0" dependencies = [ "axum", "axum-prometheus", "base64 0.22.1", + "bincode", "calendar-duration", "clap", "curve25519-dalek", "ppoprf", "rand", + "reqwest", "rlimit", "serde", "serde_json", @@ -1391,6 +1469,27 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384595c11a4e2969895cad5a8c4029115f5ab956a9e5ef4de79d11a426e5f20c" +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tap" version = "1.0.1" @@ -1478,6 +1577,21 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.38.0" @@ -1621,12 +1735,38 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "utf8parse" version = "0.2.1" @@ -1685,6 +1825,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.89" @@ -1812,6 +1964,16 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "winreg" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" +dependencies = [ + "cfg-if", + "windows-sys", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 4a31d51..900d496 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "star-randsrv" -version = "0.2.0" +version = "0.3.0" authors = ["Ralph Giles "] description = "STAR randomness webservice" license = "MPL-2.0" @@ -10,9 +10,11 @@ edition = "2021" axum = "0.7.5" axum-prometheus = "0.6.1" base64 = "0.22.1" +bincode = "1.3.3" calendar-duration = "1.0.0" clap = { version = "4.5.4", features = ["derive"] } -ppoprf = "0.3.1" +ppoprf = { version = "0.4.0", path = "../sta-rs/ppoprf" } +reqwest = { version = "0.12.5", default-features = false, features = ["charset", "macos-system-configuration"] } rlimit = "0.10" serde = "1.0.200" serde_json = "1.0.115" diff --git a/src/handler.rs b/src/handler.rs index 749e677..a68157a 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,14 +1,16 @@ //! STAR Randomness web service route implementation -use std::sync::RwLockReadGuard; - +use axum::body::Bytes; use axum::extract::{Json, Path, State}; use axum::http::StatusCode; use base64::prelude::{Engine as _, BASE64_STANDARD as BASE64}; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLockReadGuard; use tracing::{debug, instrument}; -use crate::state::{OPRFInstance, OPRFState}; +use crate::instance::OPRFInstance; +use crate::result::{Error, Result}; +use crate::state::OPRFState; use ppoprf::ppoprf; /// Request structure for the randomness endpoint @@ -44,7 +46,7 @@ pub struct InfoResponse { /// Timestamp of the next epoch rotation /// This should be a string in RFC 3339 format, /// e.g. 2023-03-14T16:33:05Z. - next_epoch_time: Option, + next_epoch_time: String, /// Maximum number of points accepted in a single request max_points: usize, } @@ -67,47 +69,12 @@ struct ErrorResponse { message: String, } -/// Server error conditions -/// -/// Used to generate an `ErrorResponse` from the `?` operator -/// handling requests. -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("instance '{0}' not found")] - InstanceNotFound(String), - #[error("Couldn't lock state: RwLock poisoned")] - LockFailure, - #[error("Invalid point")] - BadPoint, - #[error("Too many points for a single request")] - TooManyPoints, - #[error("Invalid epoch {0}`")] - BadEpoch(u8), - #[error("Invalid base64 encoding: {0}")] - Base64(#[from] base64::DecodeError), - #[error("PPOPRF error: {0}")] - Oprf(#[from] ppoprf::PPRFError), -} - -/// thiserror doesn't generate a `From` impl without -/// an inner value to wrap. Write one explicitly for -/// `std::sync::PoisonError` to avoid making the -/// whole `Error` struct generic. This allows us to -/// use `?` with `RwLock` methods instead of an -/// explicit `.map_err()`. -impl From> for Error { - fn from(_: std::sync::PoisonError) -> Self { - Error::LockFailure - } -} - impl axum::response::IntoResponse for Error { /// Construct an http response from our error type fn into_response(self) -> axum::response::Response { let code = match self { Error::InstanceNotFound(_) => StatusCode::NOT_FOUND, - // This indicates internal failure. - Error::LockFailure => StatusCode::INTERNAL_SERVER_ERROR, + Error::PPOPRFNotReady => StatusCode::SERVICE_UNAVAILABLE, // Other cases are the client's fault. _ => StatusCode::BAD_REQUEST, }; @@ -118,17 +85,16 @@ impl axum::response::IntoResponse for Error { } } -type Result = std::result::Result; - -fn get_server_from_state<'a>( +async fn get_server_from_state<'a>( state: &'a OPRFState, instance_name: &'a str, -) -> Result> { +) -> Result>> { Ok(state .instances .get(instance_name) .ok_or_else(|| Error::InstanceNotFound(instance_name.to_string()))? - .read()?) + .read() + .await) } /// Process PPOPRF evaluation requests @@ -139,31 +105,36 @@ async fn randomness( request: RandomnessRequest, ) -> Result> { debug!("recv: {request:?}"); - let state = get_server_from_state(&state, &instance_name)?; - let epoch = request.epoch.unwrap_or(state.epoch); - if epoch != state.epoch { - return Err(Error::BadEpoch(epoch)); - } - if request.points.len() > crate::MAX_POINTS { - return Err(Error::TooManyPoints); - } - // Don't support returning proofs until we have a more - // space-efficient batch proof implemented in ppoprf. - let mut points = Vec::with_capacity(request.points.len()); - for base64_point in request.points { - let input = BASE64.decode(base64_point)?; - // FIXME: Point::from is fallible and needs to return a result. - // partial work-around: check correct length - if input.len() != ppoprf::COMPRESSED_POINT_LEN { - return Err(Error::BadPoint); + let state_guard = get_server_from_state(&state, &instance_name).await?; + match state_guard.as_ref() { + None => Err(Error::PPOPRFNotReady), + Some(state) => { + let epoch = request.epoch.unwrap_or(state.epoch); + if epoch != state.epoch { + return Err(Error::BadEpoch(epoch)); + } + if request.points.len() > crate::MAX_POINTS { + return Err(Error::TooManyPoints); + } + // Don't support returning proofs until we have a more + // space-efficient batch proof implemented in ppoprf. + let mut points = Vec::with_capacity(request.points.len()); + for base64_point in request.points { + let input = BASE64.decode(base64_point)?; + // FIXME: Point::from is fallible and needs to return a result. + // partial work-around: check correct length + if input.len() != ppoprf::COMPRESSED_POINT_LEN { + return Err(Error::BadPoint); + } + let point = ppoprf::Point::from(input.as_slice()); + let evaluation = state.server.eval(&point, epoch, false)?; + points.push(BASE64.encode(evaluation.output.as_bytes())); + } + let response = RandomnessResponse { points, epoch }; + debug!("send: {response:?}"); + Ok(Json(response)) } - let point = ppoprf::Point::from(input.as_slice()); - let evaluation = state.server.eval(&point, epoch, false)?; - points.push(BASE64.encode(evaluation.output.as_bytes())); } - let response = RandomnessResponse { points, epoch }; - debug!("send: {response:?}"); - Ok(Json(response)) } /// Process PPOPRF evaluation requests using default instance @@ -188,17 +159,22 @@ pub async fn specific_instance_randomness( #[instrument(skip(state))] async fn info(state: OPRFState, instance_name: String) -> Result> { debug!("recv: info request"); - let state = get_server_from_state(&state, &instance_name)?; - let public_key = state.server.get_public_key().serialize_to_bincode()?; - let public_key = BASE64.encode(public_key); - let response = InfoResponse { - current_epoch: state.epoch, - next_epoch_time: state.next_epoch_time.clone(), - max_points: crate::MAX_POINTS, - public_key, - }; - debug!("send: {response:?}"); - Ok(Json(response)) + let state_guard = get_server_from_state(&state, &instance_name).await?; + match state_guard.as_ref() { + None => Err(Error::PPOPRFNotReady), + Some(state) => { + let public_key = state.server.get_public_key().serialize_to_bincode()?; + let public_key = BASE64.encode(public_key); + let response = InfoResponse { + current_epoch: state.epoch, + next_epoch_time: state.next_epoch_time.clone(), + max_points: crate::MAX_POINTS, + public_key, + }; + debug!("send: {response:?}"); + Ok(Json(response)) + } + } } /// Provide PPOPRF epoch and key metadata using default instance @@ -222,3 +198,14 @@ pub async fn list_instances(State(state): State) -> Result, body: Bytes) -> Result<()> { + state.set_private_keys(body).await +} + +/// Generates & exports keys so that nitriding and forward the keys to worker enclaves. +pub async fn get_ppoprf_private_key(State(state): State) -> Result> { + state.create_missing_instances().await; + state.get_private_keys().await +} diff --git a/src/instance.rs b/src/instance.rs new file mode 100644 index 0000000..8eea9d4 --- /dev/null +++ b/src/instance.rs @@ -0,0 +1,88 @@ +use calendar_duration::CalendarDuration; +use tokio::task::JoinHandle; +use tracing::info; + +use crate::result::Result; +use crate::{util::format_rfc3339, Config}; +use ppoprf::ppoprf; + +/// Internal state of an OPRF instance +pub struct OPRFInstance { + /// oprf implementation + pub server: ppoprf::Server, + /// currently-valid randomness epoch + pub epoch: u8, + /// Duration of each epoch + pub epoch_duration: CalendarDuration, + /// RFC 3339 timestamp of the next epoch rotation + pub next_epoch_time: String, + /// Handle for the background task associated with the instance + pub background_task_handle: Option>, +} + +impl OPRFInstance { + /// Initialize a new OPRFServer state with the given configuration + pub fn new( + config: &Config, + instance_name: &str, + puncture_previous_epochs: bool, + ) -> Result { + let epochs_range = config.first_epoch..=config.last_epoch; + let mut server = ppoprf::Server::new(epochs_range.clone().collect())?; + + // Get epoch duration matching the instance name. + let instance_index = config + .instance_names + .iter() + .position(|name| name == instance_name) + .unwrap(); + let epoch_duration = config.epoch_durations[instance_index]; + + // Get base time for calculating curren epochs + let now = time::OffsetDateTime::now_utc() + .replace_millisecond(0) + .expect("failed to remove millisecond component from OffsetDateTime"); + let base_time = config.epoch_base_time.unwrap_or(now); + + assert!(now >= base_time, "epoch-base-time should be in the past"); + + // Calculate the total amount of epochs elapsed since the base time + // and time of next rotation by using the epoch_duration to iterate + // from the base time until now. + let mut elapsed_epoch_count = 0; + let mut next_epoch_time = base_time + epoch_duration; + while next_epoch_time <= now { + next_epoch_time = next_epoch_time + epoch_duration; + elapsed_epoch_count += 1; + } + + // Calculate the current epoch using modulo arithmetic. + let offset = elapsed_epoch_count % epochs_range.len(); + let current_epoch = config.first_epoch + offset as u8; + + // puncture_previous_epochs should be false if the keys will be + // explictly set after construction, since the synced key will include + // punctured information. + if current_epoch != config.first_epoch && puncture_previous_epochs { + // Advance to the current epoch if base time indicates we started + // in the middle of a sequence. + info!( + "Puncturing obsolete epochs {}..{} to match base time", + config.first_epoch, current_epoch + ); + for epoch in config.first_epoch..current_epoch { + server + .puncture(epoch) + .expect("Failed to puncture obsolete epoch"); + } + } + + Ok(OPRFInstance { + server, + epoch: current_epoch, + epoch_duration, + next_epoch_time: format_rfc3339(&next_epoch_time), + background_task_handle: None, + }) + } +} diff --git a/src/main.rs b/src/main.rs index 3e9b491..f696c87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,11 @@ //! STAR Randomness web service -use axum::{routing::get, routing::post, Router}; -use axum_prometheus::PrometheusMetricLayer; +use axum::{ + routing::{get, post, put}, + Router, +}; use axum_prometheus::metrics_exporter_prometheus::PrometheusHandle; +use axum_prometheus::PrometheusMetricLayer; use calendar_duration::CalendarDuration; use clap::Parser; use rlimit::Resource; @@ -18,6 +21,8 @@ use util::{assert_unique_names, parse_timestamp}; static GLOBAL: Jemalloc = Jemalloc; mod handler; +mod instance; +mod result; mod state; mod util; @@ -62,12 +67,19 @@ pub struct Config { /// Enable prometheus metric reporting and listen on specified address. #[arg(long)] prometheus_listen: Option, + /// Enable key synchronization via Nitriding to allow horizontal scaling between + /// server replicas. + #[arg(long, default_value_t = false)] + enclave_key_sync: bool, + /// Internal port of Nitriding server within enclave. + #[arg(long)] + nitriding_internal_port: Option, } /// Initialize an axum::Router for our web service /// Having this as a separate function makes testing easier. -fn app(oprf_state: OPRFState) -> Router { - Router::new() +fn app(config: &Config, oprf_state: OPRFState) -> Router { + let mut router = Router::new() // Friendly default route to identify the site .route("/", get(|| async { "STAR randomness server\n" })) // Endpoints for all instances @@ -82,8 +94,14 @@ fn app(oprf_state: OPRFState) -> Router { .route("/instances", get(handler::list_instances)) // Endpoints for default instance .route("/randomness", post(handler::default_instance_randomness)) - .route("/info", get(handler::default_instance_info)) - // Attach shared state + .route("/info", get(handler::default_instance_info)); + if config.enclave_key_sync { + router = router + .route("/enclave/state", put(handler::set_ppoprf_private_key)) + .route("/enclave/state", get(handler::get_ppoprf_private_key)); + } + // Attach shared state + router .with_state(oprf_state) // Logging must come after active routes .layer(tower_http::trace::TraceLayer::new_for_http()) @@ -95,9 +113,7 @@ fn start_prometheus_server(metrics_handle: PrometheusHandle, addr: String) { Router::new().route("/metrics", get(|| async move { metrics_handle.render() })); info!("Metrics server listening on {}", addr); let listener = TcpListener::bind(addr).await.unwrap(); - axum::serve(listener, metrics_app) - .await - .unwrap(); + axum::serve(listener, metrics_app).await.unwrap(); }); } @@ -150,6 +166,10 @@ async fn main() { config.instance_names.len() == config.epoch_durations.len(), "instance-name switch count must match epoch-seconds switch count" ); + assert!( + !config.enclave_key_sync || config.nitriding_internal_port.is_some(), + "nitriding internal port should be defined if key sync is enabled" + ); let metric_layer = config.prometheus_listen.as_ref().map(|listen| { let (layer, handle) = PrometheusMetricLayer::pair(); @@ -157,12 +177,11 @@ async fn main() { layer }); - let oprf_state = OPRFServer::new(&config); - oprf_state.start_background_tasks(&config); + let oprf_state = OPRFServer::new(config.clone()).await; // Set up routes and middleware info!("initializing routes..."); - let mut app = app(oprf_state); + let mut app = app(&config, oprf_state); if let Some(metric_layer) = metric_layer { app = app.layer(metric_layer); } diff --git a/src/result.rs b/src/result.rs new file mode 100644 index 0000000..ef2cc59 --- /dev/null +++ b/src/result.rs @@ -0,0 +1,27 @@ +/// Server error conditions +/// +/// Used to generate an `ErrorResponse` from the `?` operator +/// handling requests. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("instance '{0}' not found")] + InstanceNotFound(String), + #[error("Invalid point")] + BadPoint, + #[error("Too many points for a single request")] + TooManyPoints, + #[error("Invalid epoch {0}`")] + BadEpoch(u8), + #[error("Invalid base64 encoding: {0}")] + Base64(#[from] base64::DecodeError), + #[error("PPOPRF error: {0}")] + Oprf(#[from] ppoprf::PPRFError), + #[error("Key serialization error: {0}")] + KeySerialization(bincode::Error), + #[error("Invalid private key call")] + InvalidPrivateKeyCall, + #[error("PPOPRF not ready")] + PPOPRFNotReady, +} + +pub type Result = std::result::Result; diff --git a/src/state.rs b/src/state.rs index dd15e49..e756773 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,229 +1,312 @@ //! STAR Randomness web service //! Epoch and key state and its management -use calendar_duration::CalendarDuration; +use axum::body::Bytes; +use serde::{Deserialize, Serialize}; use std::{ - collections::HashMap, - sync::{Arc, RwLock}, + collections::{BTreeMap, HashMap}, + sync::Arc, }; -use time::{format_description::well_known::Rfc3339, OffsetDateTime}; -use tracing::{info, instrument}; +use tokio::sync::{OnceCell, RwLock}; +use tracing::{error, info, instrument}; -use crate::Config; +use crate::{ + instance::OPRFInstance, + result::{Error, Result}, +}; +use crate::{ + util::{format_rfc3339, parse_timestamp, send_private_keys_to_nitriding}, + Config, +}; use ppoprf::ppoprf; -/// Internal state of an OPRF instance -pub struct OPRFInstance { - /// oprf implementation - pub server: ppoprf::Server, - /// currently-valid randomness epoch - pub epoch: u8, - /// RFC 3339 timestamp of the next epoch rotation - pub next_epoch_time: Option, -} - -impl OPRFInstance { - /// Initialize a new OPRFServer state with the given configuration - pub fn new(config: &Config) -> Result { - // ppoprf wants a vector, so generate one from our range. - let epochs: Vec = (config.first_epoch..=config.last_epoch).collect(); - let epoch = epochs[0]; - let server = ppoprf::Server::new(epochs)?; - Ok(OPRFInstance { - server, - epoch, - next_epoch_time: None, - }) - } -} - /// Container for OPRF instances pub struct OPRFServer { /// All OPRF instances, keyed by instance name - pub instances: HashMap>, + /// If the instance is None, then key sync is enabled + /// and the server is waiting for nitriding to prompt + /// key generation or restoration. + pub instances: HashMap>>, /// The name of the default instance pub default_instance: String, + /// The config for the server + pub config: Config, + /// Will only be initialized if key sync is enabled. + /// If set, the state will reflect the leader/worker status + /// of the server. + pub is_leader: OnceCell, } /// Arc wrapper for OPRFServer pub type OPRFState = Arc; -struct StartingEpochInfo { - elapsed_epoch_count: usize, - next_rotation: OffsetDateTime, +/// Structure containing PPOPRF key information. +/// Used when deserializing and setting keys. +#[derive(Deserialize)] +pub struct KeyInfo { + pub key_state: ppoprf::ServerKeyState, + pub epoch: u8, } -impl StartingEpochInfo { - fn calculate(base_time: OffsetDateTime, instance_epoch_duration: CalendarDuration) -> Self { - let now = time::OffsetDateTime::now_utc(); - let mut elapsed_epoch_count = 0; - let mut next_rotation = base_time + instance_epoch_duration; - while next_rotation < now { - next_rotation = next_rotation + instance_epoch_duration; - elapsed_epoch_count += 1; - } - Self { - elapsed_epoch_count, - next_rotation, - } - } +/// Structure containing PPOPRF key information. +/// Used when getting keys for serialization. +#[derive(Serialize)] +pub struct KeyInfoRef<'a> { + pub key_state: ppoprf::ServerKeyStateRef<'a>, + pub epoch: u8, } +/// Map of instance names to KeyInfo. +/// Used for deserializing and setting keys. +pub type OPRFKeys = BTreeMap; + +/// Map of instance names to KeyInfoRef. +/// Used when getting keys for serialization. +pub type OPRFKeysRef<'a> = BTreeMap>; + impl OPRFServer { /// Initialize all OPRF instances with given configuration - pub fn new(config: &Config) -> Arc { - let instances = config - .instance_names - .iter() - .map(|instance_name| { - // Oblivious function state - info!(instance_name, "initializing OPRF state..."); - let server = OPRFInstance::new(config).expect("Could not initialize PPOPRF state"); - info!(instance_name, "epoch now {}", server.epoch); - - (instance_name.to_string(), RwLock::new(server)) - }) - .collect(); - Arc::new(OPRFServer { + pub async fn new(config: Config) -> Arc { + let mut instances = HashMap::new(); + for instance_name in &config.instance_names { + // If key sync is enabled, we should hold off on creating any instances. + // We should wait until GET or PUT /enclave/state is called to either + // generate new PPOPRF keys or sync existing keys. + let instance = match config.enclave_key_sync { + true => None, + false => Some( + OPRFInstance::new(&config, &instance_name, true) + .expect("Could not initialize new PPOPRF server"), + ), + }; + instances.insert(instance_name.to_string(), RwLock::new(instance)); + } + let enclave_key_sync_enabled = config.enclave_key_sync; + let server = Arc::new(OPRFServer { instances, default_instance: config.instance_names.first().cloned().unwrap(), - }) + config, + is_leader: Default::default(), + }); + if !enclave_key_sync_enabled { + for instance_name in &server.config.instance_names { + server + .start_background_task(instance_name.to_string()) + .await; + } + } + server } /// Start background tasks to keep OPRF instances up to date - pub fn start_background_tasks(self: &Arc, config: &Config) { - for (instance_name, instance_epoch_duration) in config - .instance_names - .iter() - .cloned() - .zip(config.epoch_durations.iter().cloned()) - { - // Spawn a background process to advance the epoch - info!(instance_name, "Spawning background epoch rotation task..."); - let background_state = self.clone(); - let background_config = config.clone(); - tokio::spawn(async move { - background_state - .epoch_loop(background_config, instance_name, instance_epoch_duration) - .await - }); - } + async fn start_background_task(self: &Arc, instance_name: String) { + // Spawn a background process to advance the epoch + info!(instance_name, "Spawning background epoch rotation task..."); + let background_state = self.clone(); + let mut instance_guard = self.instances.get(&instance_name).unwrap().write().await; + instance_guard.as_mut().unwrap().background_task_handle = Some(tokio::spawn(async move { + background_state.epoch_loop(instance_name).await + })); } /// Advance to the next epoch on a timer /// This can be invoked as a background task to handle epoch /// advance and key rotation according to the given instance. - #[instrument(skip(self, config, instance_epoch_duration))] - async fn epoch_loop( - self: Arc, - config: Config, - instance_name: String, - instance_epoch_duration: CalendarDuration, - ) { + #[instrument(skip(self, instance_name))] + async fn epoch_loop(self: Arc, instance_name: String) { let server = self .instances .get(&instance_name) .expect("OPRFServer should exist for instance name"); - let epochs = config.first_epoch..=config.last_epoch; - - info!("rotating epoch every {instance_epoch_duration}"); - - let start_time = OffsetDateTime::now_utc(); - // Epoch base_time comes from a config argument if given, - // otherwise use start_time. - let base_time = config.epoch_base_time.unwrap_or(start_time); - info!( - "epoch base time = {}", - base_time - .format(&Rfc3339) - .expect("well-known timestamp format should always succeed") - ); - - // Calculate where we are in the epoch schedule relative to the - // base time. We may need to start in the middle of the range. - assert!( - start_time >= base_time, - "epoch-base-time should be in the past" - ); - let StartingEpochInfo { - elapsed_epoch_count, - mut next_rotation, - } = StartingEpochInfo::calculate(base_time, instance_epoch_duration); - - // The `epochs` range is `u8`, so the length can be no more - // than `u8::MAX + 1`, making it safe to truncate the modulo. - let offset = elapsed_epoch_count % epochs.len(); - let current_epoch = epochs.start() + offset as u8; - - // Advance to the current epoch if base time indicates we started - // in the middle of a sequence. - if current_epoch != config.first_epoch { + + let (mut next_epoch_time, epoch_duration) = { + let server = server.read().await; + let s = server.as_ref().unwrap(); info!( - "Puncturing obsolete epochs {}..{} to match base time", - config.first_epoch, current_epoch + "epoch now {}, next rotation = {}", + s.epoch, s.next_epoch_time ); - let mut s = server.write().expect("Failed to lock OPRFServer"); - for epoch in config.first_epoch..current_epoch { - s.server - .puncture(epoch) - .expect("Failed to puncture obsolete epoch"); - } - s.epoch = current_epoch; - info!("epoch now {}, next rotation = {next_rotation}", s.epoch); - } + ( + parse_timestamp(&s.next_epoch_time).unwrap(), + s.epoch_duration, + ) + }; - loop { - // Pre-calculate the next_epoch_time for the InfoResponse hander. - // Truncate to the nearest second. - let timestamp = next_rotation - .replace_millisecond(0) - .expect("should be able to truncate to a fixed ms") - .format(&Rfc3339) - .expect("well-known timestamp format should always succeed"); - { - // Acquire a temporary write lock which should be dropped - // before sleeping. The locking should not fail, but if it - // does we can't set the field back to None, so panic rather - // than report stale information. - let mut s = server - .write() - .expect("should be able to update next_epoch_time"); - s.next_epoch_time = Some(timestamp); - } + let epochs = self.config.first_epoch..=self.config.last_epoch; + loop { // Wait until the current epoch ends. - let sleep_duration = next_rotation - time::OffsetDateTime::now_utc(); + let sleep_duration = next_epoch_time - time::OffsetDateTime::now_utc(); // Negative durations mean we're behind. if sleep_duration.is_positive() { tokio::time::sleep(sleep_duration.unsigned_abs()).await; } - next_rotation = next_rotation + instance_epoch_duration; - - // Acquire exclusive access to the oprf state. - // Panics if this fails, since processing requests with an - // expired epoch weakens user privacy. - let mut s = server.write().expect("Failed to lock OPRFServer"); - - // Puncture the current epoch so it can no longer be used. - let old_epoch = s.epoch; - s.server - .puncture(old_epoch) - .expect("Failed to puncture current epoch"); - - // Advance to the next epoch, checking for overflow - // and out-of-range. - let new_epoch = old_epoch.checked_add(1); - if new_epoch.filter(|e| epochs.contains(e)).is_some() { - // Server is already initialized for this one. - s.epoch = new_epoch.unwrap(); - } else { - info!("Epochs exhausted! Rotating OPRF key"); - // Panics if this fails. Puncture should mean we can't - // violate privacy through further evaluations, but we - // still want to drop the inner state with its private key. - *s = OPRFInstance::new(&config).expect("Could not initialize new PPOPRF server"); + next_epoch_time = next_epoch_time + epoch_duration; + + { + // Acquire exclusive access to the oprf state. + // Panics if this fails, since processing requests with an + // expired epoch weakens user privacy. + let mut s_guard = server.write().await; + let s = s_guard.as_mut().unwrap(); + + // Puncture the current epoch so it can no longer be used. + let old_epoch = s.epoch; + s.server + .puncture(old_epoch) + .expect("Failed to puncture current epoch"); + + // Advance to the next epoch, checking for overflow + // and out-of-range. + let new_epoch = old_epoch.checked_add(1); + if new_epoch.filter(|e| epochs.contains(e)).is_some() { + // Server is already initialized for this one. + s.epoch = new_epoch.unwrap(); + } else { + if let Some(false) = self.is_leader.get() { + info!("Epochs exhausted, exiting background task. New task will start after leader shares new key."); + *s_guard = None; + return; + } else { + info!("Epochs exhausted! Rotating OPRF key"); + // Panics if this fails. Puncture should mean we can't + // violate privacy through further evaluations, but we + // still want to drop the inner state with its private key. + *s = OPRFInstance::new(&self.config, &instance_name, true) + .expect("Could not initialize new PPOPRF server"); + } + } + s.next_epoch_time = format_rfc3339(&next_epoch_time); + info!("epoch now {}, next rotation = {next_epoch_time}", s.epoch); + } + + if self.config.enclave_key_sync { + if let Some(true) = self.is_leader.get() { + // Since a new OPRFInstance was created, we should sync the new key + // to other enclaves if key sync is enabled. + send_private_keys_to_nitriding( + self.config.nitriding_internal_port.unwrap(), + self.get_private_keys() + .await + .expect("failed to get private keys to send to nitriding"), + ) + .await + .expect("failed to send updated private keys to nitriding"); + } + } + } + } + + /// Stores keys sent by nitriding, and sourced from the leader enclave. + /// If this method is called, this server will assume that it is a worker. + /// OPRFInstances will be created, if not created already. + pub async fn set_private_keys(self: &Arc, private_keys_bytes: Bytes) -> Result<()> { + assert!(self.config.enclave_key_sync); + if let Some(true) = self.is_leader.get() { + error!("invalid set_private_keys call on leader"); + return Err(Error::InvalidPrivateKeyCall); + } + if !self.is_leader.initialized() { + self.is_leader + .set(false) + .expect("failed to set leader status"); + } + let private_keys: OPRFKeys = + bincode::deserialize(&private_keys_bytes).map_err(|e| Error::KeySerialization(e))?; + for (instance_name, key_info) in private_keys { + if let Some(instance) = self.instances.get(&instance_name) { + { + let mut instance_guard = instance.write().await; + + match instance_guard.as_mut() { + Some(existing_instance) => { + // If the key already matches with the stored key, or if the + // epoch from the update does not match the current epoch, + // do not update the instance at this time as there is no need + // to update. + if existing_instance.server.get_private_key() + == key_info.key_state.as_ref() + || key_info.epoch != existing_instance.epoch + { + continue; + } + // Kill existing background task, since we'll create a new one + // after setting the key. + if let Some(handle) = existing_instance.background_task_handle.take() { + handle.abort(); + } + } + None => { + let new_instance = + OPRFInstance::new(&self.config, &instance_name, false) + .expect("Could not initialize PPOPRF state"); + if key_info.epoch != new_instance.epoch { + continue; + } + *instance_guard = Some(new_instance); + } + }; + + instance_guard + .as_mut() + .unwrap() + .server + .set_private_key(key_info.key_state); + } + + self.start_background_task(instance_name).await; } - info!("epoch now {}, next rotation = {next_rotation}", s.epoch); } + Ok(()) + } + + /// Should be called in GET /enclave/state. Will create OPRFInstances + /// and start the background tasks so that the leader keys can be exported + /// to nitriding. + pub async fn create_missing_instances(self: &Arc) { + assert!(self.config.enclave_key_sync); + for (instance_name, instance) in &self.instances { + let mut instance = instance.write().await; + if instance.is_none() { + *instance = Some( + OPRFInstance::new(&self.config, instance_name, true) + .expect("Could not initialize PPOPRF state"), + ); + drop(instance); + self.start_background_task(instance_name.to_string()).await; + } + } + } + + /// Exports keys so that nitriding and forward the keys to worker enclaves. + /// If this method is called, the server will assume that it is the leader. + pub async fn get_private_keys(self: &Arc) -> Result> { + assert!(self.config.enclave_key_sync); + if let Some(false) = self.is_leader.get() { + error!("invalid get_private_keys call on worker"); + return Err(Error::InvalidPrivateKeyCall); + } + if !self.is_leader.initialized() { + self.is_leader + .set(true) + .expect("failed to set leader status"); + } + let mut server_guards = Vec::with_capacity(self.instances.len()); + for (instance_name, instance) in &self.instances { + server_guards.push((instance_name, instance.write().await)) + } + let mut private_keys = OPRFKeysRef::default(); + for (instance_name, instance) in &mut server_guards { + let instance = instance.as_ref().unwrap(); + + private_keys.insert( + instance_name.to_string(), + KeyInfoRef { + epoch: instance.epoch, + key_state: instance.server.get_private_key(), + }, + ); + } + bincode::serialize(&private_keys).map_err(|e| Error::KeySerialization(e)) } } diff --git a/src/tests.rs b/src/tests.rs index 620460f..9c4f822 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,15 +1,22 @@ //! STAR Randomness web service tests -use crate::state::OPRFServer; +use crate::state::{KeyInfoRef, OPRFKeys, OPRFKeysRef, OPRFServer}; use axum::body::{to_bytes, Body, Bytes}; -use axum::http::Request; +use axum::extract::State; use axum::http::StatusCode; +use axum::http::{Method, Request}; +use axum::routing::put; +use axum::Router; use base64::prelude::{Engine as _, BASE64_STANDARD as BASE64}; use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint}; use rand::rngs::OsRng; use serde_json::{json, Value}; use std::time::Duration; use time::OffsetDateTime; +use tokio::net::TcpListener; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio::time::sleep; use tower::Service; use tower::ServiceExt; @@ -20,7 +27,7 @@ const NEXT_EPOCH_TIME: &str = "2023-03-22T21:46:35Z"; /// Maximum size of a response body to consider /// This is an approximate bound to allow for crate::MAX_POINTS. /// The exact size is 32 bytes per point, plus base64 and json overhead. -const RESPONSE_MAX: usize = 48*1024; +const RESPONSE_MAX: usize = 48 * 1024; struct InstanceConfig { instance_name: String, @@ -28,7 +35,7 @@ struct InstanceConfig { } /// Create an app instance for testing -fn test_app(instance_configs: Option>) -> crate::Router { +async fn test_app(instance_configs: Option>) -> crate::Router { let instance_configs = instance_configs.unwrap_or(vec![InstanceConfig { instance_name: "main".to_string(), epoch_duration: "1s".to_string(), @@ -49,27 +56,29 @@ fn test_app(instance_configs: Option>) -> crate::Router { .into_iter() .map(|c| c.instance_name) .collect(), + enclave_key_sync: false, + nitriding_internal_port: None, }; // server state - let oprf_state = OPRFServer::new(&config); + let oprf_state = OPRFServer::new(config.clone()).await; + for instance in oprf_state.instances.values() { - instance.write().unwrap().next_epoch_time = Some(NEXT_EPOCH_TIME.to_owned()); + instance.write().await.as_mut().unwrap().next_epoch_time = NEXT_EPOCH_TIME.to_string(); } - // attach axum routes and middleware - crate::app(oprf_state) + crate::app(&config, oprf_state) } /// Create a request for testing -fn test_request(uri: &str, payload: Option) -> Request { +fn test_request(uri: &str, payload: Option, method: Option) -> Request { let builder = Request::builder().uri(uri); let request = match payload { - Some(json) => { + Some(payload) => { // POST payload body as json builder - .method("POST") + .method(method.unwrap_or(Method::POST)) .header("Content-Type", "application/json") - .body(json.into()) + .body(payload) } None => { // regular GET request @@ -81,9 +90,9 @@ fn test_request(uri: &str, payload: Option) -> Request { #[tokio::test] async fn welcome() { - let app = test_app(None); + let app = test_app(None).await; - let request = test_request("/", None); + let request = test_request("/", None, None); let response = app.oneshot(request).await.unwrap(); // Root should return some identifying text for friendliness. @@ -125,18 +134,19 @@ async fn info() { instance_name: "alternate".to_string(), epoch_duration: "1s".to_string(), }, - ])); + ])) + .await; - let response = app.call(test_request("/info", None)).await.unwrap(); + let response = app.call(test_request("/info", None, None)).await.unwrap(); // Info should return the correct epoch, etc. let default_public_key = validate_info_response_and_return_public_key_b64( response.status(), - to_bytes(response.into_body(), RESPONSE_MAX).await.unwrap() + to_bytes(response.into_body(), RESPONSE_MAX).await.unwrap(), ); let response = app - .call(test_request("/instances/main/info", None)) + .call(test_request("/instances/main/info", None, None)) .await .unwrap(); let specific_default_public_key = validate_info_response_and_return_public_key_b64( @@ -146,7 +156,7 @@ async fn info() { assert_eq!(default_public_key, specific_default_public_key); let response = app - .call(test_request("/instances/alternate/info", None)) + .call(test_request("/instances/alternate/info", None, None)) .await .unwrap(); let alternate_public_key = validate_info_response_and_return_public_key_b64( @@ -156,7 +166,7 @@ async fn info() { assert_ne!(default_public_key, alternate_public_key); let response = app - .call(test_request("/instances/notexisting/info", None)) + .call(test_request("/instances/notexisting/info", None, None)) .await .unwrap(); assert_eq!(response.status(), StatusCode::NOT_FOUND); @@ -173,7 +183,8 @@ async fn randomness() { instance_name: "alternate".to_string(), epoch_duration: "1s".to_string(), }, - ])); + ])) + .await; // Create a single-point randomness request. let point = RistrettoPoint::random(&mut OsRng); @@ -183,7 +194,7 @@ async fn randomness() { .to_string(); // Submit to the hander. - let request = test_request("/randomness", Some(payload.clone())); + let request = test_request("/randomness", Some(payload.clone().into()), None); let response = app.call(request).await.unwrap(); // Verify we receive a successful, well-formed response. assert_eq!(response.status(), StatusCode::OK); @@ -193,7 +204,8 @@ async fn randomness() { let response = app .call(test_request( "/instances/main/randomness", - Some(payload.clone()), + Some(payload.clone().into()), + None, )) .await .unwrap(); @@ -204,7 +216,8 @@ async fn randomness() { let response = app .call(test_request( "/instances/alternate/randomness", - Some(payload.clone()), + Some(payload.clone().into()), + None, )) .await .unwrap(); @@ -216,7 +229,8 @@ async fn randomness() { let response = app .call(test_request( "/instances/notexisting/randomness", - Some(payload), + Some(payload.into()), + None, )) .await .unwrap(); @@ -234,8 +248,8 @@ async fn epoch() { "epoch": EPOCH }) .to_string(); - let request = test_request("/randomness", Some(payload)); - let response = test_app(None).oneshot(request).await.unwrap(); + let request = test_request("/randomness", Some(payload.into()), None); + let response = test_app(None).await.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = to_bytes(response.into_body(), RESPONSE_MAX).await.unwrap(); verify_randomness_body(&body, points.len()); @@ -247,8 +261,8 @@ async fn epoch() { "epoch": 0 }) .to_string(); - let request = test_request("/randomness", Some(payload)); - let response = test_app(None).oneshot(request).await.unwrap(); + let request = test_request("/randomness", Some(payload.into()), None); + let response = test_app(None).await.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); // Verify later epochs are rejected. @@ -257,8 +271,8 @@ async fn epoch() { "epoch": EPOCH + 1 }) .to_string(); - let request = test_request("/randomness", Some(payload)); - let response = test_app(None).oneshot(request).await.unwrap(); + let request = test_request("/randomness", Some(payload.into()), None); + let response = test_app(None).await.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); } @@ -266,54 +280,37 @@ async fn epoch() { /// with the correct epoch. #[tokio::test] async fn epoch_base_time() { - let now = OffsetDateTime::now_utc(); - let delay = Duration::from_secs(5); + let now = OffsetDateTime::now_utc() + .replace_millisecond(0) + .expect("should be able to truncate to a fixed ms"); + let delay = Duration::from_secs(11); // Config with explicit base time let config = crate::Config { listen: "127.0.0.1:8081".to_string(), - epoch_durations: vec!["1s".into()], + epoch_durations: vec!["10s".into()], first_epoch: EPOCH, last_epoch: EPOCH * 2, epoch_base_time: Some(now - delay), increase_nofile_limit: false, prometheus_listen: None, instance_names: vec!["main".to_string()], + enclave_key_sync: false, + nitriding_internal_port: None, }; - // Verify test parameters are compatible with the - // expected_epoch calculation. - assert!(EPOCH as u64 + delay.as_secs() < EPOCH as u64 * 2); - let expected_epoch = EPOCH + delay.as_secs() as u8; - let advance = Duration::from_secs(1); + let expected_epoch = EPOCH + 1; + let advance = Duration::from_secs(9); let expected_time = (now + advance) - // Published timestamp is truncated to the second. - .replace_millisecond(0) - .expect("should be able to truncate to a fixed ms") .format(&time::format_description::well_known::Rfc3339) .expect("well-known timestamp format should always succeed"); // server state - let oprf_state = OPRFServer::new(&config); - // background task to manage epoch rotation - oprf_state.start_background_tasks(&config); - - // Wait for `epoch_loop` to update `next_epoch_time` as a proxy - // for completing epoch schedule initialization. Use a timeout - // to avoid hanging test runs. - let pause = Duration::from_millis(10); - let mut tries = 0; - let oprf_instance = oprf_state.instances.get("main").unwrap(); - while oprf_instance.read().unwrap().next_epoch_time.is_none() { - println!("waiting for {pause:?} for initialization {tries}"); - assert!(tries < 10, "timeout waiting for epoch_loop initialization"); - tokio::time::sleep(pause).await; - tries += 1; - } + let oprf_state = OPRFServer::new(config.clone()).await; // attach axum routes and middleware - let app = crate::app(oprf_state); + let app = crate::app(&config, oprf_state); - let request = test_request("/info", None); + let request = test_request("/info", None, None); let response = app.oneshot(request).await.unwrap(); // Info should return the correct epoch, etc. @@ -365,9 +362,9 @@ fn make_points(count: usize) -> Vec { /// Verify randomness response to a batch of points async fn verify_batch(points: &[String]) { - let app = test_app(None); + let app = test_app(None).await; let payload = json!({ "points": points }).to_string(); - let request = test_request("/randomness", Some(payload)); + let request = test_request("/randomness", Some(payload.into()), None); let response = app.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = to_bytes(response.into_body(), RESPONSE_MAX).await.unwrap(); @@ -396,7 +393,230 @@ async fn max_points() { // should be rejected. let points = make_points(crate::MAX_POINTS + 1); let payload = json!({ "points": points }).to_string(); - let request = test_request("/randomness", Some(payload)); - let response = test_app(None).oneshot(request).await.unwrap(); + let request = test_request("/randomness", Some(payload.into()), None); + let response = test_app(None).await.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + +#[tokio::test] +async fn test_enclave_leader() { + let config = crate::Config { + listen: "127.0.0.1:8082".to_string(), + epoch_durations: vec!["1s".into(), "2s".into()], + first_epoch: EPOCH, + last_epoch: EPOCH * 2, + epoch_base_time: None, + increase_nofile_limit: false, + prometheus_listen: None, + instance_names: vec!["main".to_string(), "secondary".to_string()], + enclave_key_sync: true, + nitriding_internal_port: Some(8083), + }; + + let oprf_state = OPRFServer::new(config.clone()).await; + + assert!(oprf_state + .instances + .get("main") + .unwrap() + .read() + .await + .is_none()); + assert!(oprf_state + .instances + .get("secondary") + .unwrap() + .read() + .await + .is_none()); + assert!(!oprf_state.is_leader.initialized()); + + let app = crate::app(&config, oprf_state.clone()); + + let request = test_request("/enclave/state", None, None); + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), RESPONSE_MAX).await.unwrap(); + assert!(!body.is_empty()); + + let private_keys: OPRFKeys = + bincode::deserialize(&body).expect("Failed to deserialize private keys"); + + assert_eq!(private_keys.len(), 2); + + for (instance_name, key_info) in private_keys.iter() { + let instance = oprf_state.instances.get(instance_name).unwrap(); + let instance_guard = instance.read().await; + let instance = instance_guard.as_ref().unwrap(); + + assert_eq!(instance.epoch, key_info.epoch); + assert_eq!( + instance.server.get_private_key(), + key_info.key_state.as_ref() + ); + } + + assert_eq!(private_keys.len(), config.instance_names.len()); + for instance_name in config.instance_names.iter() { + assert!(private_keys.contains_key(instance_name)); + } + + assert_eq!(oprf_state.is_leader.get(), Some(&true)); +} + +#[tokio::test] +async fn test_enclave_worker() { + let config = crate::Config { + listen: "127.0.0.1:8084".to_string(), + epoch_durations: vec!["1s".into(), "2s".into()], + first_epoch: EPOCH, + last_epoch: EPOCH * 2, + epoch_base_time: None, + increase_nofile_limit: false, + prometheus_listen: None, + instance_names: vec!["main".to_string(), "secondary".to_string()], + enclave_key_sync: true, + nitriding_internal_port: Some(8085), + }; + + let oprf_state = OPRFServer::new(config.clone()).await; + + assert!(oprf_state + .instances + .get("main") + .unwrap() + .read() + .await + .is_none()); + assert!(oprf_state + .instances + .get("secondary") + .unwrap() + .read() + .await + .is_none()); + assert!(!oprf_state.is_leader.initialized()); + + let mock_ppoprfs = config + .instance_names + .iter() + .map(|instance_name| { + ( + instance_name, + ppoprf::ppoprf::Server::new((EPOCH..EPOCH * 2).collect()).unwrap(), + ) + }) + .collect::>(); + let mock_keys = mock_ppoprfs + .iter() + .map(|(instance_name, server)| { + ( + instance_name.to_string(), + KeyInfoRef { + key_state: server.get_private_key(), + epoch: EPOCH, + }, + ) + }) + .collect::(); + + let mock_keys_bytes = bincode::serialize(&mock_keys).expect("Failed to serialize mock keys"); + + let app = crate::app(&config, oprf_state.clone()); + + let request = test_request( + "/enclave/state", + Some(mock_keys_bytes.into()), + Some(Method::PUT), + ); + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(oprf_state.is_leader.get(), Some(&false)); + + for (instance_name, key_info) in mock_keys.iter() { + let instance = oprf_state.instances.get(instance_name).unwrap(); + let instance_guard = instance.read().await; + let instance = instance_guard.as_ref().unwrap(); + + assert_eq!(instance.epoch, key_info.epoch); + assert_eq!(instance.server.get_private_key(), key_info.key_state); + } +} + +#[tokio::test] +async fn test_leader_updates_keys_with_nitriding() { + let config = crate::Config { + listen: "127.0.0.1:8085".to_string(), + epoch_durations: vec!["1s".into()], + first_epoch: EPOCH, + last_epoch: EPOCH + 2, + epoch_base_time: None, + increase_nofile_limit: false, + prometheus_listen: None, + instance_names: vec!["main".to_string()], + enclave_key_sync: true, + nitriding_internal_port: Some(8087), + }; + + let (mock_server_handle, mut body_rx) = start_mock_nitriding_server(8087).await; + + let oprf_state = OPRFServer::new(config.clone()).await; + + let app = crate::app(&config, oprf_state.clone()); + + let request = test_request("/enclave/state", None, None); + app.oneshot(request).await.unwrap(); + + assert!(body_rx.is_empty()); + + sleep(Duration::from_secs(1)).await; + + let updated_body = body_rx.recv().await.unwrap(); + let updated_keys: OPRFKeys = bincode::deserialize(&updated_body).unwrap(); + + assert_eq!(updated_keys.len(), 1); + + for (instance_name, key_info) in updated_keys { + let instance = oprf_state.instances.get(&instance_name).unwrap(); + let instance_guard = instance.read().await; + let instance = instance_guard.as_ref().unwrap(); + + assert_eq!(instance.epoch, key_info.epoch); + assert_eq!( + instance.server.get_private_key(), + key_info.key_state.as_ref() + ); + } + + mock_server_handle.abort(); + mock_server_handle.await.ok(); +} + +async fn start_mock_nitriding_server( + port: u16, +) -> (JoinHandle<()>, mpsc::UnboundedReceiver) { + let (body_tx, body_rx) = mpsc::unbounded_channel(); + + let app = Router::new() + .route("/enclave/state", put(nitriding_put_state_handler)) + .with_state(body_tx); + + let handle = tokio::spawn(async move { + let listener = TcpListener::bind(format!("127.0.0.1:{port}")) + .await + .unwrap(); + axum::serve(listener, app).await.unwrap(); + }); + + (handle, body_rx) +} + +async fn nitriding_put_state_handler( + State(body_tx): State>, + body: Bytes, +) { + body_tx.send(body).unwrap(); +} diff --git a/src/util.rs b/src/util.rs index 1dc5eae..90b5faf 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; +use reqwest::{Client, Method}; use time::{format_description::well_known::Rfc3339, OffsetDateTime}; /// Parse a timestamp given as a config option @@ -15,3 +16,23 @@ pub fn assert_unique_names(instance_names: &[String]) { "all instance names must be unique" ); } + +pub fn format_rfc3339(date: &OffsetDateTime) -> String { + date.format(&Rfc3339) + .expect("well-known timestamp format should always succeed") +} + +pub async fn send_private_keys_to_nitriding( + nitriding_internal_port: u16, + private_key_bincode: Vec, +) -> Result<(), reqwest::Error> { + let client = Client::new(); + let request = client + .request( + Method::PUT, + format!("http://127.0.0.1:{nitriding_internal_port}/enclave/state"), + ) + .body(private_key_bincode) + .build()?; + client.execute(request).await.map(|_| ()) +}