diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0af6c5b..9e9412a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,6 +50,19 @@ jobs: files: lcov.info fail_on_error: true + examples: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@v1 + with: + profile: minimal + toolchain: stable + target: x86_64-unknown-linux-gnu + override: true + - name: Run `dining_cryptographers` + run: cargo run dining_cryptographers + # This is supposed to factor out possible bugs introduced by Rust compiler changes # and dependency changes, making the results more reproducible. stable-test: diff --git a/CHANGELOG.md b/CHANGELOG.md index d0eb065..f4b6477 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - `session::tokio::run_session()` and `par_run_session()` take an additional `cancellation` argument to support external loop cancellation. ([#100]) +- `Round` now uses associated types for messages, payloads, and artifacts instead of boxed types. ([#117]) +- Protocol errors and evidence verification are now defined for each round separately. `ProtocolError` is an associated type of `Round`. ([#117]) +- `misbehave` combinator is reworked ino `extend`. It now works by defining typed extensions for a specific `Round` type. ([#117]) ### Fixed @@ -17,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#100]: https://github.com/entropyxyz/manul/pull/100 +[#117]: https://github.com/entropyxyz/manul/pull/117 [#119]: https://github.com/entropyxyz/manul/pull/119 diff --git a/Cargo.lock b/Cargo.lock index ac5403d..6448fc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -730,6 +730,15 @@ dependencies = [ "serde", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -757,6 +766,16 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + [[package]] name = "rand_core" version = "0.6.4" @@ -1378,3 +1397,23 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/examples/dining_cryptographers.rs b/examples/dining_cryptographers.rs index 4493e03..3bc5cc8 100644 --- a/examples/dining_cryptographers.rs +++ b/examples/dining_cryptographers.rs @@ -55,9 +55,9 @@ use std::collections::{BTreeMap, BTreeSet}; use manul::{ dev::{run_sync, BinaryFormat, TestHasher, TestSignature, TestSigner, TestVerifier}, protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EchoRoundParticipation, - EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors, NormalBroadcast, Payload, - Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round, RoundId, TransitionInfo, + BoxedRound, CommunicationInfo, EchoRoundParticipation, EntryPoint, FinalizeOutcome, LocalError, NoArtifact, + NoMessage, NoProtocolErrors, Protocol, ProtocolMessage, ReceiveError, Round, RoundId, RoundInfo, + TransitionInfo, }, session::SessionParameters, }; @@ -73,35 +73,18 @@ use tracing::{debug, info, trace}; #[derive(Debug)] pub struct DiningCryptographersProtocol; -impl Protocol for DiningCryptographersProtocol { +impl Protocol for DiningCryptographersProtocol { // XOR/¬XOR of the two bits of each of the three diners (one is their own cointoss, the other shared with their // neighbour). type Result = (bool, bool, bool); + type SharedData = (); - type ProtocolError = NoProtocolErrors; - - fn verify_direct_message_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - Ok(()) - } - - fn verify_echo_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - Ok(()) - } - - fn verify_normal_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { - Ok(()) + fn round_info(round_id: &RoundId) -> Option> { + match round_id { + _ if round_id == 1 => Some(RoundInfo::new::()), + _ if round_id == 2 => Some(RoundInfo::new::()), + _ => None, + } } } @@ -125,6 +108,14 @@ pub struct Round2 { impl Round for Round1 { type Protocol = DiningCryptographersProtocol; + type ProtocolError = NoProtocolErrors; + + type DirectMessage = Round1Message; + type EchoBroadcast = NoMessage; + type NormalBroadcast = NoMessage; + + type Payload = bool; + type Artifact = (); // Used to define the possible paths to and from this round. This protocol is very simple, it's simply Round 1 -> // Round 2, so we can use the "linear" utility method to set this up. @@ -157,44 +148,37 @@ impl Round for Round1 { // This is called when this diner prepares to share a random bit with their neighbour. fn make_direct_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, + _rng: &mut impl CryptoRngCore, destination: &DinerId, - ) -> Result<(DirectMessage, Option), LocalError> { + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { info!( "[Round1, make_direct_message] from {:?} to {destination:?}", self.diner_id ); - let msg = Round1Message { toss: self.own_toss }; - let dm = DirectMessage::new(format, msg)?; - - Ok((dm, None)) + Ok((Round1Message { toss: self.own_toss }, ())) } // This is called when this diner receives a bit from their neighbour. fn receive_message( &self, - format: &BoxedFormat, from: &DinerId, - message: ProtocolMessage, - ) -> Result> { - let dm = message.direct_message.deserialize::(format)?; + message: ProtocolMessage, + ) -> Result> { + let dm = message.direct_message; debug!( "[Round1, receive_message] {:?} was dm'd by {from:?}: {dm:?}", self.diner_id ); - let payload = Payload::new(dm.toss); - Ok(payload) + Ok(dm.toss) } // At the end of round 1 we construct the next one, Round 2, and return a [`FinalizeOutcome::AnotherRound`]. fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - _artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { - let payloads = downcast_payloads::(payloads)?; debug!("[Round1, finalize] {:?} sees payloads: {payloads:?}", self.diner_id); let neighbour_toss = *payloads @@ -206,7 +190,7 @@ impl Round for Round1 { "[Round1, finalize] {:?} is finalizing to Round 2. Own cointoss: {}, neighbour cointoss: {neighbour_toss}", self.diner_id, self.own_toss ); - Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round2 { + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new(Round2 { diner_id: self.diner_id, own_toss: self.own_toss, neighbour_toss, @@ -217,6 +201,14 @@ impl Round for Round1 { impl Round for Round2 { type Protocol = DiningCryptographersProtocol; + type ProtocolError = NoProtocolErrors; + + type DirectMessage = NoMessage; + type EchoBroadcast = NoMessage; + type NormalBroadcast = Round2Message; + + type Payload = bool; + type Artifact = NoArtifact; // This round is the last in the protocol so we can terminate here. fn transition_info(&self) -> TransitionInfo { @@ -247,11 +239,7 @@ impl Round for Round2 { } // Implementing this method means that Round 2 will make a broadcast (without echoes). - fn make_normal_broadcast( - &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { + fn make_normal_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { debug!( "[Round2, make_normal_broadcast] {:?} broadcasts to everyone else", self.diner_id @@ -262,9 +250,7 @@ impl Round for Round2 { } else { self.own_toss ^ self.neighbour_toss }; - let msg = Round2Message { reveal }; - let bcast = NormalBroadcast::new(format, msg)?; - Ok(bcast) + Ok(Round2Message { reveal }) } // Called once for each diner as messages are delivered to it. Here we deserialize the message using the configured @@ -272,16 +258,14 @@ impl Round for Round2 { // method below. fn receive_message( &self, - format: &BoxedFormat, from: &DinerId, - message: ProtocolMessage, - ) -> Result> { + message: ProtocolMessage, + ) -> Result> { debug!("[Round2, receive_message] from {from:?} to {:?}", self.diner_id); - let bcast = message.normal_broadcast.deserialize::(format)?; + let bcast = message.normal_broadcast; trace!("[Round2, receive_message] message (deserialized bcast): {:?}", bcast); // The payload is kept and delivered in the `finalize` method. - let payload = Payload::new(bcast.reveal); - Ok(payload) + Ok(bcast.reveal) } // The `finalize` method has access to all the [`Payload`]s that were sent to this diner. This protocol does not use @@ -289,10 +273,10 @@ impl Round for Round2 { // This is the last round in the protocol, so we return a [`FinalizeOutcome::Result`] with the result of the // protocol from this participant's point of view. fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - _artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { // XOR/¬XOR the two bits of this diner, depending on whether they paid or not. let mut own_reveal = self.own_toss ^ self.neighbour_toss; @@ -301,8 +285,7 @@ impl Round for Round2 { } // Extract the payloads from the other participants so we can produce a [`Protocol::Result`]. In this case it is // a tuple of 3 booleans. - let payloads_d = downcast_payloads::(payloads)?; - let bits = payloads_d.values().cloned().collect::>(); + let bits = payloads.into_values().collect::>(); Ok(FinalizeOutcome::Result((bits[0], bits[1], own_reveal))) } } @@ -337,7 +320,7 @@ impl EntryPoint for DiningEntryPoint { // Each `EntryPoint` creates one `Session`. fn make_round( self, - rng: &mut dyn CryptoRngCore, + rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], id: &DinerId, ) -> Result, LocalError> { @@ -351,7 +334,7 @@ impl EntryPoint for DiningEntryPoint { "[DiningEntryPoint, make_round] diner {id:?} tossed: {:?} (paid? {paid})", round.own_toss ); - let round = BoxedRound::new_dynamic(round); + let round = BoxedRound::new(round); Ok(round) } } @@ -376,13 +359,6 @@ impl SessionParameters for DiningSessionParams { type WireFormat = BinaryFormat; } -// Just a utility method to help us convert a [`Payload`] to, for example, a `bool`. -fn downcast_payloads(map: BTreeMap) -> Result, LocalError> { - map.into_iter() - .map(|(id, payload)| payload.downcast::().map(|p| (id, p))) - .collect() -} - fn main() { tracing_subscriber::fmt::init(); info!("Dining Cryptographers Protocol Example"); diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 37b869c..bafe0f8 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -4,4 +4,4 @@ pub mod simple; pub mod simple_chain; #[cfg(test)] -mod simple_malicious; +mod simple_test; diff --git a/examples/src/simple.rs b/examples/src/simple.rs index 8517f7b..77bf85e 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -2,10 +2,9 @@ use alloc::collections::{BTreeMap, BTreeSet}; use core::fmt::Debug; use manul::protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, - LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage, - ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId, - TransitionInfo, + BoxedRound, CommunicationInfo, EntryPoint, EvidenceError, EvidenceMessages, FinalizeOutcome, LocalError, NoMessage, + PartyId, Protocol, ProtocolError, ProtocolMessage, ReceiveError, RequiredMessageParts, RequiredMessages, Round, + RoundId, RoundInfo, TransitionInfo, }; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -14,101 +13,69 @@ use tracing::debug; #[derive(Debug)] pub struct SimpleProtocol; -#[derive(displaydoc::Display, Debug, Clone, Serialize, Deserialize)] -/// An example error. -pub enum SimpleProtocolError { - /// Invalid position in Round 1. - Round1InvalidPosition, - /// Invalid position in Round 2. - Round2InvalidPosition, -} +#[derive(displaydoc::Display, Debug, Clone, Copy, Serialize, Deserialize)] +pub(crate) struct Round1ProtocolError; -impl ProtocolError for SimpleProtocolError { - type AssociatedData = (); - - fn required_messages(&self) -> RequiredMessages { - match self { - Self::Round1InvalidPosition => RequiredMessages::new(RequiredMessageParts::direct_message(), None, None), - Self::Round2InvalidPosition => RequiredMessages::new( - RequiredMessageParts::direct_message(), - Some([(1.into(), RequiredMessageParts::direct_message())].into()), - Some([1.into()].into()), - ), - } +impl ProtocolError for Round1ProtocolError { + type Round = Round1; + fn required_messages(&self, _round_id: &RoundId) -> RequiredMessages { + RequiredMessages::new(RequiredMessageParts::direct_message(), None, None) } - - fn verify_messages_constitute_error( + fn verify_evidence( &self, - format: &BoxedFormat, - _guilty_party: &Id, + _round_id: &RoundId, + _from: &Id, _shared_randomness: &[u8], - _associated_data: &Self::AssociatedData, - message: ProtocolMessage, - _previous_messages: BTreeMap, - combined_echos: BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - match self { - SimpleProtocolError::Round1InvalidPosition => { - let _message = message.direct_message.deserialize::(format)?; - // Message contents would be checked here - Ok(()) - } - SimpleProtocolError::Round2InvalidPosition => { - let _r1_message = message.direct_message.deserialize::(format)?; - let r1_echos_serialized = combined_echos - .get(&1.into()) - .ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?; - - // Deserialize the echos - let _r1_echos = r1_echos_serialized - .iter() - .map(|(_id, echo)| echo.deserialize::(format)) - .collect::, _>>()?; - - // Message contents would be checked here - Ok(()) - } - } + _shared_data: &<>::Protocol as Protocol>::SharedData, + messages: EvidenceMessages, + ) -> std::result::Result<(), EvidenceError> { + let _message: Round1Message = messages.direct_message()?; + // Message contents would be checked here + Ok(()) + } + fn description(&self) -> std::string::String { + "Invalid position".into() } } -impl Protocol for SimpleProtocol { - type Result = u8; - type ProtocolError = SimpleProtocolError; - - fn verify_direct_message_is_invalid( - format: &BoxedFormat, - round_id: &RoundId, - message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - match round_id { - r if r == &1 => message.verify_is_not::(format), - r if r == &2 => message.verify_is_not::(format), - _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), - } +#[derive(displaydoc::Display, Debug, Clone, Copy, Serialize, Deserialize)] +pub(crate) struct Round2ProtocolError; + +impl ProtocolError for Round2ProtocolError { + type Round = Round2; + fn required_messages(&self, _round_id: &RoundId) -> RequiredMessages { + RequiredMessages::new( + RequiredMessageParts::direct_message(), + Some([(1.into(), RequiredMessageParts::direct_message())].into()), + Some([1.into()].into()), + ) } - - fn verify_echo_broadcast_is_invalid( - format: &BoxedFormat, - round_id: &RoundId, - message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - match round_id { - r if r == &1 => message.verify_is_not::(format), - r if r == &2 => message.verify_is_some(), - _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), - } + fn verify_evidence( + &self, + _round_id: &RoundId, + _from: &Id, + _shared_randomness: &[u8], + _shared_data: &<>::Protocol as Protocol>::SharedData, + messages: EvidenceMessages, + ) -> std::result::Result<(), EvidenceError> { + let _r2_message: Round2Message = messages.direct_message()?; + let _r1_echos: BTreeMap = messages.combined_echos::>(1)?; + // Message contents would be checked here + Ok(()) + } + fn description(&self) -> std::string::String { + "Invalid position".into() } +} - fn verify_normal_broadcast_is_invalid( - format: &BoxedFormat, - round_id: &RoundId, - message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { +impl Protocol for SimpleProtocol { + type Result = u8; + type SharedData = (); + fn round_info(round_id: &RoundId) -> Option> { match round_id { - r if r == &1 => message.verify_is_not::(format), - r if r == &2 => message.verify_is_some(), - _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + _ if round_id == 1 => Some(RoundInfo::new::>()), + _ if round_id == 2 => Some(RoundInfo::new::>()), + _ => None, } } } @@ -121,7 +88,7 @@ pub(crate) struct Context { } #[derive(Debug)] -pub struct Round1 { +pub(crate) struct Round1 { pub(crate) context: Context, } @@ -132,17 +99,17 @@ pub(crate) struct Round1Message { } #[derive(Serialize, Deserialize)] -struct Round1Echo { +pub(crate) struct Round1Echo { my_position: u8, } #[derive(Serialize, Deserialize)] -struct Round1Broadcast { +pub(crate) struct Round1Broadcast { x: u8, my_position: u8, } -struct Round1Payload { +pub(crate) struct Round1Payload { x: u8, } @@ -166,7 +133,7 @@ impl EntryPoint for SimpleProtocolEntryPoint { fn make_round( self, - _rng: &mut dyn CryptoRngCore, + _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], id: &Id, ) -> Result, LocalError> { @@ -182,7 +149,7 @@ impl EntryPoint for SimpleProtocolEntryPoint { let mut ids = self.all_ids; ids.remove(id); - Ok(BoxedRound::new_dynamic(Round1 { + Ok(BoxedRound::new(Round1 { context: Context { id: id.clone(), other_ids: ids, @@ -194,6 +161,7 @@ impl EntryPoint for SimpleProtocolEntryPoint { impl Round for Round1 { type Protocol = SimpleProtocol; + type ProtocolError = Round1ProtocolError; fn transition_info(&self) -> TransitionInfo { TransitionInfo::new_linear(1) @@ -203,77 +171,60 @@ impl Round for Round1 { CommunicationInfo::regular(&self.context.other_ids) } - fn make_normal_broadcast( - &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { - debug!("{:?}: making normal broadcast", self.context.id); + type NormalBroadcast = Round1Broadcast; + type EchoBroadcast = Round1Echo; + type DirectMessage = Round1Message; - let message = Round1Broadcast { + type Payload = Round1Payload; + type Artifact = (); + + fn make_normal_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { + debug!("{:?}: making normal broadcast", self.context.id); + Ok(Round1Broadcast { x: 0, my_position: self.context.ids_to_positions[&self.context.id], - }; - - NormalBroadcast::new(format, message) + }) } - fn make_echo_broadcast( - &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { + fn make_echo_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { debug!("{:?}: making echo broadcast", self.context.id); - - let message = Round1Echo { + Ok(Round1Echo { my_position: self.context.ids_to_positions[&self.context.id], - }; - - EchoBroadcast::new(format, message) + }) } fn make_direct_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, + _rng: &mut impl CryptoRngCore, destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { debug!("{:?}: making direct message for {:?}", self.context.id, destination); - let message = Round1Message { my_position: self.context.ids_to_positions[&self.context.id], your_position: self.context.ids_to_positions[destination], }; - let dm = DirectMessage::new(format, message)?; - Ok((dm, None)) + Ok((message, ())) } fn receive_message( &self, - format: &BoxedFormat, from: &Id, - message: ProtocolMessage, - ) -> Result> { + message: ProtocolMessage, + ) -> Result> { debug!("{:?}: receiving message from {:?}", self.context.id, from); - - let _echo = message.echo_broadcast.deserialize::(format)?; - let _normal = message.normal_broadcast.deserialize::(format)?; - let message = message.direct_message.deserialize::(format)?; - - debug!("{:?}: received message: {:?}", self.context.id, message); + let message = message.direct_message; if self.context.ids_to_positions[&self.context.id] != message.your_position { - return Err(ReceiveError::protocol(SimpleProtocolError::Round1InvalidPosition)); + return Err(ReceiveError::Protocol(Round1ProtocolError)); } - - Ok(Payload::new(Round1Payload { x: message.my_position })) + Ok(Round1Payload { x: message.my_position }) } fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - _artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { debug!( "{:?}: finalizing with messages from {:?}", @@ -281,14 +232,10 @@ impl Round for Round1 { payloads.keys().cloned().collect::>() ); - let typed_payloads = payloads - .into_values() - .map(|payload| payload.downcast::()) - .collect::, _>>()?; - let sum = self.context.ids_to_positions[&self.context.id] - + typed_payloads.iter().map(|payload| payload.x).sum::(); + let sum = + self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::(); - let round2 = BoxedRound::new_dynamic(Round2 { + let round2 = BoxedRound::new(Round2 { round1_sum: sum, context: self.context, }); @@ -310,6 +257,7 @@ pub(crate) struct Round2Message { impl Round for Round2 { type Protocol = SimpleProtocol; + type ProtocolError = Round2ProtocolError; fn transition_info(&self) -> TransitionInfo { TransitionInfo::new_linear_terminating(2) @@ -319,49 +267,50 @@ impl Round for Round2 { CommunicationInfo::regular(&self.context.other_ids) } + type DirectMessage = Round2Message; + type EchoBroadcast = NoMessage; + type NormalBroadcast = NoMessage; + + type Payload = Round1Payload; + type Artifact = (); + fn make_direct_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, + _rng: &mut impl CryptoRngCore, destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { debug!("{:?}: making direct message for {:?}", self.context.id, destination); let message = Round2Message { my_position: self.context.ids_to_positions[&self.context.id], your_position: self.context.ids_to_positions[destination], }; - let dm = DirectMessage::new(format, message)?; - Ok((dm, None)) + Ok((message, ())) } fn receive_message( &self, - format: &BoxedFormat, from: &Id, - message: ProtocolMessage, - ) -> Result> { + message: ProtocolMessage, + ) -> Result> { debug!("{:?}: receiving message from {:?}", self.context.id, from); - message.echo_broadcast.assert_is_none()?; - message.normal_broadcast.assert_is_none()?; - - let message = message.direct_message.deserialize::(format)?; + let message = message.direct_message; debug!("{:?}: received message: {:?}", self.context.id, message); if self.context.ids_to_positions[&self.context.id] != message.your_position { - return Err(ReceiveError::protocol(SimpleProtocolError::Round2InvalidPosition)); + return Err(ReceiveError::Protocol(Round2ProtocolError)); } - Ok(Payload::new(Round1Payload { x: message.my_position })) + Ok(Round1Payload { x: message.my_position }) } fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - _artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { debug!( "{:?}: finalizing with messages from {:?}", @@ -369,12 +318,8 @@ impl Round for Round2 { payloads.keys().cloned().collect::>() ); - let typed_payloads = payloads - .into_values() - .map(|payload| payload.downcast::()) - .collect::, _>>()?; - let sum = self.context.ids_to_positions[&self.context.id] - + typed_payloads.iter().map(|payload| payload.x).sum::(); + let sum = + self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::(); Ok(FinalizeOutcome::Result(sum + self.round1_sum)) } diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs index d38857d..f84611f 100644 --- a/examples/src/simple_chain.rs +++ b/examples/src/simple_chain.rs @@ -15,7 +15,7 @@ pub struct DoubleSimpleProtocol; impl ChainedMarker for DoubleSimpleProtocol {} -impl ChainedProtocol for DoubleSimpleProtocol { +impl ChainedProtocol for DoubleSimpleProtocol { type Protocol1 = SimpleProtocol; type Protocol2 = SimpleProtocol; } diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs deleted file mode 100644 index 90e82d1..0000000 --- a/examples/src/simple_malicious.rs +++ /dev/null @@ -1,187 +0,0 @@ -use alloc::collections::BTreeSet; -use core::fmt::Debug; - -use manul::{ - combinators::misbehave::{Misbehaving, MisbehavingEntryPoint}, - dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, - protocol::{ - Artifact, BoxedFormat, BoxedRound, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart, - }, - signature::Keypair, -}; -use rand_core::{CryptoRngCore, OsRng}; -use test_log::test; - -use crate::simple::{Round1, Round1Message, Round2, Round2Message, SimpleProtocolEntryPoint}; - -#[derive(Debug, Clone, Copy)] -enum Behavior { - SerializedGarbage, - AttributableFailure, - AttributableFailureRound2, -} - -struct MaliciousLogic; - -impl Misbehaving for MaliciousLogic { - type EntryPoint = SimpleProtocolEntryPoint; - - fn modify_direct_message( - _rng: &mut dyn CryptoRngCore, - round: &BoxedRound>::Protocol>, - behavior: &Behavior, - format: &BoxedFormat, - _destination: &Id, - direct_message: DirectMessage, - artifact: Option, - ) -> Result<(DirectMessage, Option), LocalError> { - let dm = if round.id() == 1 { - match behavior { - Behavior::SerializedGarbage => DirectMessage::new(format, [99u8])?, - Behavior::AttributableFailure => { - let round1 = round.downcast_ref::>()?; - let message = Round1Message { - my_position: round1.context.ids_to_positions[&round1.context.id], - your_position: round1.context.ids_to_positions[&round1.context.id], - }; - DirectMessage::new(format, message)? - } - _ => direct_message, - } - } else if round.id() == 2 { - match behavior { - Behavior::AttributableFailureRound2 => { - let round2 = round.downcast_ref::>()?; - let message = Round2Message { - my_position: round2.context.ids_to_positions[&round2.context.id], - your_position: round2.context.ids_to_positions[&round2.context.id], - }; - DirectMessage::new(format, message)? - } - _ => direct_message, - } - } else { - direct_message - }; - Ok((dm, artifact)) - } -} - -type MaliciousEntryPoint = MisbehavingEntryPoint; - -#[test] -fn serialized_garbage() { - let signers = (0..3).map(TestSigner::new).collect::>(); - let all_ids = signers - .iter() - .map(|signer| signer.verifying_key()) - .collect::>(); - - let entry_points = signers - .iter() - .enumerate() - .map(|(idx, signer)| { - let behavior = if idx == 0 { - Some(Behavior::SerializedGarbage) - } else { - None - }; - - let entry_point = MaliciousEntryPoint::new(SimpleProtocolEntryPoint::new(all_ids.clone()), behavior); - (*signer, entry_point) - }) - .collect::>(); - - let mut reports = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) - .unwrap() - .reports; - - let v0 = signers[0].verifying_key(); - let v1 = signers[1].verifying_key(); - let v2 = signers[2].verifying_key(); - - let _report0 = reports.remove(&v0).unwrap(); - let report1 = reports.remove(&v1).unwrap(); - let report2 = reports.remove(&v2).unwrap(); - - assert!(report1.provable_errors[&v0].verify(&()).is_ok()); - assert!(report2.provable_errors[&v0].verify(&()).is_ok()); -} - -#[test] -fn attributable_failure() { - let signers = (0..3).map(TestSigner::new).collect::>(); - let all_ids = signers - .iter() - .map(|signer| signer.verifying_key()) - .collect::>(); - - let entry_points = signers - .iter() - .enumerate() - .map(|(idx, signer)| { - let behavior = if idx == 0 { - Some(Behavior::AttributableFailure) - } else { - None - }; - - let entry_point = MaliciousEntryPoint::new(SimpleProtocolEntryPoint::new(all_ids.clone()), behavior); - (*signer, entry_point) - }) - .collect::>(); - - let mut reports = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) - .unwrap() - .reports; - - let v0 = signers[0].verifying_key(); - let v1 = signers[1].verifying_key(); - let v2 = signers[2].verifying_key(); - - let _report0 = reports.remove(&v0).unwrap(); - let report1 = reports.remove(&v1).unwrap(); - let report2 = reports.remove(&v2).unwrap(); - - assert!(report1.provable_errors[&v0].verify(&()).is_ok()); - assert!(report2.provable_errors[&v0].verify(&()).is_ok()); -} - -#[test] -fn attributable_failure_round2() { - let signers = (0..3).map(TestSigner::new).collect::>(); - let all_ids = signers - .iter() - .map(|signer| signer.verifying_key()) - .collect::>(); - - let entry_points = signers - .iter() - .enumerate() - .map(|(idx, signer)| { - let behavior = if idx == 0 { - Some(Behavior::AttributableFailureRound2) - } else { - None - }; - - let entry_point = MaliciousEntryPoint::new(SimpleProtocolEntryPoint::new(all_ids.clone()), behavior); - (*signer, entry_point) - }) - .collect::>(); - - let mut reports = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) - .unwrap() - .reports; - - let v0 = signers[0].verifying_key(); - let v1 = signers[1].verifying_key(); - let v2 = signers[2].verifying_key(); - - let _report0 = reports.remove(&v0).unwrap(); - let report1 = reports.remove(&v1).unwrap(); - let report2 = reports.remove(&v2).unwrap(); - - assert!(report1.provable_errors[&v0].verify(&()).is_ok()); - assert!(report2.provable_errors[&v0].verify(&()).is_ok()); -} diff --git a/examples/src/simple_test.rs b/examples/src/simple_test.rs new file mode 100644 index 0000000..17fd4a9 --- /dev/null +++ b/examples/src/simple_test.rs @@ -0,0 +1,138 @@ +use alloc::collections::BTreeSet; + +use manul::{ + combinators::extend::{Extendable, Extension}, + dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, + protocol::{LocalError, PartyId}, + signature::Keypair, +}; +use rand_core::{CryptoRngCore, OsRng}; +use test_log::test; + +use crate::simple::{Round1, Round1Message, Round2, Round2Message, SimpleProtocolEntryPoint}; + +#[derive(Debug, Clone)] +struct Round1InvalidDirectMessage; + +impl Extension for Round1InvalidDirectMessage +where + Id: PartyId, +{ + type Round = Round1; + + fn make_direct_message( + &self, + _rng: &mut impl CryptoRngCore, + round: &Self::Round, + _destination: &Id, + ) -> Result<(Round1Message, ()), LocalError> { + Ok(( + Round1Message { + my_position: round.context.ids_to_positions[&round.context.id], + your_position: round.context.ids_to_positions[&round.context.id], + }, + (), + )) + } +} + +#[test] +fn round1_attributable_failure() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + + let entry_points = signers + .iter() + .enumerate() + .map(|(idx, signer)| { + let entry_point = SimpleProtocolEntryPoint::new(all_ids.clone()); + let mut entry_point = Extendable::new(entry_point); + if idx == 0 { + entry_point.extend(Round1InvalidDirectMessage); + } + + (*signer, entry_point) + }) + .collect::>(); + + let mut reports = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .reports; + + let v0 = signers[0].verifying_key(); + let v1 = signers[1].verifying_key(); + let v2 = signers[2].verifying_key(); + + let _report0 = reports.remove(&v0).unwrap(); + let report1 = reports.remove(&v1).unwrap(); + let report2 = reports.remove(&v2).unwrap(); + + assert!(report1.provable_errors[&v0].verify(&()).is_ok()); + assert!(report2.provable_errors[&v0].verify(&()).is_ok()); +} + +#[derive(Debug, Clone)] +struct Round2InvalidDirectMessage; + +impl Extension for Round2InvalidDirectMessage +where + Id: PartyId, +{ + type Round = Round2; + + fn make_direct_message( + &self, + _rng: &mut impl CryptoRngCore, + round: &Self::Round, + _destination: &Id, + ) -> Result<(Round2Message, ()), LocalError> { + Ok(( + Round2Message { + my_position: round.context.ids_to_positions[&round.context.id], + your_position: round.context.ids_to_positions[&round.context.id], + }, + (), + )) + } +} + +#[test] +fn round2_attributable_failure() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + + let entry_points = signers + .iter() + .enumerate() + .map(|(idx, signer)| { + let entry_point = SimpleProtocolEntryPoint::new(all_ids.clone()); + let mut entry_point = Extendable::new(entry_point); + if idx == 0 { + entry_point.extend(Round2InvalidDirectMessage); + } + + (*signer, entry_point) + }) + .collect::>(); + + let mut reports = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .reports; + + let v0 = signers[0].verifying_key(); + let v1 = signers[1].verifying_key(); + let v2 = signers[2].verifying_key(); + + let _report0 = reports.remove(&v0).unwrap(); + let report1 = reports.remove(&v1).unwrap(); + let report2 = reports.remove(&v2).unwrap(); + + assert!(report1.provable_errors[&v0].verify(&()).is_ok()); + assert!(report2.provable_errors[&v0].verify(&()).is_ok()); +} diff --git a/manul/GUIDE.md b/manul/GUIDE.md index abb5689..47da724 100644 --- a/manul/GUIDE.md +++ b/manul/GUIDE.md @@ -18,8 +18,7 @@ pub struct MyProtocol; impl Protocol for MyProtocol { type Result = SomeResultType; - type Error = SomeErrorType; - // ... other required trait methods ... + // ... other required trait methods and types ... } ``` @@ -27,8 +26,6 @@ Key aspects: - **[`type Result`]**: This associated type defines the final output of a successful protocol execution. In the case of the Dining Cryptographers Problem, it's a tuple of bools representing each cryptographer's perspective on the outcome. - **Error Handling (Advanced)**: In more realistic protocols, the [`Protocol`] trait is where you would define error types and misbehavior reporting; when not needed, there's a [`NoProtocolErrors`] convenience type. -- **Message Validation (Advanced)**: The methods [`verify_direct_message_is_invalid`] -, [`verify_echo_broadcast_is_invalid`], and [`verify_normal_broadcast_is_invalid`] are used for validating message contents during evidence verification in more complex scenarios. ## 2. Define Your Rounds ([`Round`]) @@ -136,9 +133,6 @@ This function takes a vector of `(Signer, EntryPoint)` pairs (one for each parti [`Protocol`]: crate::protocol::Protocol [`type Result`]: crate::protocol::Protocol::Result [`NoProtocolErrors`]: crate::protocol::NoProtocolErrors -[`verify_direct_message_is_invalid`]: crate::protocol::Protocol::verify_direct_message_is_invalid -[`verify_echo_broadcast_is_invalid`]: crate::protocol::Protocol::verify_echo_broadcast_is_invalid -[`verify_normal_broadcast_is_invalid`]: crate::protocol::Protocol::verify_normal_broadcast_is_invalid] [`Round`]: crate::protocol::Round [`type Protocol`]: crate::protocol::Round::Protocol [`transition_info()`]: crate::protocol::Round::transition_info @@ -153,7 +147,7 @@ This function takes a vector of `(Signer, EntryPoint)` pairs (one for each parti [`receive_message()`]: crate::protocol::Round::receive_message [`ProtocolMessage`]: crate::protocol::ProtocolMessage [`finalize()`]: crate::protocol::Round::finalize -[`Artifact`]: crate::protocol::Artifact +[`Artifact`]: crate::protocol::Round::Artifact [`FinalizeOutcome::AnotherRound`]: crate::protocol::FinalizeOutcome::AnotherRound [`FinalizeOutcome::Result`]: crate::protocol::FinalizeOutcome::Result [`EntryPoint`]: crate::protocol::EntryPoint @@ -161,7 +155,7 @@ This function takes a vector of `(Signer, EntryPoint)` pairs (one for each parti [`entry_round_id()`]: crate::protocol::EntryPoint::entry_round_id [`make_round()`]: crate::protocol::EntryPoint::make_round [`SessionParameters`]: crate::session::SessionParameters -[`Payload`]: crate::protocol::Payload +[`Payload`]: crate::protocol::Round::Payload [`run_sync`]: crate::session::run_sync] [`BinaryFormat`]: crate::dev::BinaryFormat [`HumanReadableFormat`]: crate::dev::HumanReadableFormat diff --git a/manul/benches/async_session.rs b/manul/benches/async_session.rs index 18aa142..a2f87e6 100644 --- a/manul/benches/async_session.rs +++ b/manul/benches/async_session.rs @@ -8,9 +8,8 @@ use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ dev::{tokio::run_async, BinaryFormat, TestSessionParams, TestSigner}, protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, - FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors, NormalBroadcast, PartyId, Payload, - Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round, RoundId, TransitionInfo, + BoxedRound, CommunicationInfo, EntryPoint, FinalizeOutcome, LocalError, NoMessage, NoProtocolErrors, PartyId, + Protocol, ProtocolMessage, ReceiveError, Round, RoundId, RoundInfo, TransitionInfo, }, signature::Keypair, }; @@ -32,29 +31,9 @@ pub struct EmptyProtocol; impl Protocol for EmptyProtocol { type Result = (); - type ProtocolError = NoProtocolErrors; + type SharedData = (); - fn verify_direct_message_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_echo_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_normal_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { + fn round_info(_round_id: &RoundId) -> Option> { unimplemented!() } } @@ -65,6 +44,12 @@ struct EmptyRound { inputs: Inputs, } +#[derive(Debug)] +struct EmptyRoundWithEcho { + round_counter: u8, + inputs: Inputs, +} + #[derive(Debug, Clone)] struct Inputs { rounds_num: u8, @@ -91,20 +76,36 @@ impl EntryPoint for Inputs { fn make_round( self, - _rng: &mut dyn CryptoRngCore, + _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], _id: &Id, ) -> Result, LocalError> { - Ok(BoxedRound::new_dynamic(EmptyRound { - round_counter: 1, - inputs: self, - })) + if self.echo { + Ok(BoxedRound::new(EmptyRoundWithEcho { + round_counter: 1, + inputs: self, + })) + } else { + Ok(BoxedRound::new(EmptyRound { + round_counter: 1, + inputs: self, + })) + } } } impl Round for EmptyRound { type Protocol = EmptyProtocol; + type DirectMessage = Round1DirectMessage; + type EchoBroadcast = Round1EchoBroadcast; + type NormalBroadcast = NoMessage; + + type Artifact = Round1Artifact; + type Payload = Round1Payload; + + type ProtocolError = NoProtocolErrors; + fn transition_info(&self) -> TransitionInfo { if self.inputs.rounds_num == self.round_counter { TransitionInfo::new_linear_terminating(self.round_counter) @@ -117,64 +118,96 @@ impl Round for EmptyRound { CommunicationInfo::regular(&self.inputs.other_ids) } - fn make_echo_broadcast( + fn make_direct_message( + &self, + _rng: &mut impl CryptoRngCore, + _destination: &Id, + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { + Ok((Round1DirectMessage(do_work(self.round_counter + 2)), Round1Artifact)) + } + + fn receive_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { - if self.inputs.echo { - EchoBroadcast::new(format, Round1EchoBroadcast) + _from: &Id, + message: ProtocolMessage, + ) -> Result> { + assert!(message.direct_message.0 == do_work(self.round_counter + 2)); + Ok(Round1Payload) + } + + fn finalize( + self, + _rng: &mut impl CryptoRngCore, + _payloads: BTreeMap, + _artifacts: BTreeMap, + ) -> Result, LocalError> { + if self.round_counter == self.inputs.rounds_num { + Ok(FinalizeOutcome::Result(())) } else { - Ok(EchoBroadcast::none()) + let round = BoxedRound::new(EmptyRound { + round_counter: self.round_counter + 1, + inputs: self.inputs, + }); + Ok(FinalizeOutcome::AnotherRound(round)) + } + } +} + +impl Round for EmptyRoundWithEcho { + type Protocol = EmptyProtocol; + + type DirectMessage = Round1DirectMessage; + type EchoBroadcast = Round1EchoBroadcast; + type NormalBroadcast = NoMessage; + + type Artifact = Round1Artifact; + type Payload = Round1Payload; + + type ProtocolError = NoProtocolErrors; + + fn transition_info(&self) -> TransitionInfo { + if self.inputs.rounds_num == self.round_counter { + TransitionInfo::new_linear_terminating(self.round_counter) + } else { + TransitionInfo::new_linear(self.round_counter) } } + fn communication_info(&self) -> CommunicationInfo { + CommunicationInfo::regular(&self.inputs.other_ids) + } + + fn make_echo_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { + Ok(Round1EchoBroadcast) + } + fn make_direct_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, + _rng: &mut impl CryptoRngCore, _destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - let dm = DirectMessage::new(format, Round1DirectMessage(do_work(self.round_counter + 2)))?; - let artifact = Artifact::new(Round1Artifact); - Ok((dm, Some(artifact))) + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { + Ok((Round1DirectMessage(do_work(self.round_counter + 2)), Round1Artifact)) } fn receive_message( &self, - format: &BoxedFormat, _from: &Id, - message: ProtocolMessage, - ) -> Result> { - //std::thread::sleep(std::time::Duration::from_secs_f64(0.001)); - if self.inputs.echo { - let _echo_broadcast = message.echo_broadcast.deserialize::(format)?; - } else { - message.echo_broadcast.assert_is_none()?; - } - message.normal_broadcast.assert_is_none()?; - let direct_message = message.direct_message.deserialize::(format)?; - assert!(direct_message.0 == do_work(self.round_counter + 2)); - Ok(Payload::new(Round1Payload)) + message: ProtocolMessage, + ) -> Result> { + assert!(message.direct_message.0 == do_work(self.round_counter + 2)); + Ok(Round1Payload) } fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + _payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { - for payload in payloads.into_values() { - let _payload = payload.downcast::()?; - } - for artifact in artifacts.into_values() { - let _artifact = artifact.downcast::()?; - } - if self.round_counter == self.inputs.rounds_num { Ok(FinalizeOutcome::Result(())) } else { - let round = BoxedRound::new_dynamic(EmptyRound { + let round = BoxedRound::new(EmptyRound { round_counter: self.round_counter + 1, inputs: self.inputs, }); diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 4607fc7..405bdee 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -7,9 +7,8 @@ use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, - FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors, NormalBroadcast, PartyId, Payload, - Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round, RoundId, TransitionInfo, + BoxedRound, CommunicationInfo, EntryPoint, FinalizeOutcome, LocalError, NoMessage, NoProtocolErrors, PartyId, + Protocol, ProtocolMessage, ReceiveError, Round, RoundId, RoundInfo, TransitionInfo, }, signature::Keypair, }; @@ -21,29 +20,8 @@ pub struct EmptyProtocol; impl Protocol for EmptyProtocol { type Result = (); - type ProtocolError = NoProtocolErrors; - - fn verify_direct_message_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_echo_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_normal_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { + type SharedData = (); + fn round_info(_round_id: &RoundId) -> Option> { unimplemented!() } } @@ -54,6 +32,12 @@ struct EmptyRound { inputs: Inputs, } +#[derive(Debug)] +struct EmptyRoundWithEcho { + round_counter: u8, + inputs: Inputs, +} + #[derive(Debug, Clone)] struct Inputs { rounds_num: u8, @@ -80,20 +64,34 @@ impl EntryPoint for Inputs { fn make_round( self, - _rng: &mut dyn CryptoRngCore, + _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], _id: &Id, ) -> Result, LocalError> { - Ok(BoxedRound::new_dynamic(EmptyRound { - round_counter: 1, - inputs: self, - })) + if self.echo { + Ok(BoxedRound::new(EmptyRoundWithEcho { + round_counter: 1, + inputs: self, + })) + } else { + Ok(BoxedRound::new(EmptyRound { + round_counter: 1, + inputs: self, + })) + } } } impl Round for EmptyRound { type Protocol = EmptyProtocol; + type ProtocolError = NoProtocolErrors; + type EchoBroadcast = Round1EchoBroadcast; + type NormalBroadcast = NoMessage; + type DirectMessage = Round1DirectMessage; + type Payload = Round1Payload; + type Artifact = Round1Artifact; + fn transition_info(&self) -> TransitionInfo { if self.inputs.rounds_num == self.round_counter { TransitionInfo::new_linear_terminating(self.round_counter) @@ -106,62 +104,92 @@ impl Round for EmptyRound { CommunicationInfo::regular(&self.inputs.other_ids) } - fn make_echo_broadcast( + fn make_direct_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { - if self.inputs.echo { - EchoBroadcast::new(format, Round1EchoBroadcast) + _rng: &mut impl CryptoRngCore, + _destination: &Id, + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { + Ok((Round1DirectMessage, Round1Artifact)) + } + + fn receive_message( + &self, + _from: &Id, + _message: ProtocolMessage, + ) -> Result> { + Ok(Round1Payload) + } + + fn finalize( + self, + _rng: &mut impl CryptoRngCore, + _payloads: BTreeMap, + _artifacts: BTreeMap, + ) -> Result, LocalError> { + if self.round_counter == self.inputs.rounds_num { + Ok(FinalizeOutcome::Result(())) } else { - Ok(EchoBroadcast::none()) + let round = BoxedRound::new(EmptyRound { + round_counter: self.round_counter + 1, + inputs: self.inputs, + }); + Ok(FinalizeOutcome::AnotherRound(round)) + } + } +} + +impl Round for EmptyRoundWithEcho { + type Protocol = EmptyProtocol; + + type ProtocolError = NoProtocolErrors; + type EchoBroadcast = Round1EchoBroadcast; + type NormalBroadcast = NoMessage; + type DirectMessage = Round1DirectMessage; + type Payload = Round1Payload; + type Artifact = Round1Artifact; + + fn transition_info(&self) -> TransitionInfo { + if self.inputs.rounds_num == self.round_counter { + TransitionInfo::new_linear_terminating(self.round_counter) + } else { + TransitionInfo::new_linear(self.round_counter) } } + fn communication_info(&self) -> CommunicationInfo { + CommunicationInfo::regular(&self.inputs.other_ids) + } + + fn make_echo_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { + Ok(Round1EchoBroadcast) + } + fn make_direct_message( &self, - _rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, + _rng: &mut impl CryptoRngCore, _destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - let dm = DirectMessage::new(format, Round1DirectMessage)?; - let artifact = Artifact::new(Round1Artifact); - Ok((dm, Some(artifact))) + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { + Ok((Round1DirectMessage, Round1Artifact)) } fn receive_message( &self, - format: &BoxedFormat, _from: &Id, - message: ProtocolMessage, - ) -> Result> { - if self.inputs.echo { - let _echo_broadcast = message.echo_broadcast.deserialize::(format)?; - } else { - message.echo_broadcast.assert_is_none()?; - } - message.normal_broadcast.assert_is_none()?; - let _direct_message = message.direct_message.deserialize::(format)?; - Ok(Payload::new(Round1Payload)) + _message: ProtocolMessage, + ) -> Result> { + Ok(Round1Payload) } fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + _payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { - for payload in payloads.into_values() { - let _payload = payload.downcast::()?; - } - for artifact in artifacts.into_values() { - let _artifact = artifact.downcast::()?; - } - if self.round_counter == self.inputs.rounds_num { Ok(FinalizeOutcome::Result(())) } else { - let round = BoxedRound::new_dynamic(EmptyRound { + let round = BoxedRound::new(EmptyRound { round_counter: self.round_counter + 1, inputs: self.inputs, }); diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs index 48f012b..81bfb4a 100644 --- a/manul/src/combinators.rs +++ b/manul/src/combinators.rs @@ -1,4 +1,4 @@ //! Combinators operating on protocols. pub mod chain; -pub mod misbehave; +pub mod extend; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs index 38e7379..4204fa3 100644 --- a/manul/src/combinators/chain.rs +++ b/manul/src/combinators/chain.rs @@ -45,19 +45,21 @@ Usage: 5. Implement the marker trait [`ChainedMarker`] for this type. Same as with the protocol, this is needed to disambiguate different generic blanket implementations. -6. [`ChainedAssociatedData`] is the structure used to supply associated data +6. [`ChainedSharedData`] is the structure used to supply shared data when verifying evidence from the chained protocol. + Contains shared data for both protocols. */ -use alloc::{boxed::Box, collections::BTreeMap}; -use core::fmt::{self, Debug}; +use alloc::{boxed::Box, collections::BTreeMap, format}; +use core::fmt::Debug; use rand_core::CryptoRngCore; use crate::protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, - LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage, - ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, TransitionInfo, + Artifact, BoxedFormat, BoxedReceiveError, BoxedRng, BoxedRound, CommunicationInfo, DirectMessage, + DynProtocolMessage, DynRound, DynRoundInfo, EchoBroadcast, EntryPoint, EvidenceError, EvidenceProtocolMessage, + FinalizeOutcome, GroupNum, LocalError, NormalBroadcast, PartyId, Payload, Protocol, RoundId, RoundInfo, + SerializedProtocolError, TransitionInfo, }; /// A marker trait that is used to disambiguate blanket trait implementations for [`Protocol`] and [`EntryPoint`]. @@ -72,175 +74,177 @@ pub trait ChainedProtocol: 'static + Debug { type Protocol2: Protocol; } -/// The protocol error type for the chained protocol. -#[derive_where::derive_where(Debug, Clone, Serialize, Deserialize)] -pub enum ChainedProtocolError +/// Associated data for verification of malicious behavior evidence in the chained protocol. +#[derive_where::derive_where(Debug)] +pub struct ChainedSharedData where C: ChainedProtocol, { - /// A protocol error from the first protocol. - Protocol1(>::ProtocolError), - /// A protocol error from the second protocol. - Protocol2(>::ProtocolError), + /// Associated data for the errors in the first protocol. + pub protocol1: >::SharedData, + /// Associated data for the errors in the second protocol. + pub protocol2: >::SharedData, } -impl fmt::Display for ChainedProtocolError -where - C: ChainedProtocol, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - Self::Protocol1(err) => write!(f, "Protocol 1: {err}"), - Self::Protocol2(err) => write!(f, "Protocol 2: {err}"), - } +fn ungroup(expected_group: GroupNum, round_id: &RoundId) -> Result { + let (group, round_id) = round_id.split_group()?; + if group != expected_group { + return Err(LocalError::new(format!( + "Expected round ID from group {expected_group}, got round {round_id} from a different group: {group}" + ))); } + Ok(round_id) } -impl ChainedProtocolError -where - C: ChainedProtocol, -{ - fn from_protocol1(err: >::ProtocolError) -> Self { - Self::Protocol1(err) - } - - fn from_protocol2(err: >::ProtocolError) -> Self { - Self::Protocol2(err) - } +fn ungroup_map(expected_group: GroupNum, grouped: BTreeMap) -> Result, LocalError> { + grouped + .into_iter() + .map(|(round_id, value)| ungroup(expected_group, &round_id).map(|round_id| (round_id, value))) + .collect() } -/// Associated data for verification of malicious behavior evidence in the chained protocol. #[derive_where::derive_where(Debug)] -pub struct ChainedAssociatedData -where - C: ChainedProtocol, -{ - /// Associated data for the errors in the first protocol. - pub protocol1: <>::ProtocolError as ProtocolError>::AssociatedData, - /// Associated data for the errors in the second protocol. - pub protocol2: <>::ProtocolError as ProtocolError>::AssociatedData, -} +struct RoundInfo1 + ChainedMarker>(RoundInfo); -impl ProtocolError for ChainedProtocolError +impl DynRoundInfo for RoundInfo1 where - C: ChainedProtocol, + P: ChainedProtocol + ChainedMarker, { - type AssociatedData = ChainedAssociatedData; + type Protocol = P; - fn required_messages(&self) -> RequiredMessages { - let (protocol_num, required_messages) = match self { - Self::Protocol1(err) => (1, err.required_messages()), - Self::Protocol2(err) => (2, err.required_messages()), - }; + fn verify_direct_message_is_invalid( + &self, + format: &BoxedFormat, + message: &DirectMessage, + ) -> Result<(), EvidenceError> { + self.0.as_ref().verify_direct_message_is_invalid(format, message) + } - let previous_rounds = required_messages.previous_rounds.map(|previous_rounds| { - previous_rounds - .into_iter() - .map(|(round_id, required)| (round_id.group_under(protocol_num), required)) - .collect() - }); - - let combined_echos = required_messages.combined_echos.map(|combined_echos| { - combined_echos - .into_iter() - .map(|round_id| round_id.group_under(protocol_num)) - .collect() - }); - - RequiredMessages { - this_round: required_messages.this_round, - previous_rounds, - combined_echos, - } + fn verify_echo_broadcast_is_invalid( + &self, + format: &BoxedFormat, + message: &EchoBroadcast, + ) -> Result<(), EvidenceError> { + self.0.as_ref().verify_echo_broadcast_is_invalid(format, message) + } + + fn verify_normal_broadcast_is_invalid( + &self, + format: &BoxedFormat, + message: &NormalBroadcast, + ) -> Result<(), EvidenceError> { + self.0.as_ref().verify_normal_broadcast_is_invalid(format, message) } - #[allow(clippy::too_many_arguments)] - fn verify_messages_constitute_error( + fn verify_evidence( &self, + round_id: &RoundId, format: &BoxedFormat, + error: &SerializedProtocolError, guilty_party: &Id, shared_randomness: &[u8], - associated_data: &Self::AssociatedData, - message: ProtocolMessage, - previous_messages: BTreeMap, + shared_data: &>::SharedData, + message: EvidenceProtocolMessage, + previous_messages: BTreeMap, combined_echos: BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - let previous_messages = previous_messages - .into_iter() - .map(|(round_id, message)| round_id.split_group().map(|(_group_num, round_id)| (round_id, message))) - .collect::, _>>()?; - let combined_echos = combined_echos - .into_iter() - .map(|(round_id, message)| round_id.split_group().map(|(_group_num, round_id)| (round_id, message))) - .collect::, _>>()?; - - match self { - Self::Protocol1(err) => err.verify_messages_constitute_error( - format, - guilty_party, - shared_randomness, - &associated_data.protocol1, - message, - previous_messages, - combined_echos, - ), - Self::Protocol2(err) => err.verify_messages_constitute_error( - format, - guilty_party, - shared_randomness, - &associated_data.protocol2, - message, - previous_messages, - combined_echos, - ), - } + ) -> Result<(), EvidenceError> { + let round_id = ungroup(1, round_id)?; + let previous_messages = ungroup_map(1, previous_messages)?; + let combined_echos = ungroup_map(1, combined_echos)?; + self.0.as_ref().verify_evidence( + &round_id, + format, + error, + guilty_party, + shared_randomness, + &shared_data.protocol1, + message, + previous_messages, + combined_echos, + ) } } -impl Protocol for C +#[derive_where::derive_where(Debug)] +struct RoundInfo2 + ChainedMarker>(RoundInfo); + +impl DynRoundInfo for RoundInfo2 where - Id: 'static, - C: ChainedProtocol + ChainedMarker, + P: ChainedProtocol + ChainedMarker, { - type Result = >::Result; - type ProtocolError = ChainedProtocolError; + type Protocol = P; fn verify_direct_message_is_invalid( + &self, format: &BoxedFormat, - round_id: &RoundId, message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - let (group, round_id) = round_id.split_group()?; - if group == 1 { - C::Protocol1::verify_direct_message_is_invalid(format, &round_id, message) - } else { - C::Protocol2::verify_direct_message_is_invalid(format, &round_id, message) - } + ) -> Result<(), EvidenceError> { + self.0.as_ref().verify_direct_message_is_invalid(format, message) } fn verify_echo_broadcast_is_invalid( + &self, format: &BoxedFormat, - round_id: &RoundId, message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - let (group, round_id) = round_id.split_group()?; - if group == 1 { - C::Protocol1::verify_echo_broadcast_is_invalid(format, &round_id, message) - } else { - C::Protocol2::verify_echo_broadcast_is_invalid(format, &round_id, message) - } + ) -> Result<(), EvidenceError> { + self.0.as_ref().verify_echo_broadcast_is_invalid(format, message) } fn verify_normal_broadcast_is_invalid( + &self, format: &BoxedFormat, - round_id: &RoundId, message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { - let (group, round_id) = round_id.split_group()?; + ) -> Result<(), EvidenceError> { + self.0.as_ref().verify_normal_broadcast_is_invalid(format, message) + } + + fn verify_evidence( + &self, + round_id: &RoundId, + format: &BoxedFormat, + error: &SerializedProtocolError, + guilty_party: &Id, + shared_randomness: &[u8], + shared_data: &>::SharedData, + message: EvidenceProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, + ) -> Result<(), EvidenceError> { + let round_id = ungroup(2, round_id)?; + let previous_messages = ungroup_map(2, previous_messages)?; + let combined_echos = ungroup_map(2, combined_echos)?; + self.0.as_ref().verify_evidence( + &round_id, + format, + error, + guilty_party, + shared_randomness, + &shared_data.protocol2, + message, + previous_messages, + combined_echos, + ) + } +} + +impl Protocol for C +where + Id: 'static, + C: ChainedProtocol + ChainedMarker, +{ + type Result = >::Result; + type SharedData = ChainedSharedData; + + fn round_info(round_id: &RoundId) -> Option> { + let (group, round_id) = round_id.split_group().ok()?; if group == 1 { - C::Protocol1::verify_normal_broadcast_is_invalid(format, &round_id, message) + let round_info = C::Protocol1::round_info(&round_id)?; + Some(RoundInfo::new_obj(RoundInfo1(round_info))) + } else if group == 2 { + let round_info = C::Protocol2::round_info(&round_id)?; + Some(RoundInfo::new_obj(RoundInfo2(round_info))) } else { - C::Protocol2::verify_normal_broadcast_is_invalid(format, &round_id, message) + None } } } @@ -288,7 +292,7 @@ where fn make_round( self, - rng: &mut dyn CryptoRngCore, + rng: &mut impl CryptoRngCore, shared_randomness: &[u8], id: &Id, ) -> Result, LocalError> { @@ -330,7 +334,7 @@ where Protocol2(BoxedRound>::Protocol2>), } -impl Round for ChainedRound +impl DynRound for ChainedRound where Id: PartyId, T: ChainedJoin, @@ -363,7 +367,7 @@ where rng: &mut dyn CryptoRngCore, format: &BoxedFormat, destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { + ) -> Result<(DirectMessage, Artifact), LocalError> { match &self.state { ChainState::Protocol1 { round, .. } => round.as_ref().make_direct_message(rng, format, destination), ChainState::Protocol2(round) => round.as_ref().make_direct_message(rng, format, destination), @@ -396,17 +400,17 @@ where &self, format: &BoxedFormat, from: &Id, - message: ProtocolMessage, - ) -> Result> { + message: DynProtocolMessage, + ) -> Result> { match &self.state { - ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message(format, from, message) { - Ok(payload) => Ok(payload), - Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), - }, - ChainState::Protocol2(round) => match round.as_ref().receive_message(format, from, message) { - Ok(payload) => Ok(payload), - Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), - }, + ChainState::Protocol1 { round, .. } => round + .as_ref() + .receive_message(format, from, message) + .map_err(|error| error.group_under(1)), + ChainState::Protocol2(round) => round + .as_ref() + .receive_message(format, from, message) + .map_err(|error| error.group_under(2)), } } @@ -422,10 +426,10 @@ where round, transition, shared_randomness, - } => match round.into_boxed().finalize(rng, payloads, artifacts)? { + } => match round.into_inner().finalize(rng, payloads, artifacts)? { FinalizeOutcome::Result(result) => { let entry_point2 = transition.make_entry_point2(result); - let round = entry_point2.make_round(rng, &shared_randomness, &id)?; + let round = entry_point2.make_round(&mut BoxedRng(rng), &shared_randomness, &id)?; let chained_round = ChainedRound:: { state: ChainState::Protocol2(round), }; @@ -443,7 +447,7 @@ where Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(chained_round))) } }, - ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts)? { + ChainState::Protocol2(round) => match round.into_inner().finalize(rng, payloads, artifacts)? { FinalizeOutcome::Result(result) => Ok(FinalizeOutcome::Result(result)), FinalizeOutcome::AnotherRound(round) => { let chained_round = ChainedRound:: { diff --git a/manul/src/combinators/extend.rs b/manul/src/combinators/extend.rs new file mode 100644 index 0000000..d10f6f6 --- /dev/null +++ b/manul/src/combinators/extend.rs @@ -0,0 +1,385 @@ +/*! +This module contains tools to extend or override methods of a [`Round`] in a protocol. + +Usage: + +1. Implement [`Extension`] for an object (which may be an empty struct or contain some data). + +2. Wrap an [`EntryPoint`] of a protocol in an [`Extendable`]. + +3. Add extensions to it via [`Extendable::extend`] or [`Extendable::with_extension`]. + +4. Use the [`Extendable`] object as the new entry point. + The extension will be activated for every round whose type is equal to [`Extension::Round`]. +*/ + +use alloc::{boxed::Box, collections::BTreeMap, string::String}; +use core::{any::TypeId, fmt::Debug}; + +use rand_core::CryptoRngCore; + +use crate::protocol::{ + Artifact, BoxedFormat, BoxedReceiveError, BoxedRound, BoxedTypedRound, CommunicationInfo, DirectMessage, + DynProtocolMessage, DynRound, EchoBroadcast, EntryPoint, EvidenceError, EvidenceMessages, FinalizeOutcome, + LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage, ReceiveError, + RequiredMessages, Round, RoundId, TransitionInfo, +}; + +/// An extension to a round, allowing one to extend or override its methods. +pub trait Extension: 'static + Debug + Send + Sync + Clone { + /// The round type to which the extension is applied. + type Round: Round; + + /// Called instead of [`Round::make_normal_broadcast`]. + /// + /// The default implementation calls [`Round::make_normal_broadcast`]. + fn make_normal_broadcast( + &self, + rng: &mut impl CryptoRngCore, + round: &Self::Round, + ) -> Result<>::NormalBroadcast, LocalError> { + round.make_normal_broadcast(rng) + } + + /// Called instead of [`Round::make_echo_broadcast`]. + /// + /// The default implementation calls [`Round::make_echo_broadcast`]. + fn make_echo_broadcast( + &self, + rng: &mut impl CryptoRngCore, + round: &Self::Round, + ) -> Result<>::EchoBroadcast, LocalError> { + round.make_echo_broadcast(rng) + } + + /// Called instead of [`Round::make_direct_message`]. + /// + /// The default implementation calls [`Round::make_direct_message`]. + #[allow(clippy::type_complexity)] + fn make_direct_message( + &self, + rng: &mut impl CryptoRngCore, + round: &Self::Round, + destination: &Id, + ) -> Result< + ( + >::DirectMessage, + >::Artifact, + ), + LocalError, + > { + round.make_direct_message(rng, destination) + } + + /// Called instead of [`Round::finalize`]. + /// + /// The default implementation calls [`Round::finalize`]. + fn extend_finalize( + &self, + rng: &mut impl CryptoRngCore, + round: Self::Round, + payloads: BTreeMap>::Payload>, + artifacts: BTreeMap>::Artifact>, + ) -> Result>::Protocol>, LocalError> { + round.finalize(rng, payloads, artifacts) + } +} + +#[derive_where::derive_where(Debug, Clone, Serialize, Deserialize)] +struct ExtendedProtocolError>(>::ProtocolError); + +impl> ProtocolError for ExtendedProtocolError { + type Round = ExtendedRound; + fn required_messages(&self, round_id: &RoundId) -> RequiredMessages { + self.0.required_messages(round_id) + } + fn verify_evidence( + &self, + round_id: &RoundId, + from: &Id, + shared_randomness: &[u8], + shared_data: &<>::Protocol as Protocol>::SharedData, + messages: EvidenceMessages<'_, Id, Self::Round>, + ) -> Result<(), EvidenceError> { + let messages = messages.into_round::(); + self.0 + .verify_evidence(round_id, from, shared_randomness, shared_data, messages) + } + fn description(&self) -> String { + self.0.description() + } +} + +#[allow(clippy::type_complexity)] +#[derive(Debug)] +struct ExtendedRound> { + round: Ext::Round, + /// The extension active for the current round type. + extension: Ext, + /// A mapping between round types and extensions. + /// During protocol execution, this map is checked and if the current round type has an extension defined, + /// use it to extend the round. Otherwise fall through to the "normal" round. + /// + /// It is saved here since we have no access to external context from a round, + /// so we have to pass this mapping from round to round during finalization. + extensions: BTreeMap>::Protocol>>>, +} + +impl Round for ExtendedRound +where + Id: PartyId, + Ext: Extension, +{ + type Protocol = >::Protocol; + type ProtocolError = ExtendedProtocolError; + + type DirectMessage = >::DirectMessage; + type NormalBroadcast = >::NormalBroadcast; + type EchoBroadcast = >::EchoBroadcast; + + type Payload = >::Payload; + type Artifact = >::Artifact; + + fn transition_info(&self) -> TransitionInfo { + self.round.transition_info() + } + + fn communication_info(&self) -> CommunicationInfo { + self.round.communication_info() + } + + fn receive_message( + &self, + from: &Id, + message: ProtocolMessage, + ) -> Result> { + self.round + .receive_message( + from, + ProtocolMessage { + echo_broadcast: message.echo_broadcast, + normal_broadcast: message.normal_broadcast, + direct_message: message.direct_message, + }, + ) + .map_err(|error| error.map::(ExtendedProtocolError)) + } + + fn make_normal_broadcast(&self, rng: &mut impl CryptoRngCore) -> Result { + self.extension.make_normal_broadcast(rng, &self.round) + } + + fn make_echo_broadcast(&self, rng: &mut impl CryptoRngCore) -> Result { + self.extension.make_echo_broadcast(rng, &self.round) + } + + fn make_direct_message( + &self, + rng: &mut impl CryptoRngCore, + destination: &Id, + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { + self.extension.make_direct_message(rng, &self.round, destination) + } + + fn finalize( + self, + rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, LocalError> { + let outcome = self.extension.extend_finalize(rng, self.round, payloads, artifacts)?; + Ok(match outcome { + FinalizeOutcome::Result(result) => FinalizeOutcome::Result(result), + FinalizeOutcome::AnotherRound(round) => FinalizeOutcome::AnotherRound(wrap_round(round, self.extensions)?), + }) + } +} + +pub(crate) trait DynExtension>: 'static + Debug + Send + Sync { + fn clone_boxed(&self) -> Box>; + + fn extend_round( + self: Box, + round: BoxedRound, + extensions: BTreeMap>>, + ) -> Option>; +} + +#[derive(Debug, Clone)] +struct ExtensionWrapper(Ext); + +impl ExtensionWrapper { + fn new(extension: Ext) -> Self { + Self(extension) + } +} + +impl DynExtension>::Protocol> for ExtensionWrapper +where + Id: PartyId, + Ext: Extension, +{ + fn clone_boxed(&self) -> Box>::Protocol>> { + Box::new(ExtensionWrapper(self.0.clone())) + } + + fn extend_round( + self: Box, + round: BoxedRound>::Protocol>, + extensions: BTreeMap>::Protocol>>>, + ) -> Option>::Protocol>> { + let typed_round = round.into_typed().ok()?.downcast::().ok()?; + let extended_round = ExtendedRound:: { + round: typed_round, + extension: (*self).0, + extensions, + }; + Some(BoxedRound::new(extended_round)) + } +} + +#[derive_where::derive_where(Debug)] +struct PassthroughRound> { + round: BoxedRound, + extensions: BTreeMap>>, +} + +impl DynRound for PassthroughRound +where + Id: PartyId, + P: Protocol, +{ + type Protocol = P; + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, LocalError> { + let outcome = self.round.into_inner().finalize(rng, payloads, artifacts)?; + Ok(match outcome { + FinalizeOutcome::Result(result) => FinalizeOutcome::Result(result), + FinalizeOutcome::AnotherRound(round) => FinalizeOutcome::AnotherRound(wrap_round(round, self.extensions)?), + }) + } + + fn transition_info(&self) -> TransitionInfo { + self.round.as_ref().transition_info() + } + + fn communication_info(&self) -> CommunicationInfo { + self.round.as_ref().communication_info() + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + destination: &Id, + ) -> Result<(DirectMessage, Artifact), LocalError> { + self.round.as_ref().make_direct_message(rng, format, destination) + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + ) -> Result { + self.round.as_ref().make_echo_broadcast(rng, format) + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + ) -> Result { + self.round.as_ref().make_normal_broadcast(rng, format) + } + + fn receive_message( + &self, + format: &BoxedFormat, + from: &Id, + message: DynProtocolMessage, + ) -> Result> { + self.round.as_ref().receive_message(format, from, message) + } +} + +fn wrap_round>( + round: BoxedRound, + extensions: BTreeMap>>, +) -> Result, LocalError> { + if let Some(extension) = extensions.get(&round.as_typed()?.type_id()) { + let extension: Box> = extension.clone_boxed(); + // This will only panic if the fetched element was previously added to `extensions` with a wrong key. + Ok(extension + .extend_round(round, extensions) + .expect("Extension's associated `Round` has a correct type")) + } else { + Ok(BoxedRound::new_dynamic(PassthroughRound { round, extensions })) + } +} + +/// A wrapper for a protocol's [`EntryPoint`], allowing registering [`Extension`] implementors +/// to extend or override [`Round`] methods. +#[derive(Debug)] +pub struct Extendable> { + entry_point: EP, + extensions: BTreeMap>>, +} + +impl Extendable +where + Id: PartyId, + EP: EntryPoint, +{ + /// Wraps an entry point making it extendable. + pub fn new(entry_point: EP) -> Self { + Self { + entry_point, + extensions: BTreeMap::new(), + } + } + + /// Registers an extension and returns the updated entry point. + pub fn with_extension>(self, extension: Ext) -> Self + where + Ext::Round: Round, + { + let mut entry_point = self; + entry_point.extend(extension); + entry_point + } + + /// Registers an extension. + pub fn extend>(&mut self, extension: Ext) + where + Ext::Round: Round, + { + let type_id = BoxedTypedRound::::type_id_for::(); + self.extensions + .insert(type_id, Box::new(ExtensionWrapper::new(extension))); + } +} + +impl EntryPoint for Extendable +where + Id: PartyId, + EP: EntryPoint, +{ + type Protocol = >::Protocol; + fn entry_round_id() -> RoundId { + EP::entry_round_id() + } + fn make_round( + self, + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: &Id, + ) -> Result, LocalError> { + let round = self.entry_point.make_round(rng, shared_randomness, id)?; + wrap_round(round, self.extensions) + } +} diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs deleted file mode 100644 index a578d21..0000000 --- a/manul/src/combinators/misbehave.rs +++ /dev/null @@ -1,318 +0,0 @@ -/*! -A combinator allowing one to intercept outgoing messages from a round, and replace or modify them. - -Usage: - -1. Define a behavior type, subject to [`Behavior`] bounds. - This will represent the possible actions the override may perform. - -2. Implement [`Misbehaving`] for a type of your choice. Usually it will be a ZST. - You will need to specify the entry point for the unmodified protocol, - and some of `modify_*` methods (the blanket implementations simply pass through the original messages). - -3. The `modify_*` methods can be called from any round, use [`BoxedRound::id`](`crate::protocol::BoxedRound::id`) - on the `round` argument to determine which round it is. - -4. In the `modify_*` methods, you can get the original typed message using the provided `deserializer` argument, - and create a new one using the `serializer`. - -5. You can get access to the typed `Round` object by using - [`BoxedRound::downcast_ref`](`crate::protocol::BoxedRound::downcast_ref`). - -6. Use [`MisbehavingEntryPoint`] parametrized by `Id`, the behavior type from step 1, and the type from step 2 - as the entry point of the new protocol. -*/ - -use alloc::{boxed::Box, collections::BTreeMap}; -use core::fmt::Debug; - -use rand_core::CryptoRngCore; - -use crate::protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, - LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessage, ReceiveError, Round, RoundId, - TransitionInfo, -}; - -/// A trait describing required properties for a behavior type. -pub trait Behavior: 'static + Debug + Send + Sync {} - -impl Behavior for T {} - -/// A trait defining a sequence of misbehaving rounds modifying or replacing the messages sent by some existing ones. -/// -/// Override one or more optional methods to modify the specific messages. -pub trait Misbehaving: 'static -where - Id: PartyId, - B: Behavior, -{ - /// The entry point of the wrapped rounds. - type EntryPoint: Debug + EntryPoint; - - /// Called after [`Round::make_echo_broadcast`](`crate::protocol::Round::make_echo_broadcast`) - /// and may modify its result. - /// - /// The default implementation passes through the original message. - #[allow(unused_variables)] - fn modify_echo_broadcast( - rng: &mut dyn CryptoRngCore, - round: &BoxedRound>::Protocol>, - behavior: &B, - format: &BoxedFormat, - echo_broadcast: EchoBroadcast, - ) -> Result { - Ok(echo_broadcast) - } - - /// Called after [`Round::make_normal_broadcast`](`crate::protocol::Round::make_normal_broadcast`) - /// and may modify its result. - /// - /// The default implementation passes through the original message. - #[allow(unused_variables)] - fn modify_normal_broadcast( - rng: &mut dyn CryptoRngCore, - round: &BoxedRound>::Protocol>, - behavior: &B, - format: &BoxedFormat, - normal_broadcast: NormalBroadcast, - ) -> Result { - Ok(normal_broadcast) - } - - /// Called after [`Round::make_direct_message`](`crate::protocol::Round::make_direct_message`) - /// and may modify its result. - /// - /// The default implementation passes through the original message. - #[allow(unused_variables, clippy::too_many_arguments)] - fn modify_direct_message( - rng: &mut dyn CryptoRngCore, - round: &BoxedRound>::Protocol>, - behavior: &B, - format: &BoxedFormat, - destination: &Id, - direct_message: DirectMessage, - artifact: Option, - ) -> Result<(DirectMessage, Option), LocalError> { - Ok((direct_message, artifact)) - } - - /// Called before [`Round::finalize`](`crate::protocol::Round::finalize`) - /// and may override its result. - /// - /// Return [`FinalizeOverride::UseDefault`] to use the existing `finalize()` - /// (the default behavior is to do that passing through `round`, `payloads`, and `artifacts`). - /// Otherwise finalize manually and return [`FinalizeOverride::Override`], - /// in which case the existing `finalize()` will not be called. - #[allow(unused_variables)] - fn override_finalize( - rng: &mut dyn CryptoRngCore, - round: BoxedRound>::Protocol>, - behavior: &B, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result>::Protocol>, LocalError> { - Ok(FinalizeOverride::UseDefault { - round, - payloads, - artifacts, - }) - } -} - -/// Possible return values for [`Misbehaving::override_finalize`]. -#[derive(Debug)] -pub enum FinalizeOverride> { - /// Use the existing [`Round::finalize`](`crate::protocol::Round::finalize`) with the given arguments. - UseDefault { - /// The round object to pass to `finalize()`. - round: BoxedRound, - /// The payloads map to pass to `finalize()`. - payloads: BTreeMap, - /// The artifacts map to pass to `finalize()`. - artifacts: BTreeMap, - }, - /// Finalize manually; the existing [`Round::finalize`](`crate::protocol::Round::finalize`) will not be called. - Override(FinalizeOutcome), -} - -/// The new entry point for the misbehaving rounds. -/// -/// Use as an entry point to run the session, with your ID, the behavior `B` and the misbehavior definition `M` set. -#[derive_where::derive_where(Debug)] -pub struct MisbehavingEntryPoint -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - entry_point: M::EntryPoint, - behavior: Option, -} - -impl MisbehavingEntryPoint -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - /// Creates an entry point for the misbehaving protocol using an entry point for the inner protocol. - pub fn new(entry_point: M::EntryPoint, behavior: Option) -> Self { - Self { entry_point, behavior } - } -} - -impl EntryPoint for MisbehavingEntryPoint -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - type Protocol = >::Protocol; - - fn entry_round_id() -> RoundId { - M::EntryPoint::entry_round_id() - } - - fn make_round( - self, - rng: &mut dyn CryptoRngCore, - shared_randomness: &[u8], - id: &Id, - ) -> Result, LocalError> { - let round = self.entry_point.make_round(rng, shared_randomness, id)?; - Ok(BoxedRound::new_dynamic(MisbehavingRound:: { - round, - behavior: self.behavior, - })) - } -} - -#[derive_where::derive_where(Debug)] -struct MisbehavingRound -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - round: BoxedRound>::Protocol>, - behavior: Option, -} - -impl MisbehavingRound -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - /// Wraps the outcome of the underlying Round into the MisbehavingRound structure. - fn map_outcome( - outcome: FinalizeOutcome>::Protocol>, - behavior: Option, - ) -> FinalizeOutcome>::Protocol> { - match outcome { - FinalizeOutcome::Result(result) => FinalizeOutcome::Result(result), - FinalizeOutcome::AnotherRound(round) => { - FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Self { round, behavior })) - } - } - } -} - -impl Round for MisbehavingRound -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - type Protocol = >::Protocol; - - fn transition_info(&self) -> TransitionInfo { - self.round.as_ref().transition_info() - } - - fn communication_info(&self) -> CommunicationInfo { - self.round.as_ref().communication_info() - } - - fn make_direct_message( - &self, - rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - let (direct_message, artifact) = self.round.as_ref().make_direct_message(rng, format, destination)?; - if let Some(behavior) = self.behavior.as_ref() { - M::modify_direct_message( - rng, - &self.round, - behavior, - format, - destination, - direct_message, - artifact, - ) - } else { - Ok((direct_message, artifact)) - } - } - - fn make_echo_broadcast( - &self, - rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { - let echo_broadcast = self.round.as_ref().make_echo_broadcast(rng, format)?; - if let Some(behavior) = self.behavior.as_ref() { - M::modify_echo_broadcast(rng, &self.round, behavior, format, echo_broadcast) - } else { - Ok(echo_broadcast) - } - } - - fn make_normal_broadcast( - &self, - rng: &mut dyn CryptoRngCore, - format: &BoxedFormat, - ) -> Result { - let normal_broadcast = self.round.as_ref().make_normal_broadcast(rng, format)?; - if let Some(behavior) = self.behavior.as_ref() { - M::modify_normal_broadcast(rng, &self.round, behavior, format, normal_broadcast) - } else { - Ok(normal_broadcast) - } - } - - fn receive_message( - &self, - format: &BoxedFormat, - from: &Id, - message: ProtocolMessage, - ) -> Result> { - self.round.as_ref().receive_message(format, from, message) - } - - fn finalize( - self: Box, - rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result, LocalError> { - let (round, payloads, artifacts) = if let Some(behavior) = self.behavior.as_ref() { - let result = M::override_finalize(rng, self.round, behavior, payloads, artifacts)?; - match result { - FinalizeOverride::UseDefault { - round, - payloads, - artifacts, - } => (round, payloads, artifacts), - FinalizeOverride::Override(outcome) => return Ok(Self::map_outcome(outcome, self.behavior)), - } - } else { - (self.round, payloads, artifacts) - }; - - let outcome = round.into_boxed().finalize(rng, payloads, artifacts)?; - Ok(Self::map_outcome(outcome, self.behavior)) - } -} diff --git a/manul/src/dev/run_sync.rs b/manul/src/dev/run_sync.rs index 0b4db10..5fd9c8f 100644 --- a/manul/src/dev/run_sync.rs +++ b/manul/src/dev/run_sync.rs @@ -1,4 +1,4 @@ -use alloc::{collections::BTreeMap, format, string::String, vec::Vec}; +use alloc::{boxed::Box, collections::BTreeMap, format, string::String, vec::Vec}; use rand::Rng; use rand_core::CryptoRngCore; @@ -15,7 +15,7 @@ use crate::{ enum State, SP: SessionParameters> { InProgress { - session: Session, + session: Box>, accum: RoundAccumulator, }, Finished(SessionReport), @@ -126,7 +126,7 @@ where session: new_session, cached_messages, } => { - session = new_session; + session = *new_session; accum = session.make_accumulator(); for message in cached_messages { @@ -147,7 +147,10 @@ where session.verifier(), session.round_id() ); - break State::InProgress { session, accum }; + break State::InProgress { + session: Box::new(session), + accum, + }; } CanFinalize::Never => { trace!( @@ -209,17 +212,15 @@ where states.insert(verifier, state); } - let messages_len = messages.len(); loop { // Pick a random message and deliver it let message = messages.pop(rng); debug!( - "Delivering message from {:?} to {:?} ({}/{})", + "Delivering message from {:?} to {:?} ({} more in the queue)", message.from, message.to, - messages_len - messages.len(), - messages_len + messages.len(), ); let state = states.remove(&message.to); if state.is_none() { @@ -239,7 +240,7 @@ where session.add_processed_message(&mut accum, processed)?; } - let (new_state, new_messages) = propagate(rng, session, accum)?; + let (new_state, new_messages) = propagate(rng, *session, accum)?; messages.extend(new_messages); new_state } else { diff --git a/manul/src/dev/session_parameters.rs b/manul/src/dev/session_parameters.rs index 16fc2ae..b26c8b9 100644 --- a/manul/src/dev/session_parameters.rs +++ b/manul/src/dev/session_parameters.rs @@ -100,7 +100,7 @@ impl digest::OutputSizeUser for TestHasher { /// An implementation of [`SessionParameters`] using the testing signer/verifier types. #[derive(Debug, Clone, Copy)] -pub struct TestSessionParams(core::marker::PhantomData); +pub struct TestSessionParams(core::marker::PhantomData); impl SessionParameters for TestSessionParams { type Signer = TestSigner; diff --git a/manul/src/dev/tokio.rs b/manul/src/dev/tokio.rs index e96da80..944456e 100644 --- a/manul/src/dev/tokio.rs +++ b/manul/src/dev/tokio.rs @@ -83,7 +83,6 @@ where EP: EntryPoint, SP: SessionParameters, SP::Signer: Send + Sync, - >::ProtocolError: Send + Sync, >::Result: Send, { let num_parties = entry_points.len(); diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 975fc15..f2bc7ea 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -11,25 +11,38 @@ to be executed by a [`Session`](`crate::session::Session`). For more details, see the documentation of the mentioned traits. */ -mod boxed_format; -mod boxed_round; +mod dyn_evidence; +mod dyn_round; mod errors; +mod evidence; mod message; +mod rng; mod round; mod round_id; +mod round_info; +mod wire_format; -pub use boxed_format::BoxedFormat; -pub use boxed_round::BoxedRound; -pub use errors::{ - DeserializationError, DirectMessageError, EchoBroadcastError, LocalError, MessageValidationError, - NormalBroadcastError, ProtocolValidationError, ReceiveError, RemoteError, +pub use dyn_round::BoxedRound; +pub use errors::{LocalError, ReceiveError, RemoteError}; +pub use evidence::{ + EvidenceError, EvidenceMessages, NoProtocolErrors, ProtocolError, RequiredMessageParts, RequiredMessages, }; -pub use message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessage, ProtocolMessagePart}; pub use round::{ - Artifact, CommunicationInfo, EchoRoundParticipation, EntryPoint, FinalizeOutcome, NoProtocolErrors, PartyId, - Payload, Protocol, ProtocolError, RequiredMessageParts, RequiredMessages, Round, + CommunicationInfo, EchoRoundParticipation, EntryPoint, FinalizeOutcome, NoArtifact, NoMessage, PartyId, Protocol, + ProtocolMessage, Round, }; -pub use round_id::{RoundId, TransitionInfo}; +pub use round_id::{RoundId, RoundNum, TransitionInfo}; +pub use round_info::RoundInfo; -pub(crate) use errors::ReceiveErrorType; -pub(crate) use message::ProtocolMessagePartHashable; +pub(crate) use dyn_evidence::{BoxedProtocolError, SerializedProtocolError}; +pub(crate) use dyn_round::{Artifact, BoxedReceiveError, BoxedTypedRound, DynRound, Payload}; +pub(crate) use evidence::EvidenceProtocolMessage; +pub(crate) use message::{ + DirectMessage, DirectMessageError, DynProtocolMessage, EchoBroadcast, EchoBroadcastError, NormalBroadcast, + NormalBroadcastError, ProtocolMessagePart, ProtocolMessagePartHashable, +}; +pub(crate) use rng::BoxedRng; +pub(crate) use round::NoType; +pub(crate) use round_id::GroupNum; +pub(crate) use round_info::DynRoundInfo; +pub(crate) use wire_format::BoxedFormat; diff --git a/manul/src/protocol/boxed_round.rs b/manul/src/protocol/boxed_round.rs deleted file mode 100644 index 989064f..0000000 --- a/manul/src/protocol/boxed_round.rs +++ /dev/null @@ -1,74 +0,0 @@ -use alloc::{boxed::Box, format}; - -use super::{ - errors::LocalError, - round::{PartyId, Protocol, Round}, - round_id::RoundId, -}; - -/// A wrapped new round that may be returned by [`Round::finalize`] -/// or [`EntryPoint::make_round`](`crate::protocol::EntryPoint::make_round`). -#[derive_where::derive_where(Debug)] -pub struct BoxedRound>(Box>); - -impl> BoxedRound { - /// Wraps an object implementing the dynamic round trait ([`Round`](`crate::protocol::Round`)). - pub fn new_dynamic>(round: R) -> Self { - Self(Box::new(round)) - } - - pub(crate) fn as_ref(&self) -> &dyn Round { - self.0.as_ref() - } - - pub(crate) fn into_boxed(self) -> Box> { - self.0 - } - - fn boxed_type_is(&self) -> bool { - core::any::TypeId::of::() == self.0.get_type_id() - } - - /// Attempts to extract an object of a concrete type, preserving the original on failure. - pub fn try_downcast>(self) -> Result { - if self.boxed_type_is::() { - // Safety: This is safe since we just checked that we are casting to the correct type. - let boxed_downcast = unsafe { Box::::from_raw(Box::into_raw(self.0) as *mut T) }; - Ok(*boxed_downcast) - } else { - Err(self) - } - } - - /// Attempts to extract an object of a concrete type. - /// - /// Fails if the wrapped type is not `T`. - pub fn downcast>(self) -> Result { - self.try_downcast() - .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) - } - - /// Attempts to provide a reference to an object of a concrete type. - /// - /// Fails if the wrapped type is not `T`. - pub fn downcast_ref>(&self) -> Result<&T, LocalError> { - if self.boxed_type_is::() { - let ptr: *const dyn Round = self.0.as_ref(); - // Safety: This is safe since we just checked that we are casting to the correct type. - Ok(unsafe { &*(ptr as *const T) }) - } else { - Err(LocalError::new(format!( - "Failed to downcast into type {}", - core::any::type_name::() - ))) - } - } - - /// Returns the round's ID. - pub fn id(&self) -> RoundId { - // This constructs a new `TransitionInfo` object, so calling this method inside `Session` - // has mild performance drawbacks. - // This is mostly exposed for the sake of users writing `Misbehave` impls for testing. - self.0.transition_info().id() - } -} diff --git a/manul/src/protocol/dyn_evidence.rs b/manul/src/protocol/dyn_evidence.rs new file mode 100644 index 0000000..d0689ed --- /dev/null +++ b/manul/src/protocol/dyn_evidence.rs @@ -0,0 +1,76 @@ +use alloc::{boxed::Box, string::String}; +use core::fmt::Debug; + +use serde::{Deserialize, Serialize}; +use serde_encoded_bytes::{Base64, SliceLike}; + +use super::{ + errors::LocalError, + evidence::{ProtocolError, RequiredMessages}, + round::Round, + round_id::{GroupNum, RoundId}, + wire_format::BoxedFormat, +}; +use crate::session::DeserializationError; + +pub(crate) trait DynProtocolError: Debug { + fn description(&self) -> String; + fn serialize(self: Box, format: &BoxedFormat) -> Result; +} + +impl> DynProtocolError for T { + fn description(&self) -> String { + self.description() + } + + fn serialize(self: Box, format: &BoxedFormat) -> Result { + format.serialize(*self).map(SerializedProtocolError) + } +} + +#[derive(Debug)] +pub(crate) struct BoxedProtocolError { + required_messages: RequiredMessages, + error: Box + Send + Sync>, +} + +impl BoxedProtocolError { + pub fn new>(error: R::ProtocolError, round_id: &RoundId) -> Self { + let required_messages = error.required_messages(round_id); + Self { + required_messages, + error: Box::new(error), + } + } + + pub fn as_ref(&self) -> &dyn DynProtocolError { + self.error.as_ref() + } + + pub fn into_boxed(self) -> Box> { + self.error + } + + pub fn group_under(self, group_num: GroupNum) -> Self { + Self { + required_messages: self.required_messages.group_under(group_num), + error: self.error, + } + } + + pub fn required_messages(&self) -> &RequiredMessages { + &self.required_messages + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct SerializedProtocolError(#[serde(with = "SliceLike::")] Box<[u8]>); + +impl SerializedProtocolError { + pub fn deserialize>( + &self, + format: &BoxedFormat, + ) -> Result { + format.deserialize::(&self.0) + } +} diff --git a/manul/src/protocol/dyn_round.rs b/manul/src/protocol/dyn_round.rs new file mode 100644 index 0000000..09ca498 --- /dev/null +++ b/manul/src/protocol/dyn_round.rs @@ -0,0 +1,395 @@ +use alloc::{boxed::Box, collections::BTreeMap, format}; +use core::{ + any::{Any, TypeId}, + fmt::Debug, +}; + +use rand_core::CryptoRngCore; + +use super::{ + dyn_evidence::BoxedProtocolError, + errors::{LocalError, ReceiveError, RemoteError}, + message::{ + DirectMessage, DirectMessageError, DynProtocolMessage, EchoBroadcast, EchoBroadcastError, NormalBroadcast, + NormalBroadcastError, ProtocolMessagePart, + }, + rng::BoxedRng, + round::{CommunicationInfo, FinalizeOutcome, NoMessage, NoType, PartyId, Protocol, ProtocolMessage, Round}, + round_id::{GroupNum, RoundId, TransitionInfo}, + wire_format::BoxedFormat, +}; +use crate::{session::EchoRoundError, utils::DynTypeId}; + +#[derive(Debug)] +pub(crate) struct Payload(pub Box); + +impl Payload { + /// Creates a new payload. + pub fn new(payload: T) -> Self { + Self(Box::new(payload)) + } + + /// Creates an empty payload. + /// + /// Use it in [`Round::receive_message`] if it does not need to create payloads. + pub fn empty() -> Self { + Self::new(()) + } + + /// Attempts to downcast back to the concrete type. + pub fn downcast(self) -> Result { + Ok(*(self.0.downcast::().map_err(|_| { + LocalError::new(format!( + "Failed to downcast Payload into {}", + core::any::type_name::() + )) + })?)) + } +} + +#[derive(Debug)] +pub(crate) struct Artifact(pub Box); + +impl Artifact { + /// Creates a new artifact. + pub fn new(artifact: T) -> Self { + Self(Box::new(artifact)) + } + + /// Attempts to downcast back to the concrete type. + pub fn downcast(self) -> Result { + Ok(*(self.0.downcast::().map_err(|_| { + LocalError::new(format!( + "Failed to downcast Artifact into {}", + core::any::type_name::() + )) + })?)) + } +} + +pub(crate) trait DynRound: 'static + Debug + Send + Sync + DynTypeId { + type Protocol: Protocol; + + fn transition_info(&self) -> TransitionInfo; + + fn communication_info(&self) -> CommunicationInfo; + + fn make_direct_message( + &self, + #[allow(unused_variables)] rng: &mut dyn CryptoRngCore, + #[allow(unused_variables)] format: &BoxedFormat, + #[allow(unused_variables)] destination: &Id, + ) -> Result<(DirectMessage, Artifact), LocalError>; + + fn make_echo_broadcast( + &self, + #[allow(unused_variables)] rng: &mut dyn CryptoRngCore, + #[allow(unused_variables)] format: &BoxedFormat, + ) -> Result; + + fn make_normal_broadcast( + &self, + #[allow(unused_variables)] rng: &mut dyn CryptoRngCore, + #[allow(unused_variables)] format: &BoxedFormat, + ) -> Result; + + fn receive_message( + &self, + format: &BoxedFormat, + from: &Id, + message: DynProtocolMessage, + ) -> Result>; + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, LocalError>; +} + +#[derive(Debug)] +struct RoundWrapper { + round: R, +} + +impl RoundWrapper { + pub fn new(round: R) -> Self { + Self { round } + } + + pub fn into_inner(self) -> R { + self.round + } +} + +impl DynRound for RoundWrapper +where + Id: PartyId, + R: Round, +{ + type Protocol = >::Protocol; + + fn transition_info(&self) -> TransitionInfo { + self.round.transition_info() + } + + fn communication_info(&self) -> CommunicationInfo { + self.round.communication_info() + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + destination: &Id, + ) -> Result<(DirectMessage, Artifact), LocalError> { + let (direct_message, artifact) = self.round.make_direct_message(&mut BoxedRng(rng), destination)?; + Ok((DirectMessage::new(format, direct_message)?, Artifact::new(artifact))) + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + ) -> Result { + let echo_broadcast = self.round.make_echo_broadcast(&mut BoxedRng(rng))?; + EchoBroadcast::new(format, echo_broadcast) + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + ) -> Result { + let normal_broadcast = self.round.make_normal_broadcast(&mut BoxedRng(rng))?; + NormalBroadcast::new(format, normal_broadcast) + } + + fn receive_message( + &self, + format: &BoxedFormat, + from: &Id, + message: DynProtocolMessage, + ) -> Result> { + let direct_message = if let Some(direct_message) = NoMessage::new_if_equals::() { + message.direct_message.assert_is_none()?; + direct_message + } else { + message.direct_message.deserialize::(format)? + }; + + let echo_broadcast = if let Some(echo_broadcast) = NoMessage::new_if_equals::() { + message.echo_broadcast.assert_is_none()?; + echo_broadcast + } else { + message.echo_broadcast.deserialize::(format)? + }; + + let normal_broadcast = if let Some(normal_broadcast) = NoMessage::new_if_equals::() { + message.normal_broadcast.assert_is_none()?; + normal_broadcast + } else { + message.normal_broadcast.deserialize::(format)? + }; + + let payload = self + .round + .receive_message( + from, + ProtocolMessage { + direct_message, + echo_broadcast, + normal_broadcast, + }, + ) + .map_err(|error| BoxedReceiveError::new(error, &self.transition_info().id))?; + + Ok(Payload::new(payload)) + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, LocalError> { + let payloads = payloads + .into_iter() + .map(|(id, payload)| payload.downcast::().map(|payload| (id, payload))) + .collect::, _>>()?; + let artifacts = artifacts + .into_iter() + .map(|(id, artifact)| artifact.downcast::().map(|artifact| (id, artifact))) + .collect::, _>>()?; + + self.round.finalize(&mut BoxedRng(rng), payloads, artifacts) + } +} + +/// A wrapped new round that may be returned by [`Round::finalize`] +/// or [`EntryPoint::make_round`](`crate::protocol::EntryPoint::make_round`). +#[derive_where::derive_where(Debug)] +pub struct BoxedRound>(BoxedRoundEnum); + +#[derive_where::derive_where(Debug)] +enum BoxedRoundEnum> { + Dynamic(Box>), + Typed(BoxedTypedRound), +} + +impl> BoxedRound { + /// Wraps an object implementing the typed round trait ([`Round`](`crate::protocol::Round`)). + pub fn new>(round: R) -> Self { + Self(BoxedRoundEnum::Typed(BoxedTypedRound::new(round))) + } + + pub(crate) fn new_dynamic>(round: R) -> Self { + Self(BoxedRoundEnum::Dynamic(Box::new(round))) + } + + pub(crate) fn as_ref(&self) -> &dyn DynRound { + match &self.0 { + BoxedRoundEnum::Dynamic(boxed) => boxed.as_ref(), + BoxedRoundEnum::Typed(boxed) => boxed.as_ref(), + } + } + + pub(crate) fn into_inner(self) -> Box> { + match self.0 { + BoxedRoundEnum::Dynamic(boxed) => boxed, + BoxedRoundEnum::Typed(boxed) => boxed.into_inner(), + } + } + + pub(crate) fn as_typed(&self) -> Result<&BoxedTypedRound, LocalError> { + match &self.0 { + BoxedRoundEnum::Dynamic(_boxed) => { + Err(LocalError::new("Attempted to use a boxed dynamic round as a typed one")) + } + BoxedRoundEnum::Typed(boxed) => Ok(boxed), + } + } + + pub(crate) fn into_typed(self) -> Result, LocalError> { + match self.0 { + BoxedRoundEnum::Dynamic(_boxed) => { + Err(LocalError::new("Attempted to use a boxed dynamic round as a typed one")) + } + BoxedRoundEnum::Typed(boxed) => Ok(boxed), + } + } +} + +#[derive_where::derive_where(Debug)] +pub(crate) struct BoxedTypedRound>(Box>); + +impl> BoxedTypedRound { + pub fn new>(round: R) -> Self { + Self(Box::new(RoundWrapper::new(round))) + } + + pub(crate) fn as_ref(&self) -> &dyn DynRound { + self.0.as_ref() + } + + pub(crate) fn into_inner(self) -> Box> { + self.0 + } + + /// Returns the type ID of the encapsulated `Round` implementor. + pub(crate) fn type_id(&self) -> TypeId { + self.0.as_ref().get_type_id() + } + + /// Returns the type ID that [`type_id`] would return for an object created with [`new()`] + /// given a round of type `R`. + pub(crate) fn type_id_for>() -> TypeId { + TypeId::of::>() + } + + /// Attempts to extract an object of a concrete type, preserving the original on failure. + pub(crate) fn try_downcast>(self) -> Result { + if self.type_id() == TypeId::of::>() { + // Safety: This is safe since we just checked that we are casting to the correct type. + let boxed_downcast = + unsafe { Box::>::from_raw(Box::into_raw(self.0) as *mut RoundWrapper) }; + Ok((*boxed_downcast).into_inner()) + } else { + Err(self) + } + } + + /// Attempts to extract an object of a concrete type. + /// + /// Fails if the wrapped type is not `T`. + pub(crate) fn downcast>(self) -> Result { + self.try_downcast() + .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) + } +} + +#[derive(Debug)] +pub(crate) enum BoxedReceiveError { + Local(LocalError), + /// The given direct message cannot be deserialized. + InvalidDirectMessage(DirectMessageError), + /// The given echo broadcast cannot be deserialized. + InvalidEchoBroadcast(EchoBroadcastError), + /// The given normal broadcast cannot be deserialized. + InvalidNormalBroadcast(NormalBroadcastError), + /// A provable protocol error associated with the round. + Protocol(BoxedProtocolError), + /// An unprovable error. + Unprovable(RemoteError), + /// An error during an echo round. + Echo(Box>), +} + +impl BoxedReceiveError { + pub(crate) fn new>(error: ReceiveError, round_id: &RoundId) -> Self { + match error { + ReceiveError::Local(error) => Self::Local(error), + ReceiveError::Unprovable(error) => Self::Unprovable(error), + ReceiveError::Protocol(error) => Self::Protocol(BoxedProtocolError::new::(error, round_id)), + } + } + + pub(crate) fn group_under(self, group_num: GroupNum) -> Self { + if let Self::Protocol(error) = self { + Self::Protocol(error.group_under(group_num)) + } else { + self + } + } +} + +impl From for BoxedReceiveError { + fn from(error: LocalError) -> Self { + BoxedReceiveError::Local(error) + } +} + +impl From> for BoxedReceiveError { + fn from(error: BoxedProtocolError) -> Self { + BoxedReceiveError::Protocol(error) + } +} + +impl From for BoxedReceiveError { + fn from(error: DirectMessageError) -> Self { + BoxedReceiveError::InvalidDirectMessage(error) + } +} + +impl From for BoxedReceiveError { + fn from(error: EchoBroadcastError) -> Self { + BoxedReceiveError::InvalidEchoBroadcast(error) + } +} + +impl From for BoxedReceiveError { + fn from(error: NormalBroadcastError) -> Self { + BoxedReceiveError::InvalidNormalBroadcast(error) + } +} diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index 6a34d89..ad612ca 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -1,8 +1,7 @@ -use alloc::{boxed::Box, format, string::String}; +use alloc::string::String; use core::fmt::Debug; -use super::round::Protocol; -use crate::session::EchoRoundError; +use super::round::Round; /// An error indicating a local problem, most likely a misuse of the API or a bug in the code. #[derive(displaydoc::Display, Debug, Clone)] @@ -30,221 +29,46 @@ impl RemoteError { /// An error that can be returned from [`Round::receive_message`](`super::Round::receive_message`). #[derive(Debug)] -pub struct ReceiveError>(pub(crate) ReceiveErrorType); - -#[derive(Debug)] -pub(crate) enum ReceiveErrorType> { +pub enum ReceiveError + ?Sized> { /// A local error, indicating an implemenation bug or a misuse by the upper layer. Local(LocalError), - /// The given direct message cannot be deserialized. - InvalidDirectMessage(DirectMessageError), - /// The given echo broadcast cannot be deserialized. - InvalidEchoBroadcast(EchoBroadcastError), - /// The given normal broadcast cannot be deserialized. - InvalidNormalBroadcast(NormalBroadcastError), /// A provable error occurred. - Protocol(P::ProtocolError), + Protocol(R::ProtocolError), /// An unprovable error occurred. Unprovable(RemoteError), - // Note that this variant should not be instantiated by the user (a protocol author), - // so this whole enum is crate-private and the variants are created - // via constructors and From impls. - /// An echo round error occurred. - Echo(Box>), -} - -impl> ReceiveError { - /// A local error, indicating an implemenation bug or a misuse by the upper layer. - pub fn local(message: impl Into) -> Self { - Self(ReceiveErrorType::Local(LocalError::new(message.into()))) - } - - /// An unprovable error occurred. - pub fn unprovable(message: impl Into) -> Self { - Self(ReceiveErrorType::Unprovable(RemoteError::new(message.into()))) - } - - /// A provable error occurred. - pub fn protocol(error: P::ProtocolError) -> Self { - Self(ReceiveErrorType::Protocol(error)) - } - - /// Maps the error to a different protocol, given the mapping function for protocol errors. - pub(crate) fn map(self, f: F) -> ReceiveError - where - F: Fn(P::ProtocolError) -> T::ProtocolError, - T: Protocol, - { - ReceiveError(self.0.map::(f)) - } } -impl> ReceiveErrorType { - pub(crate) fn map(self, f: F) -> ReceiveErrorType +impl ReceiveError +where + R: Round, +{ + pub(crate) fn map(self, f: F) -> ReceiveError where - F: Fn(P::ProtocolError) -> T::ProtocolError, - T: Protocol, + F: Fn(R::ProtocolError) -> NR::ProtocolError, + NR: Round, { match self { - Self::Local(err) => ReceiveErrorType::Local(err), - Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), - Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), - Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), - Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), - Self::Echo(err) => ReceiveErrorType::Echo(err), - Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + Self::Local(err) => ReceiveError::Local(err), + Self::Unprovable(err) => ReceiveError::Unprovable(err), + Self::Protocol(err) => ReceiveError::Protocol(f(err)), } } } -impl From for ReceiveError +impl From for ReceiveError where - P: Protocol, + R: Round, { fn from(error: LocalError) -> Self { - Self(ReceiveErrorType::Local(error)) + Self::Local(error) } } -impl From for ReceiveError +impl From for ReceiveError where - P: Protocol, + R: Round, { fn from(error: RemoteError) -> Self { - Self(ReceiveErrorType::Unprovable(error)) - } -} - -impl From> for ReceiveError -where - P: Protocol, -{ - fn from(error: EchoRoundError) -> Self { - Self(ReceiveErrorType::Echo(Box::new(error))) - } -} - -impl From for ReceiveError -where - P: Protocol, -{ - fn from(error: DirectMessageError) -> Self { - Self(ReceiveErrorType::InvalidDirectMessage(error)) - } -} - -impl From for ReceiveError -where - P: Protocol, -{ - fn from(error: EchoBroadcastError) -> Self { - Self(ReceiveErrorType::InvalidEchoBroadcast(error)) - } -} - -impl From for ReceiveError -where - P: Protocol, -{ - fn from(error: NormalBroadcastError) -> Self { - Self(ReceiveErrorType::InvalidNormalBroadcast(error)) - } -} - -/// An error that can occur during the validation of an evidence of an invalid message. -#[derive(Debug, Clone)] -pub enum MessageValidationError { - /// Indicates a local problem, usually a bug in the library code. - Local(LocalError), - /// Indicates a problem with the evidence, for example the given round not sending such messages, - /// or the message actually deserializing successfully. - InvalidEvidence(String), -} - -/// An error that can be returned during deserialization error. -#[derive(displaydoc::Display, Debug, Clone)] -#[displaydoc("Deserialization error: {0}")] -pub struct DeserializationError(String); - -impl DeserializationError { - /// Creates a new deserialization error. - pub fn new(message: impl Into) -> Self { - Self(message.into()) - } -} - -impl From for MessageValidationError { - fn from(error: LocalError) -> Self { - Self::Local(error) - } -} - -/// An error that can occur during the validation of an evidence of a protocol error. -#[derive(Debug, Clone)] -pub enum ProtocolValidationError { - /// Indicates a local problem, usually a bug in the library code. - Local(LocalError), - /// Indicates a problem with the evidence, for example missing messages, - /// or messages that cannot be deserialized. - InvalidEvidence(String), -} - -// If fail to deserialize a message when validating the evidence -// it means that the evidence is invalid - a deserialization error would have been -// processed separately, generating its own evidence. -impl From for ProtocolValidationError { - fn from(error: DirectMessageError) -> Self { - Self::InvalidEvidence(format!("Failed to deserialize direct message: {error:?}")) - } -} - -impl From for ProtocolValidationError { - fn from(error: EchoBroadcastError) -> Self { - Self::InvalidEvidence(format!("Failed to deserialize echo broadcast: {error:?}")) - } -} - -impl From for ProtocolValidationError { - fn from(error: NormalBroadcastError) -> Self { - Self::InvalidEvidence(format!("Failed to deserialize normal broadcast: {error:?}")) - } -} - -impl From for ProtocolValidationError { - fn from(error: LocalError) -> Self { - Self::Local(error) - } -} - -/// An error during deserialization of a direct message. -#[derive(displaydoc::Display, Debug, Clone)] -#[displaydoc("Direct message error: {0}")] -pub struct DirectMessageError(String); - -impl From for DirectMessageError { - fn from(message: String) -> Self { - Self(message) - } -} - -/// An error during deserialization of an echo broadcast. -#[derive(displaydoc::Display, Debug, Clone)] -#[displaydoc("Echo broadcast error: {0}")] -pub struct EchoBroadcastError(String); - -impl From for EchoBroadcastError { - fn from(message: String) -> Self { - Self(message) - } -} - -/// An error during deserialization of a normal broadcast. -#[derive(displaydoc::Display, Debug, Clone)] -#[displaydoc("Normal broadcast error: {0}")] -pub struct NormalBroadcastError(String); - -impl From for NormalBroadcastError { - fn from(message: String) -> Self { - Self(message) + Self::Unprovable(error) } } diff --git a/manul/src/protocol/evidence.rs b/manul/src/protocol/evidence.rs new file mode 100644 index 0000000..305f4c8 --- /dev/null +++ b/manul/src/protocol/evidence.rs @@ -0,0 +1,400 @@ +use alloc::{ + collections::{BTreeMap, BTreeSet}, + format, + string::String, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use serde::{Deserialize, Serialize}; + +use super::{ + errors::LocalError, + message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}, + round::{PartyId, Protocol, Round}, + round_id::{GroupNum, RoundId, RoundNum}, + wire_format::BoxedFormat, +}; + +/// Describes provable errors triggered by an incoming message during protocol execution. +/// +/// Provable here means that we can create an evidence object entirely of messages signed by some party, +/// which, in combination, prove the party's malicious actions. +pub trait ProtocolError: 'static + Debug + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> { + /// The round where the described errors occur. + type Round: Round; + + /// A short description of the error, for logging purposes. + fn description(&self) -> String; + + /// Specifies the messages of the guilty party that need to be stored as the evidence + /// to prove its malicious behavior. + fn required_messages(&self, round_id: &RoundId) -> RequiredMessages; + + /// Returns `Ok(())` if the attached messages indeed prove that a malicious action happened. + /// + /// The signatures and metadata of the messages will be checked by the calling code, + /// the responsibility of this method is just to check the message contents. + /// + /// `messages` gives access to messages stored as the evidence; these include the parts + /// of the message that triggered the error, and possibly earlier messages + /// (if requested by [`ProtocolError::required_messages`]). + fn verify_evidence( + &self, + round_id: &RoundId, + from: &Id, + shared_randomness: &[u8], + shared_data: &<>::Protocol as Protocol>::SharedData, + messages: EvidenceMessages<'_, Id, Self::Round>, + ) -> Result<(), EvidenceError>; +} + +#[derive(Debug)] +pub(crate) struct EvidenceProtocolMessage { + pub(crate) direct_message: Option, + pub(crate) normal_broadcast: Option, + pub(crate) echo_broadcast: Option, +} + +/// The messages from the guilty party collected as an evidence of a provable error. +/// +/// The contents depend on what was requested by [`ProtocolError::required_messages`]. +#[derive(Debug)] +pub struct EvidenceMessages<'a, Id, R: Round> { + message: EvidenceProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, + format: &'a BoxedFormat, + phantom: PhantomData R>, +} + +impl<'a, Id, R> EvidenceMessages<'a, Id, R> +where + R: Round, +{ + pub(crate) fn new( + format: &'a BoxedFormat, + message: EvidenceProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, + ) -> Self { + Self { + format, + message, + previous_messages, + combined_echos, + phantom: PhantomData, + } + } +} + +impl<'a, Id, R> EvidenceMessages<'a, Id, R> +where + Id: PartyId, + R: Round, +{ + /// Returns a stored echo broadcast from a previous round. + pub fn previous_echo_broadcast>( + &self, + round_num: RoundNum, + ) -> Result { + // TODO: we can check here that the RoundInfo corresponding to `round_num` is of a correct type. + let message_parts = self.previous_messages.get(&RoundId::new(round_num)).ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Message parts for round {round_num} are not included in the evidence" + )) + })?; + message_parts + .echo_broadcast + .as_ref() + .ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Echo broadcast for round {round_num} is not included in the evidence" + )) + })? + .deserialize::(self.format) + .map_err(|error| { + EvidenceError::InvalidEvidence(format!( + "Failed to deserialize an echo broadcast for round {round_num}: {error}", + )) + }) + } + + /// Returns a stored normal broadcast from a previous round. + pub fn previous_normal_broadcast>( + &self, + round_num: RoundNum, + ) -> Result { + // TODO: we can check here that the RoundInfo corresponding to `round_num` is of a correct type. + let message_parts = self.previous_messages.get(&RoundId::new(round_num)).ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Message parts for round {round_num} are not included in the evidence" + )) + })?; + message_parts + .normal_broadcast + .as_ref() + .ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Normal broadcast for round {round_num} is not included in the evidence" + )) + })? + .deserialize::(self.format) + .map_err(|error| { + EvidenceError::InvalidEvidence(format!( + "Failed to deserialize a normal broadcast for round {round_num}: {error}", + )) + }) + } + + /// Returns a stored direct message from a previous round. + pub fn previous_direct_message>( + &self, + round_num: RoundNum, + ) -> Result { + // TODO: we can check here that the RoundInfo corresponding to `round_num` is of a correct type. + let message_parts = self.previous_messages.get(&RoundId::new(round_num)).ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Message parts for round {round_num} are not included in the evidence" + )) + })?; + message_parts + .direct_message + .as_ref() + .ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Direct message for round {round_num} is not included in the evidence" + )) + })? + .deserialize::(self.format) + .map_err(|error| { + EvidenceError::InvalidEvidence(format!( + "Failed to deserialize a normal broadcast for round {round_num}: {error}", + )) + }) + } + + /// Returns a map with echoed broadcasts from a previous round. + pub fn combined_echos>( + &self, + round_num: RoundNum, + ) -> Result, EvidenceError> { + let combined_echos = self + .combined_echos + .get(&RoundId::new(round_num)) + .ok_or_else(|| EvidenceError::InvalidEvidence(format!("Combined echos for round {round_num} not found")))?; + combined_echos + .iter() + .map(|(id, echo_broadcast)| { + echo_broadcast + .deserialize::(self.format) + .map_err(|error| { + EvidenceError::InvalidEvidence(format!( + "Failed to deserialize a direct message for round {round_num}: {error}", + )) + }) + .map(|echo_broadcast| (id.clone(), echo_broadcast)) + }) + .collect() + } + + /// Returns the stored direct message from the round that triggered the error. + pub fn direct_message(&self) -> Result { + self.message + .direct_message + .as_ref() + .ok_or_else(|| EvidenceError::InvalidEvidence("Direct message is not included in the evidence".into()))? + .deserialize::(self.format) + .map_err(|err| EvidenceError::InvalidEvidence(format!("Error deserializing direct message: {err}"))) + } + + /// Returns the stored echo broadcast from the round that triggered the error. + pub fn echo_broadcast(&self) -> Result { + self.message + .echo_broadcast + .as_ref() + .ok_or_else(|| EvidenceError::InvalidEvidence("Echo broadcast is not included in the evidence".into()))? + .deserialize::(self.format) + .map_err(|err| EvidenceError::InvalidEvidence(format!("Error deserializing echo broadcast: {err}"))) + } + + /// Returns the stored normal broadcast from the round that triggered the error. + pub fn normal_broadcast(&self) -> Result { + self.message + .normal_broadcast + .as_ref() + .ok_or_else(|| EvidenceError::InvalidEvidence("Normal broadcast is not included in the evidence".into()))? + .deserialize::(self.format) + .map_err(|err| EvidenceError::InvalidEvidence(format!("Error deserializing normal broadcast: {err}"))) + } + + pub(crate) fn into_round(self) -> EvidenceMessages<'a, Id, NR> + where + NR: Round< + Id, + EchoBroadcast = R::EchoBroadcast, + NormalBroadcast = R::NormalBroadcast, + DirectMessage = R::DirectMessage, + >, + { + EvidenceMessages:: { + message: self.message, + previous_messages: self.previous_messages, + combined_echos: self.combined_echos, + format: self.format, + phantom: PhantomData, + } + } +} + +/// A placeholder for [`Round::ProtocolError`] for the rounds that do not generate errors. +#[derive_where::derive_where(Clone)] +#[derive(Debug, Copy, Serialize, Deserialize)] +pub struct NoProtocolErrors(PhantomData R>); + +impl ProtocolError for NoProtocolErrors +where + Id: PartyId, + R: Round, +{ + type Round = R; + fn description(&self) -> String { + panic!("Methods of `NoProtocolErrors` should not be called during normal operation.") + } + fn required_messages(&self, _round_id: &RoundId) -> RequiredMessages { + panic!("Methods of `NoProtocolErrors` should not be called during normal operation.") + } + fn verify_evidence( + &self, + _round_id: &RoundId, + _from: &Id, + _shared_randomness: &[u8], + _shared_data: &<>::Protocol as Protocol>::SharedData, + _messages: EvidenceMessages<'_, Id, Self::Round>, + ) -> Result<(), EvidenceError> { + panic!("Methods of `NoProtocolErrors` should not be called during normal operation.") + } +} + +/// Declares which parts of the message from a round have to be stored to serve as the evidence of malicious behavior. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RequiredMessageParts { + pub(crate) echo_broadcast: bool, + pub(crate) normal_broadcast: bool, + pub(crate) direct_message: bool, +} + +impl RequiredMessageParts { + fn new(echo_broadcast: bool, normal_broadcast: bool, direct_message: bool) -> Self { + // We must require at least one part, otherwise this struct doesn't need to be created. + debug_assert!(echo_broadcast || normal_broadcast || direct_message); + Self { + echo_broadcast, + normal_broadcast, + direct_message, + } + } + + /// Store echo broadcast + pub fn echo_broadcast() -> Self { + Self::new(true, false, false) + } + + /// Store normal broadcast + pub fn normal_broadcast() -> Self { + Self::new(false, true, false) + } + + /// Store direct message + pub fn direct_message() -> Self { + Self::new(false, false, true) + } + + /// Store echo broadcast in addition to what is already stored. + pub fn and_echo_broadcast(&self) -> Self { + Self::new(true, self.normal_broadcast, self.direct_message) + } + + /// Store normal broadcast in addition to what is already stored. + pub fn and_normal_broadcast(&self) -> Self { + Self::new(self.echo_broadcast, true, self.direct_message) + } + + /// Store direct message in addition to what is already stored. + pub fn and_direct_message(&self) -> Self { + Self::new(self.echo_broadcast, self.normal_broadcast, true) + } +} + +/// Declares which messages from this and previous rounds +/// have to be stored to serve as the evidence of malicious behavior. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RequiredMessages { + pub(crate) this_round: RequiredMessageParts, + pub(crate) previous_rounds: Option>, + pub(crate) combined_echos: Option>, +} + +impl RequiredMessages { + /// The general case constructor. + /// + /// `this_round` specifies the message parts to be stored from the message that triggered the error. + /// + /// `previous_rounds` specifies, optionally, if any message parts from the previous rounds need to be included. + /// + /// `combined_echos` specifies, optionally, if any echoed broadcasts need to be included. + /// The combined echos are echo broadcasts sent by a party during the echo round, + /// where it bundles all the received broadcasts and sends them back to everyone. + /// That is, they will include the echo broadcasts from all other nodes signed by the guilty party. + pub fn new( + this_round: RequiredMessageParts, + previous_rounds: Option>, + combined_echos: Option>, + ) -> Self { + Self { + this_round, + previous_rounds, + combined_echos, + } + } + + pub(crate) fn group_under(self, group_num: GroupNum) -> Self { + let previous_rounds = self.previous_rounds.map(|previous_rounds| { + previous_rounds + .into_iter() + .map(|(round_id, required)| (round_id.group_under(group_num), required)) + .collect() + }); + + let combined_echos = self.combined_echos.map(|combined_echos| { + combined_echos + .into_iter() + .map(|round_id| round_id.group_under(group_num)) + .collect() + }); + + RequiredMessages { + this_round: self.this_round, + previous_rounds, + combined_echos, + } + } +} + +/// An error that can occur during the validation of an evidence of a protocol error. +#[derive(Debug, Clone)] +pub enum EvidenceError { + /// Indicates a local problem, usually a bug in the library code. + Local(LocalError), + /// The evidence is improperly constructed + /// + /// This can indicate many things, such as: messages missing, invalid signatures, invalid messages, + /// the messages not actually proving the malicious behavior. + /// See the attached description for details. + InvalidEvidence(String), +} + +impl From for EvidenceError { + fn from(error: LocalError) -> Self { + Self::Local(error) + } +} diff --git a/manul/src/protocol/message.rs b/manul/src/protocol/message.rs index 8c9e9f8..09f3cf1 100644 --- a/manul/src/protocol/message.rs +++ b/manul/src/protocol/message.rs @@ -1,40 +1,29 @@ -use alloc::string::{String, ToString}; +use alloc::{ + boxed::Box, + string::{String, ToString}, +}; use digest::Digest; use serde::{Deserialize, Serialize}; +use serde_encoded_bytes::{Base64, SliceLike}; -use super::{ - errors::{DirectMessageError, EchoBroadcastError, LocalError, MessageValidationError, NormalBroadcastError}, - BoxedFormat, -}; - -mod private { - use alloc::boxed::Box; - use serde::{Deserialize, Serialize}; - use serde_encoded_bytes::{Base64, SliceLike}; +use super::wire_format::BoxedFormat; +use crate::protocol::{EvidenceError, LocalError, NoMessage, NoType}; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] - pub struct MessagePayload(#[serde(with = "SliceLike::")] pub Box<[u8]>); - - impl AsRef<[u8]> for MessagePayload { - fn as_ref(&self) -> &[u8] { - &self.0 - } - } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct MessagePayload(#[serde(with = "SliceLike::")] pub Box<[u8]>); - pub trait ProtocolMessageWrapper: Sized { - fn new_inner(maybe_message: Option) -> Self; - fn maybe_message(&self) -> &Option; +impl AsRef<[u8]> for MessagePayload { + fn as_ref(&self) -> &[u8] { + &self.0 } } -use private::{MessagePayload, ProtocolMessageWrapper}; - /// A serialized part of the protocol message. /// /// These would usually be generated separately by the round, but delivered together to /// [`Round::receive_message`](`crate::protocol::Round::receive_message`). -pub trait ProtocolMessagePart: ProtocolMessageWrapper { +pub(crate) trait ProtocolMessagePart: Sized { /// The error specific to deserializing this message. /// /// Used to distinguish which deserialization failed in @@ -42,6 +31,14 @@ pub trait ProtocolMessagePart: ProtocolMessageWrapper { /// and store the corresponding message in the evidence. type Error: From; + // Alternatively, we could not use an `Option`, but instead just serialize `NoMessage`. + // Since it produces the same serialization as any other empty type, and the user may use one of those + // as a message part type, there would be a possibility of a false positive on deserialization. + // So it's safer to make `NoMessage` a special case of empty message. + fn new_inner(maybe_message: Option) -> Self; + + fn maybe_message(&self) -> &Option; + /// Creates an empty message. /// /// Use in case the round does not send a message of this type. @@ -54,8 +51,11 @@ pub trait ProtocolMessagePart: ProtocolMessageWrapper { where T: 'static + Serialize, { - let payload = MessagePayload(format.serialize(message)?); - Ok(Self::new_inner(Some(payload))) + Ok(if NoMessage::equals::() { + Self::none() + } else { + Self::new_inner(Some(MessagePayload(format.serialize(message)?))) + }) } /// Returns `true` if this is an empty message. @@ -79,11 +79,11 @@ pub trait ProtocolMessagePart: ProtocolMessageWrapper { /// This is intended to be used in the implementations of /// [`Protocol::verify_direct_message_is_invalid`](`crate::protocol::Protocol::verify_direct_message_is_invalid`) or /// [`Protocol::verify_echo_broadcast_is_invalid`](`crate::protocol::Protocol::verify_echo_broadcast_is_invalid`). - fn verify_is_not<'de, T: Deserialize<'de>>(&'de self, format: &BoxedFormat) -> Result<(), MessageValidationError> { + fn verify_is_not<'de, T: 'static + Deserialize<'de>>(&'de self, format: &BoxedFormat) -> Result<(), EvidenceError> { if self.deserialize::(format).is_err() { Ok(()) } else { - Err(MessageValidationError::InvalidEvidence( + Err(EvidenceError::InvalidEvidence( "Message deserialized successfully, as expected by the protocol".into(), )) } @@ -94,11 +94,11 @@ pub trait ProtocolMessagePart: ProtocolMessageWrapper { /// This is intended to be used in the implementations of /// [`Protocol::verify_direct_message_is_invalid`](`crate::protocol::Protocol::verify_direct_message_is_invalid`) or /// [`Protocol::verify_echo_broadcast_is_invalid`](`crate::protocol::Protocol::verify_echo_broadcast_is_invalid`). - fn verify_is_some(&self) -> Result<(), MessageValidationError> { + fn verify_is_some(&self) -> Result<(), EvidenceError> { if self.maybe_message().is_some() { Ok(()) } else { - Err(MessageValidationError::InvalidEvidence( + Err(EvidenceError::InvalidEvidence( "The payload is `None`, as expected by the protocol".into(), )) } @@ -107,13 +107,18 @@ pub trait ProtocolMessagePart: ProtocolMessageWrapper { /// Deserializes the message into `T`. fn deserialize<'de, T>(&'de self, format: &BoxedFormat) -> Result where - T: Deserialize<'de>, + T: 'static + Deserialize<'de>, { - let payload = self - .maybe_message() - .as_ref() - .ok_or_else(|| "The payload is `None` and cannot be deserialized".into())?; - format.deserialize(&payload.0).map_err(|err| err.to_string().into()) + match (self.maybe_message().as_ref(), NoMessage::new_if_equals::()) { + (Some(payload), None) => format.deserialize(&payload.0).map_err(|err| err.to_string().into()), + (None, Some(no_message)) => Ok(no_message), + (Some(_payload), Some(_no_message)) => Err("Got a non-empty payload when no message part was expected" + .to_string() + .into()), + (None, None) => Err("Got an empty payload when a message part was expected" + .to_string() + .into()), + } } } @@ -155,9 +160,11 @@ impl ProtocolMessagePartHashable for T {} /// A serialized direct message. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct DirectMessage(Option); +pub(crate) struct DirectMessage(Option); + +impl ProtocolMessagePart for DirectMessage { + type Error = DirectMessageError; -impl ProtocolMessageWrapper for DirectMessage { fn new_inner(maybe_message: Option) -> Self { Self(maybe_message) } @@ -171,15 +178,13 @@ impl HasPartKind for DirectMessage { const KIND: PartKind = PartKind::DirectMessage; } -impl ProtocolMessagePart for DirectMessage { - type Error = DirectMessageError; -} - /// A serialized echo broadcast. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct EchoBroadcast(Option); +pub(crate) struct EchoBroadcast(Option); + +impl ProtocolMessagePart for EchoBroadcast { + type Error = EchoBroadcastError; -impl ProtocolMessageWrapper for EchoBroadcast { fn new_inner(maybe_message: Option) -> Self { Self(maybe_message) } @@ -193,15 +198,13 @@ impl HasPartKind for EchoBroadcast { const KIND: PartKind = PartKind::EchoBroadcast; } -impl ProtocolMessagePart for EchoBroadcast { - type Error = EchoBroadcastError; -} - /// A serialized regular (non-echo) broadcast. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct NormalBroadcast(Option); +pub(crate) struct NormalBroadcast(Option); + +impl ProtocolMessagePart for NormalBroadcast { + type Error = NormalBroadcastError; -impl ProtocolMessageWrapper for NormalBroadcast { fn new_inner(maybe_message: Option) -> Self { Self(maybe_message) } @@ -215,13 +218,9 @@ impl HasPartKind for NormalBroadcast { const KIND: PartKind = PartKind::NormalBroadcast; } -impl ProtocolMessagePart for NormalBroadcast { - type Error = NormalBroadcastError; -} - /// A bundle containing the message parts for one round. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ProtocolMessage { +pub(crate) struct DynProtocolMessage { /// The echo-broadcased message part. pub echo_broadcast: EchoBroadcast, /// The message part broadcasted without additional verification. @@ -229,3 +228,36 @@ pub struct ProtocolMessage { /// The message part sent directly to one node. pub direct_message: DirectMessage, } + +/// An error during deserialization of a direct message. +#[derive(displaydoc::Display, Debug, Clone)] +#[displaydoc("Direct message error: {0}")] +pub(crate) struct DirectMessageError(String); + +impl From for DirectMessageError { + fn from(message: String) -> Self { + Self(message) + } +} + +/// An error during deserialization of an echo broadcast. +#[derive(displaydoc::Display, Debug, Clone)] +#[displaydoc("Echo broadcast error: {0}")] +pub(crate) struct EchoBroadcastError(String); + +impl From for EchoBroadcastError { + fn from(message: String) -> Self { + Self(message) + } +} + +/// An error during deserialization of a normal broadcast. +#[derive(displaydoc::Display, Debug, Clone)] +#[displaydoc("Normal broadcast error: {0}")] +pub(crate) struct NormalBroadcastError(String); + +impl From for NormalBroadcastError { + fn from(message: String) -> Self { + Self(message) + } +} diff --git a/manul/src/protocol/rng.rs b/manul/src/protocol/rng.rs new file mode 100644 index 0000000..ed835fb --- /dev/null +++ b/manul/src/protocol/rng.rs @@ -0,0 +1,23 @@ +use rand_core::CryptoRngCore; + +/// Since object-safe trait methods cannot take `impl CryptoRngCore` arguments, +/// this structure wraps the dynamic object and exposes a `CryptoRngCore` interface, +/// to be passed to statically typed round methods. +pub(crate) struct BoxedRng<'a>(pub(crate) &'a mut dyn CryptoRngCore); + +impl rand_core::CryptoRng for BoxedRng<'_> {} + +impl rand_core::RngCore for BoxedRng<'_> { + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest) + } + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.0.try_fill_bytes(dest) + } +} diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 4bba256..084663c 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -1,24 +1,248 @@ use alloc::{ boxed::Box, collections::{BTreeMap, BTreeSet}, - format, -}; -use core::{ - any::Any, - fmt::{Debug, Display}, }; +use core::{any::TypeId, fmt::Debug, marker::PhantomData}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use super::{ - boxed_format::BoxedFormat, - boxed_round::BoxedRound, - errors::{LocalError, MessageValidationError, ProtocolValidationError, ReceiveError}, - message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessage, ProtocolMessagePart}, + dyn_round::BoxedRound, + errors::{LocalError, ReceiveError}, + evidence::ProtocolError, round_id::{RoundId, TransitionInfo}, + round_info::RoundInfo, }; +pub(crate) trait NoType: 'static + Sized { + fn new() -> Self; + + fn equals() -> bool { + TypeId::of::() == TypeId::of::() + } + + fn new_if_equals() -> Option { + if Self::equals::() { + let boxed = Box::new(Self::new()); + // SAFETY: can cast since we checked that T == NoMessage + let boxed_downcast = unsafe { Box::::from_raw(Box::into_raw(boxed) as *mut T) }; + Some(*boxed_downcast) + } else { + None + } + } +} + +/// A placeholder type for [`Round::DirectMessage`], [`Round::NormalBroadcast`], and [`Round::EchoBroadcast`] +/// indicating that the round does not send corresponding message parts. +// `PhantomData` is here to make it un-constructable by an external user. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct NoMessage(PhantomData<()>); + +impl NoType for NoMessage { + fn new() -> Self { + Self(PhantomData) + } +} + +/// A placeholder type for [`Round::DirectMessage`], [`Round::NormalBroadcast`], and [`Round::EchoBroadcast`] +/// indicating that the round does not send corresponding message parts. +// `PhantomData` is here to make it un-constructable by an external user. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct NoArtifact(PhantomData<()>); + +impl NoType for NoArtifact { + fn new() -> Self { + Self(PhantomData) + } +} + +/// A structure encapsulating different parts of a message from a single node. +#[derive(Debug)] +pub struct ProtocolMessage + ?Sized> { + /// The part of the message specific for each destination. + pub direct_message: R::DirectMessage, + /// The part of the message that will be additionally echo-broadcasted to ensure every receiver + /// gets the same data. + pub echo_broadcast: R::EchoBroadcast, + /// The part of the message that will be sent to all destinations. + pub normal_broadcast: R::NormalBroadcast, +} + +/// A type representing a single round of a protocol. +/// +/// The way a round will be used by an external caller: +/// - create messages to send out (by calling [`make_direct_message`](`Self::make_direct_message`), +/// [`make_normal_broadcast`](`Self::make_normal_broadcast`), +/// and [`make_echo_broadcast`](`Self::make_echo_broadcast`)); +/// - process received messages from other nodes (by calling [`receive_message`](`Self::receive_message`)); +/// - attempt to finalize (by calling [`finalize`](`Self::finalize`)) to produce the next round, or return a result. +pub trait Round: 'static + Debug + Send + Sync { + /// The protocol this round is a part of. + type Protocol: Protocol; + + /// The provable error type that can be returned on receiving a message. + /// + /// If this round does not generate errors, [`NoProtocolErrors`](`crate::protocol::NoProtocolErrors`) + /// can be used here. + type ProtocolError: ProtocolError; + + /// Returns the information about the position of this round in the state transition graph. + /// + /// See [`TransitionInfo`] documentation for more details. + fn transition_info(&self) -> TransitionInfo; + + /// Returns the information about the communication this rounds engages in with other nodes. + /// + /// See [`CommunicationInfo`] documentation for more details. + fn communication_info(&self) -> CommunicationInfo; + + /// The part of the message specific for each destination. + /// + /// Set to [`NoMessage`] if the round does not use this part of the message. + type DirectMessage: 'static + Serialize + for<'de> Deserialize<'de>; + + /// The part of the message that will be sent to all destinations. + /// + /// Set to [`NoMessage`] if the round does not use this part of the message. + type NormalBroadcast: 'static + Serialize + for<'de> Deserialize<'de>; + + /// The part of the message that will be additionally echo-broadcasted to ensure every receiver + /// gets the same data. + /// + /// Set to [`NoMessage`] if the round does not use this part of the message. + type EchoBroadcast: 'static + Serialize + for<'de> Deserialize<'de>; + + /// Message payload created in [`Self::receive_message`]. + /// + /// [`Self::Payload`]s are created as the output of processing an incoming message. + /// When a [`Round`] finalizes, all the `Payload`s received during the round are made available + /// and can be used to decide what to do next (next round? return a final result?). + /// Payloads are not sent to other nodes. + type Payload: Send + Sync; + + /// Associated data created alongside a message in [`Self::make_direct_message`]. + /// + /// [`Self::Artifact`]s are local to the participant that created it and are usually containers + /// for intermediary secrets and/or dynamic parameters needed in subsequent stages of the protocol. + /// Artifacts are never sent over the wire; they are made available to [`Self::finalize`] + /// for the participant, delivered in the form of a `BTreeMap` + /// where the key is the destination id of the participant to whom the direct message was sent. + /// + /// Set to [`NoArtifact`] if [`Self::DirectMessage`] is [`NoMessage`]. + type Artifact: 'static + Send + Sync; + + /// Returns the direct message to the given destination and (maybe) an accompanying artifact. + /// + /// In some protocols, when a message to another node is created, there is some associated information + /// that needs to be retained for later (randomness, proofs of knowledge, and so on). + /// These should be put in an [`Self::Artifact`] and will be available + /// at the time of [`finalize`](`Self::finalize`). + /// + /// If this method is not implemented, [`Self::DirectMessage`] must be set to [`NoMessage`], + /// and [`Self::Artifact`] to [`NoArtifact`]. + #[allow(clippy::type_complexity)] + fn make_direct_message( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + #[allow(unused_variables)] destination: &Id, + ) -> Result<(Self::DirectMessage, Self::Artifact), LocalError> { + if let Some(message) = NoMessage::new_if_equals::() { + match NoArtifact::new_if_equals::() { + Some(artifact) => Ok((message, artifact)), + None => Err(LocalError::new( + "If `DirectMessage` is `NoMessage`, `Artifact` must be `NoArtifact`", + )), + } + } else if self.communication_info().message_destinations.is_empty() { + // TODO (#4): this branch could potentially be eliminated + Err(LocalError::new( + "`make_direct_message() called when the round does not send messages - internal error", + )) + } else { + Err(LocalError::new(concat!( + "If `DirectMessage` is not `NoMessage`, and the round sends messages, ", + "`make_direct_message()` must be implemented" + ))) + } + } + + /// Returns the echo broadcast for this round. + /// + /// The execution layer will guarantee that all the destinations are sure they all received the same broadcast. This + /// also means that a message containing the broadcasts from all nodes and signed by each node is available. This is + /// used as part of the evidence of malicious behavior when producing provable offence reports. + /// + /// If this method is not implemented, [`Self::EchoBroadcast`] must be set to [`NoMessage`]. + fn make_echo_broadcast( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + ) -> Result { + if let Some(message) = NoMessage::new_if_equals::() { + Ok(message) + } else if self.communication_info().message_destinations.is_empty() { + // TODO (#4): this branch could potentially be eliminated + Err(LocalError::new( + "`make_echo_broadcast() called when the round does not send messages - internal error", + )) + } else { + Err(LocalError::new(concat!( + "If `EchoBroadcast` is not `NoMessage`, and the round sends messages, ", + "`make_echo_broadcast()` must be implemented" + ))) + } + } + + /// Returns the normal broadcast for this round. + /// + /// Unlike echo broadcasts, normal broadcasts are "send and forget" and delivered to every node defined in + /// [`Self::communication_info`] without any confirmation required by the receiving node. + /// + /// If this method is not implemented, [`Self::NormalBroadcast`] must be set to [`NoMessage`]. + fn make_normal_broadcast( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + ) -> Result { + if let Some(message) = NoMessage::new_if_equals::() { + Ok(message) + } else if self.communication_info().message_destinations.is_empty() { + // TODO (#4): this branch could potentially be eliminated + Err(LocalError::new( + "`make_normal_broadcast() called when the round does not send messages - internal error", + )) + } else { + Err(LocalError::new(concat!( + "If `NormalBroadcast` is not `NoMessage`, and the round sends messages, ", + "`make_normal_broadcast()` must be implemented" + ))) + } + } + + /// Processes a received message and generates the payload that will be used in [`finalize`](`Self::finalize`). The + /// message content can be arbitrarily checked and processed to build the exact payload needed to finalize the + /// round. + /// + /// Note that there is no need to authenticate the message at this point; + /// it has already been done by the execution layer. + fn receive_message( + &self, + from: &Id, + message_parts: ProtocolMessage, + ) -> Result>; + + /// Attempts to finalize the round, producing the next round or the result. + /// + /// `payloads` here are the ones previously generated by [`receive_message`](`Self::receive_message`), and + /// `artifacts` are the ones previously generated by [`make_direct_message`](`Self::make_direct_message`). + fn finalize( + self, + rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, LocalError>; +} + /// Describes what other parties this rounds sends messages to, and what other parties it expects messages from. #[derive(Debug, Clone)] pub struct CommunicationInfo { @@ -45,7 +269,10 @@ pub struct CommunicationInfo { pub echo_round_participation: EchoRoundParticipation, } -impl CommunicationInfo { +impl CommunicationInfo +where + Id: PartyId, +{ /// A regular round that sends messages to all `other_parties`, and expects messages back from them. pub fn regular(other_parties: &BTreeSet) -> Self { Self { @@ -58,7 +285,7 @@ impl CommunicationInfo { /// Possible successful outcomes of [`Round::finalize`]. #[derive(Debug)] -pub enum FinalizeOutcome> { +pub enum FinalizeOutcome> { /// Transition to a new round. AnotherRound(BoxedRound), /// The protocol reached a result. @@ -70,257 +297,12 @@ pub trait Protocol: 'static { /// The successful result of an execution of this protocol. type Result: Debug; - /// An object of this type will be returned when a provable error happens during [`Round::receive_message`]. - type ProtocolError: ProtocolError; - - /// Returns `Ok(())` if the given direct message cannot be deserialized - /// assuming it is a direct message from the round `round_id`. - /// - /// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`] - /// when implementing this. - fn verify_direct_message_is_invalid( - format: &BoxedFormat, - round_id: &RoundId, - message: &DirectMessage, - ) -> Result<(), MessageValidationError>; - - /// Returns `Ok(())` if the given echo broadcast cannot be deserialized - /// assuming it is an echo broadcast from the round `round_id`. - /// - /// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`] - /// when implementing this. - fn verify_echo_broadcast_is_invalid( - format: &BoxedFormat, - round_id: &RoundId, - message: &EchoBroadcast, - ) -> Result<(), MessageValidationError>; - - /// Returns `Ok(())` if the given echo broadcast cannot be deserialized - /// assuming it is an echo broadcast from the round `round_id`. - /// - /// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`] - /// when implementing this. - fn verify_normal_broadcast_is_invalid( - format: &BoxedFormat, - round_id: &RoundId, - message: &NormalBroadcast, - ) -> Result<(), MessageValidationError>; -} - -/// Declares which parts of the message from a round have to be stored to serve as the evidence of malicious behavior. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RequiredMessageParts { - pub(crate) echo_broadcast: bool, - pub(crate) normal_broadcast: bool, - pub(crate) direct_message: bool, -} - -impl RequiredMessageParts { - fn new(echo_broadcast: bool, normal_broadcast: bool, direct_message: bool) -> Self { - // We must require at least one part, otherwise this struct doesn't need to be created. - debug_assert!(echo_broadcast || normal_broadcast || direct_message); - Self { - echo_broadcast, - normal_broadcast, - direct_message, - } - } - - /// Store echo broadcast - pub fn echo_broadcast() -> Self { - Self::new(true, false, false) - } - - /// Store normal broadcast - pub fn normal_broadcast() -> Self { - Self::new(false, true, false) - } - - /// Store direct message - pub fn direct_message() -> Self { - Self::new(false, false, true) - } - - /// Store echo broadcast in addition to what is already stored. - pub fn and_echo_broadcast(&self) -> Self { - Self::new(true, self.normal_broadcast, self.direct_message) - } - - /// Store normal broadcast in addition to what is already stored. - pub fn and_normal_broadcast(&self) -> Self { - Self::new(self.echo_broadcast, true, self.direct_message) - } - - /// Store direct message in addition to what is already stored. - pub fn and_direct_message(&self) -> Self { - Self::new(self.echo_broadcast, self.normal_broadcast, true) - } -} - -/// Declares which messages from this and previous rounds -/// have to be stored to serve as the evidence of malicious behavior. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RequiredMessages { - pub(crate) this_round: RequiredMessageParts, - pub(crate) previous_rounds: Option>, - pub(crate) combined_echos: Option>, -} + /// The subset of public data shared between all participating nodes before the beginning of the protocol + /// (excluding the session ID) that is necessary for evidence verification. + type SharedData: Debug; -impl RequiredMessages { - /// The general case constructor. - /// - /// `this_round` specifies the message parts to be stored from the message that triggered the error. - /// - /// `previous_rounds` specifies, optionally, if any message parts from the previous rounds need to be included. - /// - /// `combined_echos` specifies, optionally, if any echoed broadcasts need to be included. - /// The combined echos are echo broadcasts sent by a party during the echo round, - /// where it bundles all the received broadcasts and sends them back to everyone. - /// That is, they will include the echo broadcasts from all other nodes signed by the guilty party. - pub fn new( - this_round: RequiredMessageParts, - previous_rounds: Option>, - combined_echos: Option>, - ) -> Self { - Self { - this_round, - previous_rounds, - combined_echos, - } - } -} - -/// Describes provable errors originating during protocol execution. -/// -/// Provable here means that we can create an evidence object entirely of messages signed by some party, -/// which, in combination, prove the party's malicious actions. -pub trait ProtocolError: Display + Debug + Clone + Serialize + for<'de> Deserialize<'de> { - /// Additional data that cannot be derived from the node's messages alone - /// and therefore has to be supplied externally during evidence verification. - type AssociatedData: Debug; - - /// Specifies the messages of the guilty party that need to be stored as the evidence - /// to prove its malicious behavior. - fn required_messages(&self) -> RequiredMessages; - - /// Returns `Ok(())` if the attached messages indeed prove that a malicious action happened. - /// - /// The signatures and metadata of the messages will be checked by the calling code, - /// the responsibility of this method is just to check the message contents. - /// - /// `message` contain the message parts that triggered the error - /// during [`Round::receive_message`]. - /// - /// `previous_messages` are message parts from the previous rounds, as requested by - /// [`required_messages`](Self::required_messages). - /// - /// Note that if some message part was not requested by above methods, it will be set to an empty one - /// in the [`ProtocolMessage`], even if it was present originally. - /// - /// `combined_echos` are bundled echos from other parties from the previous rounds, - /// as requested by [`required_messages`](Self::required_messages). - #[allow(clippy::too_many_arguments)] - fn verify_messages_constitute_error( - &self, - format: &BoxedFormat, - guilty_party: &Id, - shared_randomness: &[u8], - associated_data: &Self::AssociatedData, - message: ProtocolMessage, - previous_messages: BTreeMap, - combined_echos: BTreeMap>, - ) -> Result<(), ProtocolValidationError>; -} - -#[derive(displaydoc::Display, Debug, Clone, Copy, Serialize, Deserialize)] -/// A stub type indicating that this protocol does not generate any provable errors. -pub struct NoProtocolErrors; - -impl ProtocolError for NoProtocolErrors { - type AssociatedData = (); - - fn required_messages(&self) -> RequiredMessages { - panic!("Attempt to use an empty error type in an evidence. This is a bug in the protocol implementation.") - } - - fn verify_messages_constitute_error( - &self, - _format: &BoxedFormat, - _guilty_party: &Id, - _shared_randomness: &[u8], - _associated_data: &Self::AssociatedData, - _message: ProtocolMessage, - _previous_messages: BTreeMap, - _combined_echos: BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - panic!("Attempt to use an empty error type in an evidence. This is a bug in the protocol implementation.") - } -} - -/// Message payload created in [`Round::receive_message`]. -/// -/// [`Payload`]s are created as the output of processing an incoming message. When a [`Round`] finalizes, all the -/// `Payload`s received during the round are made available and can be used to decide what to do next (next round? -/// return a final result?). Payloads are not sent to other nodes. -#[derive(Debug)] -pub struct Payload(pub Box); - -impl Payload { - /// Creates a new payload. - /// - /// Would be normally called in [`Round::receive_message`]. - pub fn new(payload: T) -> Self { - Self(Box::new(payload)) - } - - /// Creates an empty payload. - /// - /// Use it in [`Round::receive_message`] if it does not need to create payloads. - pub fn empty() -> Self { - Self::new(()) - } - - /// Attempts to downcast back to the concrete type. - /// - /// Would be normally called in [`Round::finalize`]. - pub fn downcast(self) -> Result { - Ok(*(self.0.downcast::().map_err(|_| { - LocalError::new(format!( - "Failed to downcast Payload into {}", - core::any::type_name::() - )) - })?)) - } -} - -/// Associated data created alongside a message in [`Round::make_direct_message`]. -/// -/// [`Artifact`]s are local to the participant that created it and are usually containers for intermediary secrets -/// and/or dynamic parameters needed in subsequent stages of the protocol. Artifacts are never sent over the wire; they -/// are made available to [`Round::finalize`] for the participant, delivered in the form of a `BTreeMap` where the key -/// is the destination id of the participant to whom the direct message was sent. -#[derive(Debug)] -pub struct Artifact(pub Box); - -impl Artifact { - /// Creates a new artifact. - /// - /// Would be normally called in [`Round::make_direct_message`]. - pub fn new(artifact: T) -> Self { - Self(Box::new(artifact)) - } - - /// Attempts to downcast back to the concrete type. - /// - /// Would be normally called in [`Round::finalize`]. - pub fn downcast(self) -> Result { - Ok(*(self.0.downcast::().map_err(|_| { - LocalError::new(format!( - "Failed to downcast Artifact into {}", - core::any::type_name::() - )) - })?)) - } + /// Returns the round metadata for each round mapped to round IDs. + fn round_info(round_id: &RoundId) -> Option>; } /// A round that initiates a protocol and defines how execution begins. It is the only round that can be created outside @@ -342,7 +324,7 @@ pub trait EntryPoint { /// `id` is the ID of this node. fn make_round( self, - rng: &mut dyn CryptoRngCore, + rng: &mut impl CryptoRngCore, shared_randomness: &[u8], id: &Id, ) -> Result, LocalError>; @@ -371,112 +353,3 @@ pub enum EchoRoundParticipation { echo_targets: BTreeSet, }, } - -mod sealed { - /// A dyn safe trait to get the type's ID. - pub trait DynTypeId: 'static { - /// Returns the type ID of the implementing type. - fn get_type_id(&self) -> core::any::TypeId { - core::any::TypeId::of::() - } - } - - impl DynTypeId for T {} -} - -use sealed::DynTypeId; - -/** -A type representing a single round of a protocol. - -The way a round will be used by an external caller: -- create messages to send out (by calling [`make_direct_message`](`Self::make_direct_message`) - and [`make_echo_broadcast`](`Self::make_echo_broadcast`)); -- process received messages from other nodes (by calling [`receive_message`](`Self::receive_message`)); -- attempt to finalize (by calling [`finalize`](`Self::finalize`)) to produce the next round, or return a result. -*/ -pub trait Round: 'static + Debug + Send + Sync + DynTypeId { - /// The protocol this round is a part of. - type Protocol: Protocol; - - /// Returns the information about the position of this round in the state transition graph. - /// - /// See [`TransitionInfo`] documentation for more details. - fn transition_info(&self) -> TransitionInfo; - - /// Returns the information about the communication this rounds engages in with other nodes. - /// - /// See [`CommunicationInfo`] documentation for more details. - fn communication_info(&self) -> CommunicationInfo; - - /// Returns the direct message to the given destination and (maybe) an accompanying artifact. - /// - /// Return [`DirectMessage::none`] if this round does not send direct messages. - /// - /// In some protocols, when a message to another node is created, there is some associated information - /// that needs to be retained for later (randomness, proofs of knowledge, and so on). - /// These should be put in an [`Artifact`] and will be available at the time of [`finalize`](`Self::finalize`). - fn make_direct_message( - &self, - #[allow(unused_variables)] rng: &mut dyn CryptoRngCore, - #[allow(unused_variables)] format: &BoxedFormat, - #[allow(unused_variables)] destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - Ok((DirectMessage::none(), None)) - } - - /// Returns the echo broadcast for this round. - /// - /// Return [`EchoBroadcast::none`] if this round does not send echo-broadcast messages. - /// This is also the blanket implementation. - /// - /// The execution layer will guarantee that all the destinations are sure they all received the same broadcast. This - /// also means that a message containing the broadcasts from all nodes and signed by each node is available. This is - /// used as part of the evidence of malicious behavior when producing provable offence reports. - fn make_echo_broadcast( - &self, - #[allow(unused_variables)] rng: &mut dyn CryptoRngCore, - #[allow(unused_variables)] format: &BoxedFormat, - ) -> Result { - Ok(EchoBroadcast::none()) - } - - /// Returns the normal broadcast for this round. - /// - /// Return [`NormalBroadcast::none`] if this round does not send normal broadcast messages. - /// This is also the blanket implementation. - /// - /// Unlike echo broadcasts, normal broadcasts are "send and forget" and delivered to every node defined in - /// [`Self::communication_info`] without any confirmation required by the receiving node. - fn make_normal_broadcast( - &self, - #[allow(unused_variables)] rng: &mut dyn CryptoRngCore, - #[allow(unused_variables)] format: &BoxedFormat, - ) -> Result { - Ok(NormalBroadcast::none()) - } - - /// Processes a received message and generates the payload that will be used in [`finalize`](`Self::finalize`). The - /// message content can be arbitrarily checked and processed to build the exact payload needed to finalize the - /// round. - /// - /// Note that there is no need to authenticate the message at this point; - /// it has already been done by the execution layer. - fn receive_message( - &self, - format: &BoxedFormat, - from: &Id, - message: ProtocolMessage, - ) -> Result>; - - /// Attempts to finalize the round, producing the next round or the result. - /// - /// `payloads` here are the ones previously generated by [`receive_message`](`Self::receive_message`), and - /// `artifacts` are the ones previously generated by [`make_direct_message`](`Self::make_direct_message`). - fn finalize( - self: Box, - rng: &mut dyn CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result, LocalError>; -} diff --git a/manul/src/protocol/round_id.rs b/manul/src/protocol/round_id.rs index 1518813..06b569f 100644 --- a/manul/src/protocol/round_id.rs +++ b/manul/src/protocol/round_id.rs @@ -6,22 +6,26 @@ use tinyvec::TinyVec; use super::errors::LocalError; +/// Round number. +pub type RoundNum = u8; + +pub(crate) type GroupNum = u8; + /// A round identifier. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - round_nums: TinyVec<[u8; 4]>, + round: RoundNum, + groups: TinyVec<[GroupNum; 4]>, is_echo: bool, } impl Display for RoundId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Round ")?; - for (i, round_num) in self.round_nums.iter().enumerate().rev() { - write!(f, "{}", round_num)?; - if i != 0 { - write!(f, "-")?; - } + for group_num in self.groups.iter().rev() { + write!(f, "{group_num}-")?; } + write!(f, "{}", self.round)?; if self.is_echo { write!(f, " (echo)")?; } @@ -31,25 +35,29 @@ impl Display for RoundId { impl RoundId { /// Creates a new round identifier. - pub fn new(round_num: u8) -> Self { - let mut round_nums = TinyVec::new(); - round_nums.push(round_num); + pub fn new(round_num: RoundNum) -> Self { Self { - round_nums, + round: round_num, + groups: TinyVec::new(), is_echo: false, } } + pub(crate) fn round_num(&self) -> RoundNum { + self.round + } + /// Prefixes this round ID (possibly already nested) with a group number. /// /// This is supposed to be used internally, e.g. in the chain combinator, /// where we have several protocols joined up, and their round numbers may repeat. /// Grouping allows us to disambiguate them, assigning group 1 to one protocol and group 2 to the other. - pub(crate) fn group_under(&self, round_num: u8) -> Self { - let mut round_nums = self.round_nums.clone(); - round_nums.push(round_num); + pub(crate) fn group_under(&self, group_num: GroupNum) -> Self { + let mut groups = self.groups.clone(); + groups.push(group_num); Self { - round_nums, + round: self.round, + groups, is_echo: self.is_echo, } } @@ -58,14 +66,15 @@ impl RoundId { /// and returns this prefix along with the resulting round ID. /// /// Returns the `Err` variant if the round ID is not nested. - pub(crate) fn split_group(&self) -> Result<(u8, Self), LocalError> { - if self.round_nums.len() == 1 { + pub(crate) fn split_group(&self) -> Result<(GroupNum, Self), LocalError> { + if self.groups.is_empty() { Err(LocalError::new("This round ID is not in a group")) } else { - let mut round_nums = self.round_nums.clone(); - let group = round_nums.pop().expect("vector size greater than 1"); + let mut groups = self.groups.clone(); + let group = groups.pop().expect("vector size greater than 0"); let round_id = Self { - round_nums, + round: self.round, + groups, is_echo: self.is_echo, }; Ok((group, round_id)) @@ -87,7 +96,8 @@ impl RoundId { Err(LocalError::new("This is already an echo round ID")) } else { Ok(Self { - round_nums: self.round_nums.clone(), + round: self.round, + groups: self.groups.clone(), is_echo: true, }) } @@ -103,25 +113,32 @@ impl RoundId { Err(LocalError::new("This is already an non-echo round ID")) } else { Ok(Self { - round_nums: self.round_nums.clone(), + round: self.round, + groups: self.groups.clone(), is_echo: false, }) } } } -impl From for RoundId { - fn from(source: u8) -> Self { +impl From for RoundId { + fn from(source: RoundNum) -> Self { Self::new(source) } } -impl PartialEq for RoundId { - fn eq(&self, rhs: &u8) -> bool { +impl PartialEq for RoundId { + fn eq(&self, rhs: &RoundNum) -> bool { self == &RoundId::new(*rhs) } } +impl PartialEq for &RoundId { + fn eq(&self, rhs: &RoundNum) -> bool { + *self == &RoundId::new(*rhs) + } +} + /// Information about the position of the round in the state transition graph. #[derive(Debug, Clone)] pub struct TransitionInfo { @@ -152,7 +169,7 @@ pub struct TransitionInfo { impl TransitionInfo { /// Nest the round IDs under the given group. Used for combinators. - pub(crate) fn group_under(self, group: u8) -> Self { + pub(crate) fn group_under(self, group: GroupNum) -> Self { Self { id: self.id.group_under(group), parents: self.parents.into_iter().map(|r| r.group_under(group)).collect(), @@ -199,7 +216,7 @@ impl TransitionInfo { /// /// That is, if there are rounds 1, 2, 3, ..., N, where the N-th one returns the result, /// this constructor can be used for rounds 1 to N-1. - pub fn new_linear(round_num: u8) -> Self { + pub fn new_linear(round_num: RoundNum) -> Self { Self { id: RoundId::new(round_num), parents: if round_num > 1 { @@ -218,7 +235,7 @@ impl TransitionInfo { /// /// That is, if there are rounds 1, 2, 3, ..., N, where the N-th one returns the result, /// this constructor can be used for round N. - pub fn new_linear_terminating(round_num: u8) -> Self { + pub fn new_linear_terminating(round_num: RoundNum) -> Self { Self { id: RoundId::new(round_num), parents: if round_num > 1 { @@ -233,7 +250,7 @@ impl TransitionInfo { } /// Returns a new [`TransitionInfo`] with `round_nums` added to the set of children. - pub fn with_children(self, round_nums: BTreeSet) -> Self { + pub fn with_children(self, round_nums: BTreeSet) -> Self { let mut children = self.children; children.extend(round_nums.iter().map(|num| RoundId::new(*num))); Self { @@ -246,7 +263,7 @@ impl TransitionInfo { } /// Returns a new [`TransitionInfo`] with `round_nums` added to the set of siblings. - pub fn with_siblings(self, round_nums: BTreeSet) -> Self { + pub fn with_siblings(self, round_nums: BTreeSet) -> Self { let mut siblings = self.siblings; siblings.extend(round_nums.iter().map(|num| RoundId::new(*num))); Self { diff --git a/manul/src/protocol/round_info.rs b/manul/src/protocol/round_info.rs new file mode 100644 index 0000000..40e5470 --- /dev/null +++ b/manul/src/protocol/round_info.rs @@ -0,0 +1,143 @@ +use alloc::{boxed::Box, collections::BTreeMap, format}; +use core::{fmt::Debug, marker::PhantomData}; + +use super::{ + dyn_evidence::SerializedProtocolError, + evidence::{EvidenceError, EvidenceMessages, EvidenceProtocolMessage, ProtocolError}, + message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}, + round::{NoMessage, NoType, Protocol, Round}, + round_id::RoundId, + wire_format::BoxedFormat, +}; + +pub(crate) trait DynRoundInfo: Debug { + type Protocol: Protocol; + fn verify_direct_message_is_invalid( + &self, + format: &BoxedFormat, + message: &DirectMessage, + ) -> Result<(), EvidenceError>; + fn verify_echo_broadcast_is_invalid( + &self, + format: &BoxedFormat, + message: &EchoBroadcast, + ) -> Result<(), EvidenceError>; + fn verify_normal_broadcast_is_invalid( + &self, + format: &BoxedFormat, + message: &NormalBroadcast, + ) -> Result<(), EvidenceError>; + + #[allow(clippy::too_many_arguments)] + fn verify_evidence( + &self, + round_id: &RoundId, + format: &BoxedFormat, + error: &SerializedProtocolError, + guilty_party: &Id, + shared_randomness: &[u8], + shared_data: &>::SharedData, + message: EvidenceProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, + ) -> Result<(), EvidenceError>; +} + +#[derive(Debug)] +struct RoundInfoObject(PhantomData R>); + +impl DynRoundInfo for RoundInfoObject +where + R: Round, +{ + type Protocol = R::Protocol; + + fn verify_direct_message_is_invalid( + &self, + format: &BoxedFormat, + message: &DirectMessage, + ) -> Result<(), EvidenceError> { + if NoMessage::equals::() { + message.verify_is_some() + } else { + message.verify_is_not::(format) + } + } + + fn verify_echo_broadcast_is_invalid( + &self, + format: &BoxedFormat, + message: &EchoBroadcast, + ) -> Result<(), EvidenceError> { + if NoMessage::equals::() { + message.verify_is_some() + } else { + message.verify_is_not::(format) + } + } + + fn verify_normal_broadcast_is_invalid( + &self, + format: &BoxedFormat, + message: &NormalBroadcast, + ) -> Result<(), EvidenceError> { + if NoMessage::equals::() { + message.verify_is_some() + } else { + message.verify_is_not::(format) + } + } + + fn verify_evidence( + &self, + round_id: &RoundId, + format: &BoxedFormat, + error: &SerializedProtocolError, + guilty_party: &Id, + shared_randomness: &[u8], + shared_data: &>::SharedData, + message: EvidenceProtocolMessage, + previous_messages: BTreeMap, + combined_echos: BTreeMap>, + ) -> Result<(), EvidenceError> { + let error = error.deserialize::(format).map_err(|err| { + EvidenceError::InvalidEvidence(format!( + "Cannot deserialize the error as {}: {err}", + core::any::type_name::() + )) + })?; + let evidence_messages = EvidenceMessages::new(format, message, previous_messages, combined_echos); + error.verify_evidence( + round_id, + guilty_party, + shared_randomness, + shared_data, + evidence_messages, + ) + } +} + +/// Type- and state-independent round metadata. +#[derive_where::derive_where(Debug)] +pub struct RoundInfo + ?Sized>(Box>); + +impl RoundInfo +where + P: Protocol, +{ + /// Creates a new metadata object for a round of type `R`. + pub fn new() -> Self + where + R: Round, + { + Self(Box::new(RoundInfoObject::(PhantomData))) + } + + pub(crate) fn new_obj(round: impl DynRoundInfo + 'static) -> Self { + Self(Box::new(round)) + } + + pub(crate) fn as_ref(&self) -> &dyn DynRoundInfo { + self.0.as_ref() + } +} diff --git a/manul/src/protocol/boxed_format.rs b/manul/src/protocol/wire_format.rs similarity index 61% rename from manul/src/protocol/boxed_format.rs rename to manul/src/protocol/wire_format.rs index b855fcb..084c6b8 100644 --- a/manul/src/protocol/boxed_format.rs +++ b/manul/src/protocol/wire_format.rs @@ -3,18 +3,20 @@ use core::{fmt::Debug, marker::PhantomData}; use serde::{Deserialize, Serialize}; -use super::errors::{DeserializationError, LocalError}; -use crate::session::WireFormat; +use crate::{ + protocol::LocalError, + session::{DeserializationError, WireFormat}, +}; -trait ObjectSafeSerializer: Debug { +trait DynSerializer: Debug { fn serialize(&self, value: Box) -> Result, LocalError>; } // `fn(F)` makes the type `Send` + `Sync` even if `F` isn't. #[derive(Debug)] -struct SerializerWrapper(PhantomData); +struct SerializerObject(PhantomData F>); -impl ObjectSafeSerializer for SerializerWrapper { +impl DynSerializer for SerializerObject { fn serialize(&self, value: Box) -> Result, LocalError> { F::serialize(&value) } @@ -22,13 +24,13 @@ impl ObjectSafeSerializer for SerializerWrapper { // `fn(F)` makes the type `Send` + `Sync` even if `F` isn't. #[derive(Debug)] -struct DeserializerFactoryWrapper(PhantomData); +struct DeserializerFactoryObject(PhantomData F>); -trait ObjectSafeDeserializerFactory: Debug { +trait DynDeserializerFactory: Debug { fn make_erased_deserializer<'de>(&self, bytes: &'de [u8]) -> Box + 'de>; } -impl ObjectSafeDeserializerFactory for DeserializerFactoryWrapper +impl DynDeserializerFactory for DeserializerFactoryObject where F: WireFormat, { @@ -40,21 +42,21 @@ where /// A serializer/deserializer for protocol messages. #[derive(Debug)] -pub struct BoxedFormat { - serializer: Box, - deserializer_factory: Box, +pub(crate) struct BoxedFormat { + serializer: Box, + deserializer_factory: Box, } impl BoxedFormat { - pub(crate) fn new() -> Self { + pub fn new() -> Self { Self { - serializer: Box::new(SerializerWrapper::(PhantomData)), - deserializer_factory: Box::new(DeserializerFactoryWrapper::(PhantomData)), + serializer: Box::new(SerializerObject::(PhantomData)), + deserializer_factory: Box::new(DeserializerFactoryObject::(PhantomData)), } } /// Serializes a `serde`-serializable object. - pub(crate) fn serialize(&self, value: T) -> Result, LocalError> + pub fn serialize(&self, value: T) -> Result, LocalError> where T: 'static + Serialize, { @@ -63,13 +65,13 @@ impl BoxedFormat { } /// Deserializes a `serde`-deserializable object. - pub(crate) fn deserialize<'de, T>(&self, bytes: &'de [u8]) -> Result + pub fn deserialize<'de, T>(&self, bytes: &'de [u8]) -> Result where T: Deserialize<'de>, { let mut deserializer = self.deserializer_factory.make_erased_deserializer(bytes); erased_serde::deserialize::(&mut deserializer) - .map_err(|err| DeserializationError::new(format!("Deserialization error: {err:?}"))) + .map_err(|err| DeserializationError::new::(format!("{err:?}"))) } } diff --git a/manul/src/session.rs b/manul/src/session.rs index acf6bcc..9f206f1 100644 --- a/manul/src/session.rs +++ b/manul/src/session.rs @@ -18,13 +18,13 @@ mod wire_format; #[cfg(feature = "tokio")] pub mod tokio; -pub use crate::protocol::{LocalError, RemoteError}; -pub use evidence::{Evidence, EvidenceError}; +pub use crate::protocol::{EvidenceError, LocalError, RemoteError}; +pub use evidence::Evidence; pub use message::{Message, VerifiedMessage}; pub use session::{ CanFinalize, PreprocessOutcome, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters, }; pub use transcript::{SessionOutcome, SessionReport}; -pub use wire_format::WireFormat; +pub use wire_format::{DeserializationError, WireFormat}; pub(crate) use echo::EchoRoundError; diff --git a/manul/src/session/echo.rs b/manul/src/session/echo.rs index 5930bd4..828a697 100644 --- a/manul/src/session/echo.rs +++ b/manul/src/session/echo.rs @@ -12,21 +12,21 @@ use serde::{Deserialize, Serialize}; use tracing::debug; use super::{ - message::{MessageVerificationError, SignedMessageHash, SignedMessagePart}, + message::{MessageMetadata, MessageVerificationError, SignedMessageHash, SignedMessagePart}, session::{EchoRoundInfo, SessionParameters}, LocalError, }; use crate::{ protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EchoRoundParticipation, - FinalizeOutcome, MessageValidationError, NormalBroadcast, Payload, Protocol, ProtocolMessage, - ProtocolMessagePart, ReceiveError, Round, TransitionInfo, + Artifact, BoxedFormat, BoxedReceiveError, BoxedRound, CommunicationInfo, DirectMessage, DynProtocolMessage, + DynRound, EchoBroadcast, EchoRoundParticipation, EvidenceError, FinalizeOutcome, NoArtifact, NoMessage, NoType, + NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessagePart, RemoteError, TransitionInfo, }, utils::SerializableMap, }; /// An error that can occur on receiving a message during an echo round. -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub(crate) enum EchoRoundError { /// The node who constructed the echoed message pack included an invalid message in it. /// @@ -34,40 +34,106 @@ pub(crate) enum EchoRoundError { /// /// The attached identifier points out the sender for whom the echoed message was invalid, /// to speed up the verification process. - InvalidEcho(Id), + InvalidEcho(InvalidEchoError), /// The originally received message and the one received in the echo pack were both valid, /// but different. /// /// This is the fault of the sender of that specific broadcast. - MismatchedBroadcasts { - guilty_party: Id, - we_received: SignedMessagePart, - echoed_to_us: SignedMessageHash, - }, + MismatchedBroadcasts(MismatchedBroadcastsError), } -impl EchoRoundError { - pub(crate) fn description(&self) -> String { - match self { - Self::InvalidEcho(_) => "Invalid message received among the ones echoed".into(), - Self::MismatchedBroadcasts { .. } => { - "The echoed message is different from the originally received one".into() - } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct InvalidEchoError { + invalid_echo_sender: Id, +} + +impl InvalidEchoError { + pub fn description(&self) -> String { + "Invalid message received among the ones echoed".into() + } + + pub fn verify_evidence>( + &self, + metadata: &MessageMetadata, + message: &EchoRoundMessage, + ) -> Result<(), EvidenceError> { + let invalid_echo = message.message_hashes.get(&self.invalid_echo_sender).ok_or_else(|| { + EvidenceError::InvalidEvidence(format!( + "Did not find {:?} in the attached message", + self.invalid_echo_sender + )) + })?; + + let verified_echo = match invalid_echo.clone().verify::(&self.invalid_echo_sender) { + Ok(echo) => echo, + Err(MessageVerificationError::Local(error)) => return Err(EvidenceError::Local(error)), + // The message was indeed incorrectly signed - fault proven + Err(MessageVerificationError::InvalidSignature) => return Ok(()), + Err(MessageVerificationError::SignatureMismatch) => return Ok(()), + }; + + // `from` sent us a correctly signed message but from another round or another session. + // Provable fault of `from`. + if verified_echo.metadata() != metadata { + return Ok(()); } + + Err(EvidenceError::InvalidEvidence( + "There is nothing wrong with the echoed message".into(), + )) } } #[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct EchoRoundMessage { +pub(crate) struct MismatchedBroadcastsError { + guilty_party: Id, + we_received: SignedMessagePart, + echoed_to_us: SignedMessageHash, +} + +impl MismatchedBroadcastsError { + pub fn description(&self) -> String { + "The echoed message is different from the originally received one".into() + } + + pub fn guilty_party(&self) -> &Id { + &self.guilty_party + } + + pub fn verify_evidence>(&self) -> Result<(), EvidenceError> { + let we_received = self + .we_received + .clone() + .verify::(&self.guilty_party) + .map_err(MessageVerificationError::into_evidence_error)?; + let echoed_to_us = self + .echoed_to_us + .clone() + .verify::(&self.guilty_party) + .map_err(MessageVerificationError::into_evidence_error)?; + + if we_received.metadata() == echoed_to_us.metadata() && !echoed_to_us.is_hash_of::(&self.we_received) { + Ok(()) + } else { + Err(EvidenceError::InvalidEvidence( + "The attached messages don't constitute malicious behavior".into(), + )) + } + } +} + +#[derive(Debug, Clone)] +#[derive_where::derive_where(Serialize, Deserialize)] +pub(crate) struct EchoRoundMessage { /// Signatures of echo broadcasts from respective nodes. - pub(super) message_hashes: SerializableMap, + pub(super) message_hashes: SerializableMap, } /// Each protocol round can contain one `EchoRound` with "echo messages" that are sent to all /// participants. The execution layer of the protocol guarantees that all participants have received /// the messages. #[derive_where::derive_where(Debug)] -pub struct EchoRound, SP: SessionParameters> { +pub(super) struct EchoRound, SP: SessionParameters> { verifier: SP::Verifier, echo_broadcasts: BTreeMap>, echo_round_info: EchoRoundInfo, @@ -109,14 +175,20 @@ where } } - // Since the echo round doesn't have its own `Protocol`, these methods live here. + // Since the echo round doesn't have its own static round type, these methods live here. - pub fn verify_direct_message_is_invalid(message: &DirectMessage) -> Result<(), MessageValidationError> { + pub fn verify_direct_message_is_invalid( + _format: &BoxedFormat, + message: &DirectMessage, + ) -> Result<(), EvidenceError> { // We don't send any direct messages in the echo round message.verify_is_some() } - pub fn verify_echo_broadcast_is_invalid(message: &EchoBroadcast) -> Result<(), MessageValidationError> { + pub fn verify_echo_broadcast_is_invalid( + _format: &BoxedFormat, + message: &EchoBroadcast, + ) -> Result<(), EvidenceError> { // We don't send any echo broadcasts in the echo round message.verify_is_some() } @@ -124,12 +196,12 @@ where pub fn verify_normal_broadcast_is_invalid( format: &BoxedFormat, message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { - message.verify_is_not::>(format) + ) -> Result<(), EvidenceError> { + message.verify_is_not::>(format) } } -impl Round for EchoRound +impl DynRound for EchoRound where P: Protocol, SP: SessionParameters, @@ -152,6 +224,26 @@ where self.communication_info.clone() } + fn make_direct_message( + &self, + _rng: &mut dyn CryptoRngCore, + format: &BoxedFormat, + _destination: &SP::Verifier, + ) -> Result<(DirectMessage, Artifact), LocalError> { + Ok(( + DirectMessage::new(format, NoMessage::new())?, + Artifact::new(NoArtifact::new()), + )) + } + + fn make_echo_broadcast( + &self, + _rng: &mut dyn CryptoRngCore, + _format: &BoxedFormat, + ) -> Result { + Ok(EchoBroadcast::none()) + } + fn make_normal_broadcast( &self, _rng: &mut dyn CryptoRngCore, @@ -174,7 +266,7 @@ where .collect::>() .into(); - let message = EchoRoundMessage:: { message_hashes }; + let message = EchoRoundMessage:: { message_hashes }; NormalBroadcast::new(format, message) } @@ -182,14 +274,16 @@ where &self, format: &BoxedFormat, from: &SP::Verifier, - message: ProtocolMessage, - ) -> Result> { + message: DynProtocolMessage, + ) -> Result> { debug!("{:?}: received an echo message from {:?}", self.verifier, from); message.echo_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let message = message.normal_broadcast.deserialize::>(format)?; + let message = message + .normal_broadcast + .deserialize::>(format)?; // Check that the received message contains entries from `expected_echos`. // It is an unprovable fault. @@ -203,18 +297,16 @@ where let missing_keys = expected_keys.difference(&message_keys).collect::>(); if !missing_keys.is_empty() { - return Err(ReceiveError::unprovable(format!( - "Missing echoed messages from: {:?}", - missing_keys - ))); + return Err(BoxedReceiveError::Unprovable(RemoteError::new(format!( + "Missing echoed messages from: {missing_keys:?}", + )))); } let extra_keys = message_keys.difference(&expected_keys).collect::>(); if !extra_keys.is_empty() { - return Err(ReceiveError::unprovable(format!( - "Unexpected echoed messages from: {:?}", - extra_keys - ))); + return Err(BoxedReceiveError::Unprovable(RemoteError::new(format!( + "Unexpected echoed messages from: {extra_keys:?}", + )))); } // Check that every entry is equal to what we received previously (in the main round). @@ -236,29 +328,42 @@ where // This means `from` sent us an incorrectly signed message. // Provable fault of `from`. Err(MessageVerificationError::InvalidSignature) => { - return Err(EchoRoundError::InvalidEcho(sender.clone()).into()) + return Err(BoxedReceiveError::Echo(Box::new(EchoRoundError::InvalidEcho( + InvalidEchoError { + invalid_echo_sender: sender.clone(), + }, + )))) } Err(MessageVerificationError::SignatureMismatch) => { - return Err(EchoRoundError::InvalidEcho(sender.clone()).into()) + return Err(BoxedReceiveError::Echo(Box::new(EchoRoundError::InvalidEcho( + InvalidEchoError { + invalid_echo_sender: sender.clone(), + }, + )))) } }; // `from` sent us a correctly signed message but from another round or another session. // Provable fault of `from`. if verified_echo.metadata() != previously_received_echo.metadata() { - return Err(EchoRoundError::InvalidEcho(sender.clone()).into()); + return Err(BoxedReceiveError::Echo(Box::new(EchoRoundError::InvalidEcho( + InvalidEchoError { + invalid_echo_sender: sender.clone(), + }, + )))); } // `sender` sent us and `from` messages with different payloads, // but with correct signatures and the same metadata. // Provable fault of `sender`. if !verified_echo.is_hash_of::(previously_received_echo) { - return Err(EchoRoundError::MismatchedBroadcasts { - guilty_party: sender.clone(), - we_received: previously_received_echo.clone(), - echoed_to_us: echo.clone(), - } - .into()); + return Err(BoxedReceiveError::Echo(Box::new(EchoRoundError::MismatchedBroadcasts( + MismatchedBroadcastsError { + guilty_party: sender.clone(), + we_received: previously_received_echo.clone(), + echoed_to_us: echo.clone(), + }, + )))); } } @@ -272,7 +377,7 @@ where _artifacts: BTreeMap, ) -> Result, LocalError> { self.main_round - .into_boxed() + .into_inner() .finalize(rng, self.payloads, self.artifacts) } } diff --git a/manul/src/session/evidence.rs b/manul/src/session/evidence.rs index 5bfd148..f697bad 100644 --- a/manul/src/session/evidence.rs +++ b/manul/src/session/evidence.rs @@ -1,80 +1,32 @@ use alloc::{ + boxed::Box, collections::{BTreeMap, BTreeSet}, format, string::{String, ToString}, }; -use core::fmt::Debug; +use core::{fmt::Debug, marker::PhantomData}; use serde::{Deserialize, Serialize}; use super::{ - echo::{EchoRound, EchoRoundError, EchoRoundMessage}, - message::{MessageVerificationError, SignedMessageHash, SignedMessagePart}, + echo::{EchoRound, EchoRoundError, EchoRoundMessage, InvalidEchoError, MismatchedBroadcastsError}, + message::{MessageVerificationError, SignedMessagePart}, session::{SessionId, SessionParameters}, transcript::Transcript, LocalError, }; use crate::{ protocol::{ - BoxedFormat, DirectMessage, DirectMessageError, EchoBroadcast, EchoBroadcastError, MessageValidationError, - NormalBroadcast, NormalBroadcastError, Protocol, ProtocolError, ProtocolMessage, ProtocolMessagePart, - ProtocolMessagePartHashable, ProtocolValidationError, RoundId, + BoxedFormat, BoxedProtocolError, DirectMessage, DirectMessageError, EchoBroadcast, EchoBroadcastError, + EvidenceError, EvidenceProtocolMessage, NormalBroadcast, NormalBroadcastError, PartyId, Protocol, + ProtocolMessagePart, ProtocolMessagePartHashable, RoundId, SerializedProtocolError, }, utils::SerializableMap, }; -/// Possible errors when verifying [`Evidence`] (evidence of malicious behavior). -#[derive(Debug, Clone)] -pub enum EvidenceError { - /// Indicates a runtime problem or a bug in the code. - Local(LocalError), - /// The evidence is improperly constructed - /// - /// This can indicate many things, such as: messages missing, invalid signatures, invalid messages, - /// the messages not actually proving the malicious behavior. - /// See the attached description for details. - InvalidEvidence(String), -} - -impl From for EvidenceError { - fn from(error: MessageVerificationError) -> Self { - match error { - MessageVerificationError::Local(error) => Self::Local(error), - MessageVerificationError::InvalidSignature => Self::InvalidEvidence("Invalid message signature".into()), - MessageVerificationError::SignatureMismatch => { - Self::InvalidEvidence("The signature does not match the payload".into()) - } - } - } -} - -impl From for EvidenceError { - fn from(error: NormalBroadcastError) -> Self { - Self::InvalidEvidence(format!("Failed to deserialize normal broadcast: {:?}", error)) - } -} - -impl From for EvidenceError { - fn from(error: MessageValidationError) -> Self { - match error { - MessageValidationError::Local(error) => Self::Local(error), - MessageValidationError::InvalidEvidence(error) => Self::InvalidEvidence(error), - } - } -} - -impl From for EvidenceError { - fn from(error: ProtocolValidationError) -> Self { - match error { - ProtocolValidationError::Local(error) => Self::Local(error), - ProtocolValidationError::InvalidEvidence(error) => Self::InvalidEvidence(error), - } - } -} - /// A self-contained evidence of malicious behavior by a node. #[derive_where::derive_where(Debug)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct Evidence, SP: SessionParameters> { guilty_party: SP::Verifier, description: String, @@ -86,16 +38,19 @@ where P: Protocol, SP: SessionParameters, { - pub(crate) fn new_protocol_error( + pub(crate) fn new_provable_error( + format: &BoxedFormat, verifier: &SP::Verifier, echo_broadcast: SignedMessagePart, normal_broadcast: SignedMessagePart, direct_message: SignedMessagePart, - error: P::ProtocolError, + error: BoxedProtocolError, transcript: &Transcript, ) -> Result { let required_messages = error.required_messages(); + let round_num = echo_broadcast.metadata().round_id().round_num(); + let echo_broadcast = if required_messages.this_round.echo_broadcast { Some(echo_broadcast) } else { @@ -115,23 +70,23 @@ where let mut echo_broadcasts = BTreeMap::new(); let mut normal_broadcasts = BTreeMap::new(); let mut direct_messages = BTreeMap::new(); - if let Some(previous_rounds) = required_messages.previous_rounds { + if let Some(previous_rounds) = &required_messages.previous_rounds { for (round_id, required) in previous_rounds { if required.echo_broadcast { - echo_broadcasts.insert(round_id.clone(), transcript.get_echo_broadcast(&round_id, verifier)?); + echo_broadcasts.insert(round_id.clone(), transcript.get_echo_broadcast(round_id, verifier)?); } if required.normal_broadcast { - normal_broadcasts.insert(round_id.clone(), transcript.get_normal_broadcast(&round_id, verifier)?); + normal_broadcasts.insert(round_id.clone(), transcript.get_normal_broadcast(round_id, verifier)?); } if required.direct_message { - direct_messages.insert(round_id.clone(), transcript.get_direct_message(&round_id, verifier)?); + direct_messages.insert(round_id.clone(), transcript.get_direct_message(round_id, verifier)?); } } } let mut echo_hashes = BTreeMap::new(); let mut other_echo_broadcasts = BTreeMap::new(); - if let Some(required_combined_echos) = required_messages.combined_echos { + if let Some(required_combined_echos) = &required_messages.combined_echos { for round_id in required_combined_echos { echo_hashes.insert( round_id.clone(), @@ -139,27 +94,31 @@ where ); other_echo_broadcasts.insert( round_id.clone(), - transcript.get_other_echo_broadcasts(&round_id, verifier)?.into(), + transcript.get_other_echo_broadcasts(round_id, verifier)?.into(), ); } } - let description = format!("Protocol error: {error}"); + let description = format!("Protocol error (Round {round_num}): {}", error.as_ref().description()); Ok(Self { guilty_party: verifier.clone(), description, - evidence: EvidenceEnum::Protocol(ProtocolEvidence { - error, - direct_message, - echo_broadcast, - normal_broadcast, - direct_messages: direct_messages.into(), - echo_broadcasts: echo_broadcasts.into(), - normal_broadcasts: normal_broadcasts.into(), - other_echo_broadcasts: other_echo_broadcasts.into(), - echo_hashes: echo_hashes.into(), - }), + evidence: EvidenceEnum::Protocol( + ProtocolEvidence { + error: error.into_boxed().serialize(format)?, + direct_message, + echo_broadcast, + normal_broadcast, + direct_messages: direct_messages.into(), + echo_broadcasts: echo_broadcasts.into(), + normal_broadcasts: normal_broadcasts.into(), + other_echo_broadcasts: other_echo_broadcasts.into(), + echo_hashes: echo_hashes.into(), + phantom: PhantomData, + } + .into(), + ), }) } @@ -168,27 +127,22 @@ where normal_broadcast: SignedMessagePart, error: EchoRoundError, ) -> Result { - let description = format!("Echo round error: {}", error.description()); match error { - EchoRoundError::InvalidEcho(from) => Ok(Self { + EchoRoundError::InvalidEcho(error) => Ok(Self { guilty_party: verifier.clone(), - description, - evidence: EvidenceEnum::InvalidEchoPack(InvalidEchoPackEvidence { + description: error.description(), + evidence: EvidenceEnum::InvalidEcho(InvalidEchoEvidence { normal_broadcast, - invalid_echo_sender: from, + error, }), }), - EchoRoundError::MismatchedBroadcasts { - guilty_party, - we_received, - echoed_to_us, - } => Ok(Self { - guilty_party, - description, - evidence: EvidenceEnum::MismatchedBroadcasts(MismatchedBroadcastsEvidence { - we_received, - echoed_to_us, - }), + EchoRoundError::MismatchedBroadcasts(error) => Ok(Self { + // Note that this is an unusual case: the guilty party is not + // the sender of the message that triggered the error, + // but a different party (specified in the error itself). + guilty_party: error.guilty_party().clone(), + description: error.description(), + evidence: EvidenceEnum::MismatchedBroadcasts(MismatchedBroadcastsEvidence { error }), }), } } @@ -243,98 +197,65 @@ where /// to prove the malicious behavior of [`Self::guilty_party`]. /// /// Returns `Ok(())` if it is the case. - pub fn verify( - &self, - associated_data: &>::AssociatedData, - ) -> Result<(), EvidenceError> { + pub fn verify(&self, shared_data: &P::SharedData) -> Result<(), EvidenceError> { let format = BoxedFormat::new::(); match &self.evidence { - EvidenceEnum::Protocol(evidence) => evidence.verify::(&self.guilty_party, &format, associated_data), + EvidenceEnum::Protocol(evidence) => evidence.verify::(&self.guilty_party, &format, shared_data), EvidenceEnum::InvalidDirectMessage(evidence) => evidence.verify::(&self.guilty_party, &format), EvidenceEnum::InvalidEchoBroadcast(evidence) => evidence.verify::(&self.guilty_party, &format), EvidenceEnum::InvalidNormalBroadcast(evidence) => evidence.verify::(&self.guilty_party, &format), - EvidenceEnum::InvalidEchoPack(evidence) => evidence.verify(&self.guilty_party, &format), - EvidenceEnum::MismatchedBroadcasts(evidence) => evidence.verify::(&self.guilty_party), + EvidenceEnum::InvalidEcho(evidence) => evidence.verify::(&self.guilty_party, &format), + EvidenceEnum::MismatchedBroadcasts(evidence) => evidence.verify::(), } } } #[derive_where::derive_where(Debug)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] enum EvidenceEnum, SP: SessionParameters> { - Protocol(ProtocolEvidence), + Protocol(Box>), InvalidDirectMessage(InvalidDirectMessageEvidence), InvalidEchoBroadcast(InvalidEchoBroadcastEvidence), InvalidNormalBroadcast(InvalidNormalBroadcastEvidence), - InvalidEchoPack(InvalidEchoPackEvidence), - MismatchedBroadcasts(MismatchedBroadcastsEvidence), + InvalidEcho(InvalidEchoEvidence), + MismatchedBroadcasts(MismatchedBroadcastsEvidence), } -#[derive_where::derive_where(Debug)] -#[derive(Clone, Serialize, Deserialize)] -pub struct InvalidEchoPackEvidence { +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InvalidEchoEvidence { normal_broadcast: SignedMessagePart, - invalid_echo_sender: SP::Verifier, + error: InvalidEchoError, } -impl InvalidEchoPackEvidence -where - SP: SessionParameters, -{ - fn verify(&self, verifier: &SP::Verifier, format: &BoxedFormat) -> Result<(), EvidenceError> { - let verified = self.normal_broadcast.clone().verify::(verifier)?; - let deserialized = verified.payload().deserialize::>(format)?; - let invalid_echo = deserialized - .message_hashes - .get(&self.invalid_echo_sender) - .ok_or_else(|| { - EvidenceError::InvalidEvidence(format!( - "Did not find {:?} in the attached message", - self.invalid_echo_sender - )) +impl InvalidEchoEvidence { + fn verify>( + &self, + verifier: &SP::Verifier, + format: &BoxedFormat, + ) -> Result<(), EvidenceError> { + let verified = self + .normal_broadcast + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)?; + let deserialized = verified + .payload() + .deserialize::>(format) + .map_err(|error| { + EvidenceError::InvalidEvidence(format!("Failed to deserialize normal broadcast: {error:?}")) })?; - - let verified_echo = match invalid_echo.clone().verify::(&self.invalid_echo_sender) { - Ok(echo) => echo, - Err(MessageVerificationError::Local(error)) => return Err(EvidenceError::Local(error)), - // The message was indeed incorrectly signed - fault proven - Err(MessageVerificationError::InvalidSignature) => return Ok(()), - Err(MessageVerificationError::SignatureMismatch) => return Ok(()), - }; - - // `from` sent us a correctly signed message but from another round or another session. - // Provable fault of `from`. - if verified_echo.metadata() != self.normal_broadcast.metadata() { - return Ok(()); - } - - Err(EvidenceError::InvalidEvidence( - "There is nothing wrong with the echoed message".into(), - )) + self.error.verify_evidence::(verified.metadata(), &deserialized) } } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MismatchedBroadcastsEvidence { - we_received: SignedMessagePart, - echoed_to_us: SignedMessageHash, +pub struct MismatchedBroadcastsEvidence { + error: MismatchedBroadcastsError, } -impl MismatchedBroadcastsEvidence { - fn verify(&self, verifier: &SP::Verifier) -> Result<(), EvidenceError> - where - SP: SessionParameters, - { - let we_received = self.we_received.clone().verify::(verifier)?; - let echoed_to_us = self.echoed_to_us.clone().verify::(verifier)?; - - if we_received.metadata() == echoed_to_us.metadata() && !echoed_to_us.is_hash_of::(&self.we_received) { - Ok(()) - } else { - Err(EvidenceError::InvalidEvidence( - "The attached messages don't constitute malicious behavior".into(), - )) - } +impl MismatchedBroadcastsEvidence { + fn verify>(&self) -> Result<(), EvidenceError> { + self.error.verify_evidence::() } } @@ -347,17 +268,20 @@ impl InvalidDirectMessageEvidence { P: Protocol, SP: SessionParameters, { - let verified_direct_message = self.0.clone().verify::(verifier)?; + let verified_direct_message = self + .0 + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)?; let payload = verified_direct_message.payload(); if self.0.metadata().round_id().is_echo() { - Ok(EchoRound::::verify_direct_message_is_invalid(payload)?) + Ok(EchoRound::::verify_direct_message_is_invalid(format, payload)?) } else { - Ok(P::verify_direct_message_is_invalid( - format, - self.0.metadata().round_id(), - payload, - )?) + let round_id = self.0.metadata().round_id(); + let round_info = P::round_info(round_id) + .ok_or_else(|| EvidenceError::InvalidEvidence(format!("{round_id} is not in the protocol")))?; + round_info.as_ref().verify_direct_message_is_invalid(format, payload) } } } @@ -371,17 +295,20 @@ impl InvalidEchoBroadcastEvidence { P: Protocol, SP: SessionParameters, { - let verified_echo_broadcast = self.0.clone().verify::(verifier)?; + let verified_echo_broadcast = self + .0 + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)?; let payload = verified_echo_broadcast.payload(); if self.0.metadata().round_id().is_echo() { - Ok(EchoRound::::verify_echo_broadcast_is_invalid(payload)?) + Ok(EchoRound::::verify_echo_broadcast_is_invalid(format, payload)?) } else { - Ok(P::verify_echo_broadcast_is_invalid( - format, - self.0.metadata().round_id(), - payload, - )?) + let round_id = self.0.metadata().round_id(); + let round_info = P::round_info(round_id) + .ok_or_else(|| EvidenceError::InvalidEvidence(format!("{round_id} is not in the protocol")))?; + round_info.as_ref().verify_echo_broadcast_is_invalid(format, payload) } } } @@ -395,25 +322,28 @@ impl InvalidNormalBroadcastEvidence { P: Protocol, SP: SessionParameters, { - let verified_normal_broadcast = self.0.clone().verify::(verifier)?; + let verified_normal_broadcast = self + .0 + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)?; let payload = verified_normal_broadcast.payload(); if self.0.metadata().round_id().is_echo() { Ok(EchoRound::::verify_normal_broadcast_is_invalid(format, payload)?) } else { - Ok(P::verify_normal_broadcast_is_invalid( - format, - self.0.metadata().round_id(), - payload, - )?) + let round_id = self.0.metadata().round_id(); + let round_info = P::round_info(round_id) + .ok_or_else(|| EvidenceError::InvalidEvidence(format!("{round_id} is not in the protocol")))?; + round_info.as_ref().verify_normal_broadcast_is_invalid(format, payload) } } } #[derive_where::derive_where(Debug)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] struct ProtocolEvidence> { - error: P::ProtocolError, + error: SerializedProtocolError, direct_message: Option>, echo_broadcast: Option>, normal_broadcast: Option>, @@ -422,6 +352,7 @@ struct ProtocolEvidence> { normal_broadcasts: SerializableMap>, other_echo_broadcasts: SerializableMap>>, echo_hashes: SerializableMap>, + phantom: PhantomData P>, } fn verify_message_parts( @@ -435,7 +366,10 @@ where { let mut verified_parts = BTreeMap::new(); for (round_id, message_part) in message_parts.iter() { - let verified = message_part.clone().verify::(verifier)?; + let verified = message_part + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)?; let metadata = verified.metadata(); if metadata.session_id() != expected_session_id || metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( @@ -452,7 +386,7 @@ fn verify_message_part( expected_session_id: &SessionId, expected_round_id: &RoundId, message_part: &Option>, -) -> Result +) -> Result, EvidenceError> where SP: SessionParameters, T: Clone + ProtocolMessagePartHashable, @@ -464,9 +398,15 @@ where "Invalid attached message metadata".into(), )); } - message_part.clone().verify::(verifier)?.into_payload() + Some( + message_part + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)? + .into_payload(), + ) } else { - T::none() + None }; Ok(verified_part) @@ -474,14 +414,14 @@ where impl ProtocolEvidence where - Id: Debug + Clone + Ord, + Id: PartyId, P: Protocol, { fn verify( &self, verifier: &SP::Verifier, format: &BoxedFormat, - associated_data: &>::AssociatedData, + shared_data: &P::SharedData, ) -> Result<(), EvidenceError> where SP: SessionParameters, @@ -525,10 +465,16 @@ where )); } - let verified_echo_hashes = echo_hashes.clone().verify::(verifier)?; + let verified_echo_hashes = echo_hashes + .clone() + .verify::(verifier) + .map_err(MessageVerificationError::into_evidence_error)?; let echo_round_payload = verified_echo_hashes .payload() - .deserialize::>(format)?; + .deserialize::>(format) + .map_err(|error| { + EvidenceError::InvalidEvidence(format!("Failed to deserialize normal broadcast: {error:?}")) + })?; let signed_echo_broadcasts = self .other_echo_broadcasts @@ -542,7 +488,10 @@ where return Err(EvidenceError::InvalidEvidence("Invalid echo hash metadata".into())); } - let verified_echo_hash = echo_hash.clone().verify::(other_verifier)?; + let verified_echo_hash = echo_hash + .clone() + .verify::(other_verifier) + .map_err(MessageVerificationError::into_evidence_error)?; let echo_broadcast = signed_echo_broadcasts.get(other_verifier).ok_or_else(|| { EvidenceError::InvalidEvidence(format!("Missing {round_id} echo broadcast from {other_verifier:?}")) @@ -559,7 +508,10 @@ where )); } - let verified_echo_broadcast = echo_broadcast.clone().verify::(other_verifier)?; + let verified_echo_broadcast = echo_broadcast + .clone() + .verify::(other_verifier) + .map_err(MessageVerificationError::into_evidence_error)?; echo_messages.insert(other_verifier.clone(), verified_echo_broadcast.into_payload()); } @@ -568,7 +520,7 @@ where // Merge message parts - let protocol_message = ProtocolMessage { + let message_parts = EvidenceProtocolMessage { echo_broadcast, normal_broadcast, direct_message, @@ -583,10 +535,10 @@ where let mut previous_messages = BTreeMap::new(); for round_id in all_rounds { - let echo_broadcast = echo_broadcasts.remove(&round_id).unwrap_or(EchoBroadcast::none()); - let normal_broadcast = normal_broadcasts.remove(&round_id).unwrap_or(NormalBroadcast::none()); - let direct_message = direct_messages.remove(&round_id).unwrap_or(DirectMessage::none()); - let protocol_message = ProtocolMessage { + let echo_broadcast = echo_broadcasts.remove(&round_id); + let normal_broadcast = normal_broadcasts.remove(&round_id); + let direct_message = direct_messages.remove(&round_id); + let protocol_message = EvidenceProtocolMessage { echo_broadcast, normal_broadcast, direct_message, @@ -594,14 +546,20 @@ where previous_messages.insert(round_id, protocol_message); } - Ok(self.error.verify_messages_constitute_error( + let round_info = P::round_info(round_id).ok_or_else(|| { + EvidenceError::InvalidEvidence(format!("The round {round_id} is not present in the protocol")) + })?; + + round_info.as_ref().verify_evidence( + round_id, format, + &self.error, verifier, session_id.as_ref(), - associated_data, - protocol_message, + shared_data, + message_parts, previous_messages, combined_echos, - )?) + ) } } diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index 4014eae..2c38edc 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -8,10 +8,12 @@ use signature::{DigestVerifier, RandomizedDigestSigner}; use super::{ session::{SessionId, SessionParameters}, - wire_format::WireFormat, + wire_format::{deserialize, WireFormat}, LocalError, }; -use crate::protocol::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePartHashable, RoundId}; +use crate::protocol::{ + DirectMessage, EchoBroadcast, EvidenceError, NormalBroadcast, ProtocolMessagePartHashable, RoundId, +}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] struct SerializedSignature(#[serde(with = "SliceLike::")] Box<[u8]>); @@ -28,7 +30,7 @@ impl SerializedSignature { where SP: SessionParameters, { - SP::WireFormat::deserialize::(&self.0).map_err(|_| MessageVerificationError::InvalidSignature) + deserialize::(&self.0).map_err(|_| MessageVerificationError::InvalidSignature) } } @@ -41,6 +43,22 @@ pub(crate) enum MessageVerificationError { SignatureMismatch, } +impl MessageVerificationError { + // This is not a `From` implementation since in other contexts (e.g. echo round) we need a different behavior, + // and `From` impl could be accidentally used there leading to errors. + pub fn into_evidence_error(self) -> EvidenceError { + match self { + MessageVerificationError::Local(error) => EvidenceError::Local(error), + MessageVerificationError::InvalidSignature => { + EvidenceError::InvalidEvidence("Invalid message signature".into()) + } + MessageVerificationError::SignatureMismatch => { + EvidenceError::InvalidEvidence("The signature does not match the payload".into()) + } + } + } +} + impl From for MessageVerificationError { fn from(source: LocalError) -> Self { Self::Local(source) @@ -130,7 +148,7 @@ where let digest = message_with_metadata.digest::()?; let signature = signer .try_sign_digest_with_rng(rng, digest) - .map_err(|err| LocalError::new(format!("Failed to sign: {:?}", err)))?; + .map_err(|err| LocalError::new(format!("Failed to sign: {err:?}")))?; Ok(Self { signature: SerializedSignature::new::(signature)?, message_with_metadata, diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 7c77615..4217887 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -23,9 +23,9 @@ use super::{ LocalError, RemoteError, }; use crate::protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EchoRoundParticipation, - EntryPoint, FinalizeOutcome, NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessage, ProtocolMessagePart, - ReceiveError, ReceiveErrorType, RoundId, TransitionInfo, + Artifact, BoxedFormat, BoxedReceiveError, BoxedRound, CommunicationInfo, DirectMessage, DynProtocolMessage, + EchoBroadcast, EchoRoundParticipation, EntryPoint, FinalizeOutcome, NormalBroadcast, PartyId, Payload, Protocol, + ProtocolMessagePart, RoundId, TransitionInfo, }; /// A set of types needed to execute a session. @@ -129,7 +129,7 @@ pub enum RoundOutcome, SP: SessionParameters> { /// Transitioned to another round. AnotherRound { /// The session object for the new round. - session: Session, + session: Box>, /// The messages intended for the new round cached during the previous round. cached_messages: Vec>, }, @@ -166,15 +166,24 @@ where let verifier = signer.verifying_key(); let transition_info = round.as_ref().transition_info(); + let communication_info = round.as_ref().communication_info(); - let echo = round.as_ref().make_echo_broadcast(rng, &format)?; + // TODO (#4): we reuse `EchoBroadcast::none()` (that means `NoMessage` in the typed round) + // to have a second meaning, the node not sending messages at all. + let echo = if communication_info.message_destinations.is_empty() { + EchoBroadcast::none() + } else { + round.as_ref().make_echo_broadcast(rng, &format)? + }; let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, &transition_info.id(), echo)?; - let normal = round.as_ref().make_normal_broadcast(rng, &format)?; + let normal = if communication_info.message_destinations.is_empty() { + NormalBroadcast::none() + } else { + round.as_ref().make_normal_broadcast(rng, &format)? + }; let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, &transition_info.id(), normal)?; - let communication_info = round.as_ref().communication_info(); - let round_sends_echo_broadcast = !echo_broadcast.payload().is_none(); let echo_round_info = match &communication_info.echo_round_participation { EchoRoundParticipation::Default => { @@ -387,8 +396,8 @@ where /// Processes a verified message. /// /// This can be called in a spawned task if it is known to take a long time. - pub fn process_message(&self, message: VerifiedMessage) -> ProcessedMessage { - let protocol_message = ProtocolMessage { + pub fn process_message(&self, message: VerifiedMessage) -> ProcessedMessage { + let protocol_message = DynProtocolMessage { echo_broadcast: message.echo_broadcast().clone(), normal_broadcast: message.normal_broadcast().clone(), direct_message: message.direct_message().clone(), @@ -406,9 +415,9 @@ where pub fn add_processed_message( &self, accum: &mut RoundAccumulator, - processed: ProcessedMessage, + processed: ProcessedMessage, ) -> Result<(), LocalError> { - accum.add_processed_message(&self.transcript, processed) + accum.add_processed_message(&self.format, &self.transcript, processed) } /// Makes an accumulator for a new round. @@ -484,16 +493,16 @@ where accum.payloads, accum.artifacts, )); - let cached_messages = filter_messages(accum.cached, &round.id()); + let cached_messages = filter_messages(accum.cached, &round.as_ref().transition_info().id); let session = Session::new_for_next_round(rng, self.session_id, self.signer, self.format, round, transcript)?; return Ok(RoundOutcome::AnotherRound { - session, + session: session.into(), cached_messages, }); } - match self.round.into_boxed().finalize(rng, accum.payloads, accum.artifacts)? { + match self.round.into_inner().finalize(rng, accum.payloads, accum.artifacts)? { FinalizeOutcome::Result(result) => Ok(RoundOutcome::Finished(SessionReport::new( SessionOutcome::Result(result), transcript, @@ -502,7 +511,7 @@ where let round_id = round.as_ref().transition_info().id(); // Protecting against common bugs if !self.transition_info.children.contains(&round_id) { - return Err(LocalError::new(format!("Unexpected next round id: {:?}", round_id))); + return Err(LocalError::new(format!("Unexpected next round id: {round_id:?}"))); } // These messages could have been cached before @@ -518,7 +527,7 @@ where Session::new_for_next_round(rng, self.session_id, self.signer, self.format, round, transcript)?; Ok(RoundOutcome::AnotherRound { cached_messages, - session, + session: session.into(), }) } } @@ -612,8 +621,7 @@ where fn register_unprovable_error(&mut self, from: &SP::Verifier, error: RemoteError) -> Result<(), LocalError> { if self.unprovable_errors.insert(from.clone(), error).is_some() { Err(LocalError::new(format!( - "An unprovable error for {:?} is already registered", - from + "An unprovable error for {from:?} is already registered", ))) } else { Ok(()) @@ -623,8 +631,7 @@ where fn register_provable_error(&mut self, from: &SP::Verifier, evidence: Evidence) -> Result<(), LocalError> { if self.provable_errors.insert(from.clone(), evidence).is_some() { Err(LocalError::new(format!( - "A provable error for {:?} is already registered", - from + "A provable error for {from:?} is already registered", ))) } else { Ok(()) @@ -645,12 +652,11 @@ where // Add a processed artifact to the accumulator. // Returns an error if the artifact was already present. fn add_artifact(&mut self, processed: ProcessedArtifact) -> Result<(), LocalError> { - let artifact = match processed.artifact { - Some(artifact) => artifact, - None => return Ok(()), - }; - - if self.artifacts.insert(processed.destination.clone(), artifact).is_some() { + if self + .artifacts + .insert(processed.destination.clone(), processed.artifact) + .is_some() + { return Err(LocalError::new(format!( "Artifact for destination {:?} has already been recorded", processed.destination @@ -661,8 +667,9 @@ where fn add_processed_message( &mut self, + format: &BoxedFormat, transcript: &Transcript, - processed: ProcessedMessage, + processed: ProcessedMessage, ) -> Result<(), LocalError> { if self.payloads.contains_key(processed.message.from()) { return Err(LocalError::new(format!( @@ -675,8 +682,7 @@ where if !self.still_have_not_sent_messages.remove(&from) { return Err(LocalError::new(format!( - "Expected {:?} to not be in the list of expected messages", - from + "Expected {from:?} to not be in the list of expected messages", ))); } let error = match processed.processed { @@ -698,44 +704,45 @@ where Err(error) => error, }; - match error.0 { - ReceiveErrorType::InvalidDirectMessage(error) => { + match error { + BoxedReceiveError::InvalidDirectMessage(error) => { let (_echo_broadcast, _normal_broadcast, direct_message) = processed.message.into_parts(); let evidence = Evidence::new_invalid_direct_message(&from, direct_message, error); self.register_provable_error(&from, evidence) } - ReceiveErrorType::InvalidEchoBroadcast(error) => { + BoxedReceiveError::InvalidEchoBroadcast(error) => { let (echo_broadcast, _normal_broadcast, _direct_message) = processed.message.into_parts(); let evidence = Evidence::new_invalid_echo_broadcast(&from, echo_broadcast, error); self.register_provable_error(&from, evidence) } - ReceiveErrorType::InvalidNormalBroadcast(error) => { + BoxedReceiveError::InvalidNormalBroadcast(error) => { let (_echo_broadcast, normal_broadcast, _direct_message) = processed.message.into_parts(); let evidence = Evidence::new_invalid_normal_broadcast(&from, normal_broadcast, error); self.register_provable_error(&from, evidence) } - ReceiveErrorType::Protocol(error) => { + BoxedReceiveError::Protocol(boxed_error) => { let (echo_broadcast, normal_broadcast, direct_message) = processed.message.into_parts(); - let evidence = Evidence::new_protocol_error( + let evidence = Evidence::new_provable_error( + format, &from, echo_broadcast, normal_broadcast, direct_message, - error, + boxed_error, transcript, )?; self.register_provable_error(&from, evidence) } - ReceiveErrorType::Unprovable(error) => { + BoxedReceiveError::Unprovable(error) => { self.unprovable_errors.insert(from.clone(), error); Ok(()) } - ReceiveErrorType::Echo(error) => { + BoxedReceiveError::Echo(error) => { let (_echo_broadcast, normal_broadcast, _direct_message) = processed.message.into_parts(); let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, *error)?; self.register_provable_error(&from, evidence) } - ReceiveErrorType::Local(error) => Err(error), + BoxedReceiveError::Local(error) => Err(error), } } @@ -745,8 +752,7 @@ where let cached = self.cached.entry(from.clone()).or_default(); if cached.insert(round_id.clone(), message).is_some() { return Err(LocalError::new(format!( - "A message from for {:?} has already been cached", - round_id + "A message from for {round_id:?} has already been cached", ))); } Ok(()) @@ -756,13 +762,13 @@ where #[derive(Debug)] pub struct ProcessedArtifact { destination: SP::Verifier, - artifact: Option, + artifact: Artifact, } #[derive(Debug)] -pub struct ProcessedMessage, SP: SessionParameters> { +pub struct ProcessedMessage { message: VerifiedMessage, - processed: Result>, + processed: Result>, } /// The result of preprocessing an incoming message. @@ -813,12 +819,10 @@ fn filter_messages( mod tests { use impls::impls; - use super::{ - BoxedFormat, Message, ProcessedArtifact, ProcessedMessage, RoundId, Session, SessionParameters, VerifiedMessage, - }; + use super::{Message, ProcessedArtifact, ProcessedMessage, Session, VerifiedMessage}; use crate::{ dev::{BinaryFormat, TestSessionParams, TestVerifier}, - protocol::{DirectMessage, EchoBroadcast, MessageValidationError, NoProtocolErrors, NormalBroadcast, Protocol}, + protocol::{Protocol, RoundId, RoundInfo}, }; #[test] @@ -835,31 +839,10 @@ mod tests { struct DummyProtocol; - impl Protocol<::Verifier> for DummyProtocol { + impl Protocol for DummyProtocol { type Result = (); - type ProtocolError = NoProtocolErrors; - - fn verify_direct_message_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_echo_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_normal_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { + type SharedData = (); + fn round_info(_round_id: &RoundId) -> Option> { unimplemented!() } } @@ -875,6 +858,6 @@ mod tests { assert!(impls!(Message: Send)); assert!(impls!(ProcessedArtifact: Send)); assert!(impls!(VerifiedMessage: Send)); - assert!(impls!(ProcessedMessage: Send)); + assert!(impls!(ProcessedMessage: Send)); } } diff --git a/manul/src/session/tokio.rs b/manul/src/session/tokio.rs index 8ba2f37..34bcd31 100644 --- a/manul/src/session/tokio.rs +++ b/manul/src/session/tokio.rs @@ -173,7 +173,7 @@ where session: new_session, cached_messages: new_cached_messages, } => { - session = new_session; + session = *new_session; cached_messages = new_cached_messages; } } @@ -198,7 +198,6 @@ where P: Protocol, SP: SessionParameters, ::Signer: Send + Sync, -

>::ProtocolError: Send + Sync, { let mut session = Arc::new(session); // Some rounds can finalize early and put off sending messages to the next round. Such messages @@ -221,7 +220,7 @@ where loop { debug!("{my_id}: *** starting round {:?} ***", session.round_id()); - let (processed_tx, mut processed_rx) = mpsc::channel::>(100); + let (processed_tx, mut processed_rx) = mpsc::channel::>(100); let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<(MessageOut, ProcessedArtifact)>(100); // This is kept in the main task since it's mutable, @@ -384,7 +383,7 @@ where session: new_session, cached_messages: new_cached_messages, } => { - session = Arc::new(new_session); + session = Arc::new(*new_session); cached_messages = new_cached_messages; } } diff --git a/manul/src/session/transcript.rs b/manul/src/session/transcript.rs index 265d44b..c8ac223 100644 --- a/manul/src/session/transcript.rs +++ b/manul/src/session/transcript.rs @@ -268,7 +268,7 @@ where let errors = self .unprovable_errors .iter() - .map(|(id, error)| format!(" {:?}: {}", id, error)) + .map(|(id, error)| format!(" {id:?}: {error}")) .collect::>(); format!("\nUnprovable errors:\n{}", errors.join("\n")) } else { @@ -279,7 +279,7 @@ where let faulty_parties = self .missing_messages .iter() - .map(|(round_id, parties)| format!(" {}: {:?}", round_id, parties)) + .map(|(round_id, parties)| format!(" {round_id}: {parties:?}")) .collect::>(); format!("\nMissing messages:\n{}", faulty_parties.join("\n")) } else { diff --git a/manul/src/session/wire_format.rs b/manul/src/session/wire_format.rs index 9675b2b..a528040 100644 --- a/manul/src/session/wire_format.rs +++ b/manul/src/session/wire_format.rs @@ -1,9 +1,27 @@ -use alloc::{boxed::Box, format}; +use alloc::{boxed::Box, format, string::String}; use core::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::protocol::{DeserializationError, LocalError}; +use crate::protocol::LocalError; + +/// An error that can be returned during deserialization error. +#[derive(displaydoc::Display, Debug, Clone)] +#[displaydoc("Error deserializing into {target_type}: {message}")] +pub struct DeserializationError { + target_type: String, + message: String, +} + +impl DeserializationError { + /// Creates a new deserialization error. + pub fn new(message: impl Into) -> Self { + Self { + target_type: core::any::type_name::().into(), + message: message.into(), + } + } +} /* Why the asymmetry between serialization and deserialization? @@ -15,7 +33,7 @@ and it's tricky to write a similar persistent wrapper as we do for the deseriali (see https://github.com/fjarri/serde-persistent-deserializer/issues/2). So for serialization we have to instead type-erase the value itself and pass it somewhere -where the serializer type is known (`ObjectSafeSerializer::serialize()` impl); +where the serializer type is known (`DynSerializer::serialize()` impl); but for the deserialization we instead type-erase the deserializer and pass it somewhere the type of the target value is known (`Deserializer::deserialize()`). @@ -36,10 +54,12 @@ pub trait WireFormat: 'static + Debug { fn deserializer(bytes: &[u8]) -> Self::Deserializer<'_>; // A helper method for use on the session level when both `WireFormat` and `T` are known at the same point. +} - /// Deserializes the given bytestring into `T`. - fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result { - let deserializer = Self::deserializer(bytes); - T::deserialize(deserializer).map_err(|err| DeserializationError::new(format!("Deserialization error: {err:?}"))) - } +/// Deserializes the given bytestring into `T`. +pub(crate) fn deserialize<'de, F: WireFormat, T: Deserialize<'de>>( + bytes: &'de [u8], +) -> Result { + let deserializer = F::deserializer(bytes); + T::deserialize(deserializer).map_err(|err| DeserializationError::new::(format!("{err:?}"))) } diff --git a/manul/src/tests/partial_echo.rs b/manul/src/tests/partial_echo.rs index 82e1bb5..d4c2f3d 100644 --- a/manul/src/tests/partial_echo.rs +++ b/manul/src/tests/partial_echo.rs @@ -1,10 +1,9 @@ use alloc::{ - boxed::Box, collections::{BTreeMap, BTreeSet}, vec, vec::Vec, }; -use core::{fmt::Debug, marker::PhantomData}; +use core::fmt::Debug; use rand_core::{CryptoRngCore, OsRng}; use serde::{Deserialize, Serialize}; @@ -12,42 +11,25 @@ use serde::{Deserialize, Serialize}; use crate::{ dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, protocol::{ - Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EchoRoundParticipation, - EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors, NormalBroadcast, PartyId, - Payload, Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round, RoundId, TransitionInfo, + BoxedRound, CommunicationInfo, EchoRoundParticipation, EntryPoint, FinalizeOutcome, LocalError, NoArtifact, + NoMessage, NoProtocolErrors, PartyId, Protocol, ProtocolMessage, ReceiveError, Round, RoundId, RoundInfo, + TransitionInfo, }, signature::Keypair, }; #[derive(Debug)] -struct PartialEchoProtocol(PhantomData); +struct PartialEchoProtocol; -impl Protocol for PartialEchoProtocol { +impl Protocol for PartialEchoProtocol { type Result = (); - type ProtocolError = NoProtocolErrors; - - fn verify_direct_message_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &DirectMessage, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } - - fn verify_echo_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &EchoBroadcast, - ) -> Result<(), MessageValidationError> { - unimplemented!() - } + type SharedData = (); - fn verify_normal_broadcast_is_invalid( - _format: &BoxedFormat, - _round_id: &RoundId, - _message: &NormalBroadcast, - ) -> Result<(), MessageValidationError> { - unimplemented!() + fn round_info(round_id: &RoundId) -> Option> { + match round_id { + round_id if round_id == &RoundId::new(1) => Some(RoundInfo::new::>()), + _ => None, + } } } @@ -70,7 +52,7 @@ struct Round1Echo { } impl Deserialize<'de>> EntryPoint for Inputs { - type Protocol = PartialEchoProtocol; + type Protocol = PartialEchoProtocol; fn entry_round_id() -> RoundId { 1.into() @@ -78,16 +60,24 @@ impl Deserialize<'de>> EntryPoint for Inp fn make_round( self, - _rng: &mut dyn CryptoRngCore, + _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], _id: &Id, ) -> Result, LocalError> { - Ok(BoxedRound::new_dynamic(Round1 { inputs: self })) + Ok(BoxedRound::new(Round1 { inputs: self })) } } impl Deserialize<'de>> Round for Round1 { - type Protocol = PartialEchoProtocol; + type Protocol = PartialEchoProtocol; + type ProtocolError = NoProtocolErrors; + + type DirectMessage = NoMessage; + type NormalBroadcast = NoMessage; + type EchoBroadcast = Round1Echo; + + type Payload = (); + type Artifact = NoArtifact; fn transition_info(&self) -> TransitionInfo { TransitionInfo::new_linear_terminating(1) @@ -101,48 +91,40 @@ impl Deserialize<'de>> Round for Round1 Result { + fn make_echo_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Result { if self.inputs.message_destinations.is_empty() { - Ok(EchoBroadcast::none()) + // TODO (#4): this branch is unreachable in the absense of bugs in the code + // (the method will not be called in the first place if the node does not send messages). + // Can it be eliminated using the type system? + Err(LocalError::new("This node does not send messages in this round")) } else { - EchoBroadcast::new( - format, - Round1Echo { - sender: self.inputs.id.clone(), - }, - ) + Ok(Round1Echo { + sender: self.inputs.id.clone(), + }) } } fn receive_message( &self, - format: &BoxedFormat, from: &Id, - message: ProtocolMessage, - ) -> Result> { - message.normal_broadcast.assert_is_none()?; - message.direct_message.assert_is_none()?; - + message: ProtocolMessage, + ) -> Result> { if self.inputs.expecting_messages_from.is_empty() { - message.echo_broadcast.assert_is_none()?; + panic!("Message received when none was expected, this would be a provable offense"); } else { - let echo = message.echo_broadcast.deserialize::>(format)?; + let echo = message.echo_broadcast; assert_eq!(&echo.sender, from); assert!(self.inputs.expecting_messages_from.contains(from)); } - Ok(Payload::new(())) + Ok(()) } fn finalize( - self: Box, - _rng: &mut dyn CryptoRngCore, - _payloads: BTreeMap, - _artifacts: BTreeMap, + self, + _rng: &mut impl CryptoRngCore, + _payloads: BTreeMap, + _artifacts: BTreeMap, ) -> Result, LocalError> { Ok(FinalizeOutcome::Result(())) } diff --git a/manul/src/utils.rs b/manul/src/utils.rs index 856394a..d1bdf71 100644 --- a/manul/src/utils.rs +++ b/manul/src/utils.rs @@ -1,5 +1,7 @@ //! Assorted utilities. mod serializable_map; +mod type_id; pub use serializable_map::SerializableMap; +pub(crate) use type_id::DynTypeId; diff --git a/manul/src/utils/type_id.rs b/manul/src/utils/type_id.rs new file mode 100644 index 0000000..24245ad --- /dev/null +++ b/manul/src/utils/type_id.rs @@ -0,0 +1,11 @@ +use core::any::TypeId; + +/// A dyn safe trait to get the type's ID. +pub(crate) trait DynTypeId: 'static { + /// Returns the type ID of the implementing type. + fn get_type_id(&self) -> TypeId { + TypeId::of::() + } +} + +impl DynTypeId for T {}