Skip to content

Commit 7226df2

Browse files
committed
Temporary: RoundInfo
1 parent 86ef2b0 commit 7226df2

File tree

5 files changed

+190
-50
lines changed

5 files changed

+190
-50
lines changed

manul/src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod errors;
1717
mod message;
1818
mod round;
1919
mod round_id;
20+
mod round_info;
2021
mod static_round;
2122

2223
pub use boxed_format::BoxedFormat;
@@ -31,6 +32,7 @@ pub use round::{
3132
Payload, Protocol, ProtocolError, RequiredMessageParts, RequiredMessages, Round,
3233
};
3334
pub use round_id::{RoundId, TransitionInfo};
35+
pub use round_info::BoxedRoundInfo;
3436
pub use static_round::{NoMessage, StaticProtocolMessage, StaticRound};
3537

3638
pub(crate) use errors::ReceiveErrorType;

manul/src/protocol/round.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use alloc::{
22
boxed::Box,
33
collections::{BTreeMap, BTreeSet},
44
format,
5+
vec::Vec,
56
};
67
use core::{
78
any::Any,
@@ -17,6 +18,7 @@ use super::{
1718
errors::{LocalError, MessageValidationError, ProtocolValidationError, ReceiveError},
1819
message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessage, ProtocolMessagePart},
1920
round_id::{RoundId, TransitionInfo},
21+
round_info::BoxedRoundInfo,
2022
};
2123

2224
/// Describes what other parties this rounds sends messages to, and what other parties it expects messages from.
@@ -73,6 +75,9 @@ pub trait Protocol<Id>: 'static {
7375
/// An object of this type will be returned when a provable error happens during [`Round::receive_message`].
7476
type ProtocolError: ProtocolError<Id>;
7577

78+
fn rounds() -> Vec<BoxedRoundInfo<Id>>;
79+
80+
// TODO: move out of `Protocol`. To `evidence.rs`, perhaps?
7681
/// Returns `Ok(())` if the given direct message cannot be deserialized
7782
/// assuming it is a direct message from the round `round_id`.
7883
///
@@ -82,7 +87,17 @@ pub trait Protocol<Id>: 'static {
8287
format: &BoxedFormat,
8388
round_id: &RoundId,
8489
message: &DirectMessage,
85-
) -> Result<(), MessageValidationError>;
90+
) -> Result<(), MessageValidationError> {
91+
let rounds = Self::rounds()
92+
.into_iter()
93+
.map(|r| {
94+
let rid = r.transition_info().id;
95+
(rid, r)
96+
})
97+
.collect::<BTreeMap<_, _>>();
98+
let round = rounds.get(round_id).unwrap();
99+
round.verify_direct_message_is_invalid(format, message)
100+
}
86101

87102
/// Returns `Ok(())` if the given echo broadcast cannot be deserialized
88103
/// assuming it is an echo broadcast from the round `round_id`.
@@ -93,7 +108,17 @@ pub trait Protocol<Id>: 'static {
93108
format: &BoxedFormat,
94109
round_id: &RoundId,
95110
message: &EchoBroadcast,
96-
) -> Result<(), MessageValidationError>;
111+
) -> Result<(), MessageValidationError> {
112+
let rounds = Self::rounds()
113+
.into_iter()
114+
.map(|r| {
115+
let rid = r.transition_info().id;
116+
(rid, r)
117+
})
118+
.collect::<BTreeMap<_, _>>();
119+
let round = rounds.get(round_id).unwrap();
120+
round.verify_echo_broadcast_is_invalid(format, message)
121+
}
97122

98123
/// Returns `Ok(())` if the given echo broadcast cannot be deserialized
99124
/// assuming it is an echo broadcast from the round `round_id`.
@@ -104,7 +129,17 @@ pub trait Protocol<Id>: 'static {
104129
format: &BoxedFormat,
105130
round_id: &RoundId,
106131
message: &NormalBroadcast,
107-
) -> Result<(), MessageValidationError>;
132+
) -> Result<(), MessageValidationError> {
133+
let rounds = Self::rounds()
134+
.into_iter()
135+
.map(|r| {
136+
let rid = r.transition_info().id;
137+
(rid, r)
138+
})
139+
.collect::<BTreeMap<_, _>>();
140+
let round = rounds.get(round_id).unwrap();
141+
round.verify_normal_broadcast_is_invalid(format, message)
142+
}
108143
}
109144

110145
/// Declares which parts of the message from a round have to be stored to serve as the evidence of malicious behavior.

manul/src/protocol/round_info.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#![allow(dead_code, unused_variables, missing_docs)]
2+
3+
use alloc::boxed::Box;
4+
use core::marker::PhantomData;
5+
6+
use super::{
7+
boxed_format::BoxedFormat,
8+
errors::MessageValidationError,
9+
message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart},
10+
round::PartyId,
11+
round_id::TransitionInfo,
12+
static_round::{NoMessage, StaticRound},
13+
};
14+
15+
trait RoundInfo<Id> {
16+
fn transition_info(&self) -> TransitionInfo;
17+
fn verify_direct_message_is_invalid(
18+
&self,
19+
format: &BoxedFormat,
20+
message: &DirectMessage,
21+
) -> Result<(), MessageValidationError>;
22+
fn verify_echo_broadcast_is_invalid(
23+
&self,
24+
format: &BoxedFormat,
25+
message: &EchoBroadcast,
26+
) -> Result<(), MessageValidationError>;
27+
fn verify_normal_broadcast_is_invalid(
28+
&self,
29+
format: &BoxedFormat,
30+
message: &NormalBroadcast,
31+
) -> Result<(), MessageValidationError>;
32+
}
33+
34+
pub(crate) struct StaticRoundInfoAdapter<R>(PhantomData<R>);
35+
36+
impl<Id, R> RoundInfo<Id> for StaticRoundInfoAdapter<R>
37+
where
38+
Id: PartyId,
39+
R: StaticRound<Id>,
40+
{
41+
fn transition_info(&self) -> TransitionInfo {
42+
R::transition_info()
43+
}
44+
45+
fn verify_direct_message_is_invalid(
46+
&self,
47+
format: &BoxedFormat,
48+
message: &DirectMessage,
49+
) -> Result<(), MessageValidationError> {
50+
if NoMessage::equals::<R::DirectMessage>() {
51+
message.verify_is_not::<R::DirectMessage>(format)
52+
} else {
53+
message.verify_is_some()
54+
}
55+
}
56+
57+
fn verify_echo_broadcast_is_invalid(
58+
&self,
59+
format: &BoxedFormat,
60+
message: &EchoBroadcast,
61+
) -> Result<(), MessageValidationError> {
62+
if NoMessage::equals::<R::EchoBroadcast>() {
63+
message.verify_is_not::<R::EchoBroadcast>(format)
64+
} else {
65+
message.verify_is_some()
66+
}
67+
}
68+
69+
fn verify_normal_broadcast_is_invalid(
70+
&self,
71+
format: &BoxedFormat,
72+
message: &NormalBroadcast,
73+
) -> Result<(), MessageValidationError> {
74+
if NoMessage::equals::<R::NormalBroadcast>() {
75+
message.verify_is_not::<R::NormalBroadcast>(format)
76+
} else {
77+
message.verify_is_some()
78+
}
79+
}
80+
}
81+
82+
pub struct BoxedRoundInfo<Id>(Box<dyn RoundInfo<Id>>);
83+
84+
impl<Id> BoxedRoundInfo<Id> {
85+
pub fn new<R>() -> Self
86+
where
87+
Id: PartyId,
88+
R: StaticRound<Id>,
89+
{
90+
Self(Box::new(StaticRoundInfoAdapter(PhantomData::<R>)))
91+
}
92+
93+
pub(crate) fn transition_info(&self) -> TransitionInfo {
94+
self.0.transition_info()
95+
}
96+
97+
pub(crate) fn verify_direct_message_is_invalid(
98+
&self,
99+
format: &BoxedFormat,
100+
message: &DirectMessage,
101+
) -> Result<(), MessageValidationError> {
102+
self.0.verify_direct_message_is_invalid(format, message)
103+
}
104+
105+
pub(crate) fn verify_echo_broadcast_is_invalid(
106+
&self,
107+
format: &BoxedFormat,
108+
message: &EchoBroadcast,
109+
) -> Result<(), MessageValidationError> {
110+
self.0.verify_echo_broadcast_is_invalid(format, message)
111+
}
112+
113+
pub(crate) fn verify_normal_broadcast_is_invalid(
114+
&self,
115+
format: &BoxedFormat,
116+
message: &NormalBroadcast,
117+
) -> Result<(), MessageValidationError> {
118+
self.0.verify_normal_broadcast_is_invalid(format, message)
119+
}
120+
}

manul/src/protocol/static_round.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@ use super::{
1717
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
1818
pub struct NoMessage;
1919

20-
fn is_no_message<T: 'static>() -> bool {
21-
TypeId::of::<T>() == TypeId::of::<NoMessage>()
22-
}
20+
impl NoMessage {
21+
pub(crate) fn equals<T: 'static>() -> bool {
22+
TypeId::of::<T>() == TypeId::of::<NoMessage>()
23+
}
2324

24-
fn make_no_message<T: 'static>() -> Option<T> {
25-
if TypeId::of::<T>() == TypeId::of::<NoMessage>() {
26-
let boxed = Box::new(NoMessage);
27-
// SAFETY: can cast since we checked that T == NoMessage
28-
let boxed_downcast = unsafe { Box::<T>::from_raw(Box::into_raw(boxed) as *mut T) };
29-
Some(*boxed_downcast)
30-
} else {
31-
None
25+
fn new<T: 'static>() -> Option<T> {
26+
if Self::equals::<T>() {
27+
let boxed = Box::new(NoMessage);
28+
// SAFETY: can cast since we checked that T == NoMessage
29+
let boxed_downcast = unsafe { Box::<T>::from_raw(Box::into_raw(boxed) as *mut T) };
30+
Some(*boxed_downcast)
31+
} else {
32+
None
33+
}
3234
}
3335
}
3436

@@ -72,7 +74,7 @@ pub trait StaticRound<Id: PartyId>: 'static + Debug + Send + Sync + DynTypeId {
7274
#[allow(unused_variables)] rng: &mut dyn CryptoRngCore,
7375
#[allow(unused_variables)] destination: &Id,
7476
) -> Result<(Self::DirectMessage, Option<Self::Artifact>), LocalError> {
75-
let direct_message = make_no_message::<Self::DirectMessage>().ok_or_else(|| {
77+
let direct_message = NoMessage::new::<Self::DirectMessage>().ok_or_else(|| {
7678
LocalError::new("If `DirectMessage` is not `NoMessage`, `make_direct_message()` must be implemented.")
7779
})?;
7880
Ok((direct_message, None))
@@ -90,7 +92,7 @@ pub trait StaticRound<Id: PartyId>: 'static + Debug + Send + Sync + DynTypeId {
9092
&self,
9193
#[allow(unused_variables)] rng: &mut dyn CryptoRngCore,
9294
) -> Result<Self::EchoBroadcast, LocalError> {
93-
let echo_broadcast = make_no_message::<Self::EchoBroadcast>().ok_or_else(|| {
95+
let echo_broadcast = NoMessage::new::<Self::EchoBroadcast>().ok_or_else(|| {
9496
LocalError::new("If `EchoBroadcast` is not `NoMessage`, `make_echo_broadcast()` must be implemented.")
9597
})?;
9698
Ok(echo_broadcast)
@@ -107,7 +109,7 @@ pub trait StaticRound<Id: PartyId>: 'static + Debug + Send + Sync + DynTypeId {
107109
&self,
108110
#[allow(unused_variables)] rng: &mut dyn CryptoRngCore,
109111
) -> Result<Self::NormalBroadcast, LocalError> {
110-
let normal_broadcast = make_no_message::<Self::NormalBroadcast>().ok_or_else(|| {
112+
let normal_broadcast = NoMessage::new::<Self::NormalBroadcast>().ok_or_else(|| {
111113
LocalError::new("If `NormalBroadcast` is not `NoMessage`, `make_normal_broadcast()` must be implemented.")
112114
})?;
113115
Ok(normal_broadcast)
@@ -178,7 +180,7 @@ where
178180
format: &BoxedFormat,
179181
destination: &Id,
180182
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
181-
Ok(if is_no_message::<R::DirectMessage>() {
183+
Ok(if NoMessage::equals::<R::DirectMessage>() {
182184
(DirectMessage::none(), None)
183185
} else {
184186
let (direct_message, artifact) = self.round.make_direct_message(rng, destination)?;
@@ -191,7 +193,7 @@ where
191193
rng: &mut dyn CryptoRngCore,
192194
format: &BoxedFormat,
193195
) -> Result<EchoBroadcast, LocalError> {
194-
Ok(if is_no_message::<R::EchoBroadcast>() {
196+
Ok(if NoMessage::equals::<R::EchoBroadcast>() {
195197
EchoBroadcast::none()
196198
} else {
197199
let echo_broadcast = self.round.make_echo_broadcast(rng)?;
@@ -204,7 +206,7 @@ where
204206
rng: &mut dyn CryptoRngCore,
205207
format: &BoxedFormat,
206208
) -> Result<NormalBroadcast, LocalError> {
207-
Ok(if is_no_message::<R::NormalBroadcast>() {
209+
Ok(if NoMessage::equals::<R::NormalBroadcast>() {
208210
NormalBroadcast::none()
209211
} else {
210212
let normal_broadcast = self.round.make_normal_broadcast(rng)?;
@@ -218,23 +220,23 @@ where
218220
from: &Id,
219221
message: ProtocolMessage,
220222
) -> Result<Payload, ReceiveError<Id, <Self as Round<Id>>::Protocol>> {
221-
let direct_message = if is_no_message::<R::DirectMessage>() {
223+
let direct_message = if NoMessage::equals::<R::DirectMessage>() {
222224
message.direct_message.assert_is_none()?;
223225
&DirectMessage::new(format, NoMessage)?
224226
} else {
225227
&message.direct_message
226228
};
227229
let direct_message = direct_message.deserialize::<R::DirectMessage>(format)?;
228230

229-
let echo_broadcast = if is_no_message::<R::EchoBroadcast>() {
231+
let echo_broadcast = if NoMessage::equals::<R::EchoBroadcast>() {
230232
message.echo_broadcast.assert_is_none()?;
231233
&EchoBroadcast::new(format, NoMessage)?
232234
} else {
233235
&message.echo_broadcast
234236
};
235237
let echo_broadcast = echo_broadcast.deserialize::<R::EchoBroadcast>(format)?;
236238

237-
let normal_broadcast = if is_no_message::<R::NormalBroadcast>() {
239+
let normal_broadcast = if NoMessage::equals::<R::NormalBroadcast>() {
238240
message.normal_broadcast.assert_is_none()?;
239241
&NormalBroadcast::new(format, NoMessage)?
240242
} else {

manul/src/tests/partial_echo.rs

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ use serde::{Deserialize, Serialize};
1212
use crate::{
1313
dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
1414
protocol::{
15-
Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EchoRoundParticipation,
16-
EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors, NormalBroadcast, PartyId,
17-
Payload, Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round, RoundId, TransitionInfo,
15+
Artifact, BoxedFormat, BoxedRound, BoxedRoundInfo, CommunicationInfo, DirectMessage, EchoBroadcast,
16+
EchoRoundParticipation, EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors,
17+
NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round,
18+
RoundId, StaticRound, TransitionInfo,
1819
},
1920
signature::Keypair,
2021
};
@@ -26,28 +27,8 @@ impl<Id: PartyId> Protocol<Id> for PartialEchoProtocol<Id> {
2627
type Result = ();
2728
type ProtocolError = NoProtocolErrors;
2829

29-
fn verify_direct_message_is_invalid(
30-
_format: &BoxedFormat,
31-
_round_id: &RoundId,
32-
_message: &DirectMessage,
33-
) -> Result<(), MessageValidationError> {
34-
unimplemented!()
35-
}
36-
37-
fn verify_echo_broadcast_is_invalid(
38-
_format: &BoxedFormat,
39-
_round_id: &RoundId,
40-
_message: &EchoBroadcast,
41-
) -> Result<(), MessageValidationError> {
42-
unimplemented!()
43-
}
44-
45-
fn verify_normal_broadcast_is_invalid(
46-
_format: &BoxedFormat,
47-
_round_id: &RoundId,
48-
_message: &NormalBroadcast,
49-
) -> Result<(), MessageValidationError> {
50-
unimplemented!()
30+
fn rounds() -> Vec<BoxedRoundInfo<Id>> {
31+
[BoxedRoundInfo::new::<Round1<Id>>()].into()
5132
}
5233
}
5334

@@ -82,11 +63,11 @@ impl<Id: PartyId + Serialize + for<'de> Deserialize<'de>> EntryPoint<Id> for Inp
8263
_shared_randomness: &[u8],
8364
_id: &Id,
8465
) -> Result<BoxedRound<Id, Self::Protocol>, LocalError> {
85-
Ok(BoxedRound::new_dynamic(Round1 { inputs: self }))
66+
Ok(BoxedRound::new_static(Round1 { inputs: self }))
8667
}
8768
}
8869

89-
impl<Id: PartyId + Serialize + for<'de> Deserialize<'de>> Round<Id> for Round1<Id> {
70+
impl<Id: PartyId + Serialize + for<'de> Deserialize<'de>> StaticRound<Id> for Round1<Id> {
9071
type Protocol = PartialEchoProtocol<Id>;
9172

9273
fn transition_info(&self) -> TransitionInfo {

0 commit comments

Comments
 (0)