Skip to content

Commit 49bb410

Browse files
committed
Fully switch to statically typed rounds
1 parent 959da23 commit 49bb410

34 files changed

+1649
-1728
lines changed

examples/dining_cryptographers.rs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ use std::collections::{BTreeMap, BTreeSet};
5555
use manul::{
5656
dev::{run_sync, BinaryFormat, TestHasher, TestSignature, TestSigner, TestVerifier},
5757
protocol::{
58-
BoxedRound, BoxedRoundInfo, CommunicationInfo, EchoRoundParticipation, EntryPoint, FinalizeOutcome, LocalError,
59-
NoMessage, NoProtocolErrors, NoProvableErrors, Protocol, ReceiveError, RoundId, StaticProtocolMessage,
60-
StaticRound, TransitionInfo,
58+
BoxedRound, CommunicationInfo, EchoRoundParticipation, EntryPoint, FinalizeOutcome, LocalError, MessageParts,
59+
NoMessage, NoProvableErrors, Protocol, ReceiveError, Round, RoundId, RoundInfo, TransitionInfo,
6160
},
6261
session::SessionParameters,
6362
};
@@ -79,12 +78,10 @@ impl Protocol<DinerId> for DiningCryptographersProtocol {
7978
type Result = (bool, bool, bool);
8079
type SharedData = ();
8180

82-
type ProtocolError = NoProtocolErrors;
83-
84-
fn round_info(round_id: &RoundId) -> Option<BoxedRoundInfo<DinerId, Self>> {
81+
fn round_info(round_id: &RoundId) -> Option<RoundInfo<DinerId, Self>> {
8582
match round_id {
86-
_ if round_id == 1 => Some(BoxedRoundInfo::new::<Round1>()),
87-
_ if round_id == 2 => Some(BoxedRoundInfo::new::<Round2>()),
83+
_ if round_id == 1 => Some(RoundInfo::new::<Round1>()),
84+
_ if round_id == 2 => Some(RoundInfo::new::<Round2>()),
8885
_ => None,
8986
}
9087
}
@@ -108,7 +105,7 @@ pub struct Round2 {
108105
paid: bool,
109106
}
110107

111-
impl StaticRound<DinerId> for Round1 {
108+
impl Round<DinerId> for Round1 {
112109
type Protocol = DiningCryptographersProtocol;
113110
type ProvableError = NoProvableErrors<Self>;
114111

@@ -164,8 +161,8 @@ impl StaticRound<DinerId> for Round1 {
164161
fn receive_message(
165162
&self,
166163
from: &DinerId,
167-
message: StaticProtocolMessage<DinerId, Self>,
168-
) -> Result<Self::Payload, ReceiveError<DinerId, Self::Protocol>> {
164+
message: MessageParts<DinerId, Self>,
165+
) -> Result<Self::Payload, ReceiveError<DinerId, Self>> {
169166
let dm = message.direct_message;
170167
debug!(
171168
"[Round1, receive_message] {:?} was dm'd by {from:?}: {dm:?}",
@@ -192,7 +189,7 @@ impl StaticRound<DinerId> for Round1 {
192189
"[Round1, finalize] {:?} is finalizing to Round 2. Own cointoss: {}, neighbour cointoss: {neighbour_toss}",
193190
self.diner_id, self.own_toss
194191
);
195-
Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_static(Round2 {
192+
Ok(FinalizeOutcome::AnotherRound(BoxedRound::new(Round2 {
196193
diner_id: self.diner_id,
197194
own_toss: self.own_toss,
198195
neighbour_toss,
@@ -201,7 +198,7 @@ impl StaticRound<DinerId> for Round1 {
201198
}
202199
}
203200

204-
impl StaticRound<DinerId> for Round2 {
201+
impl Round<DinerId> for Round2 {
205202
type Protocol = DiningCryptographersProtocol;
206203
type ProvableError = NoProvableErrors<Self>;
207204

@@ -261,8 +258,8 @@ impl StaticRound<DinerId> for Round2 {
261258
fn receive_message(
262259
&self,
263260
from: &DinerId,
264-
message: StaticProtocolMessage<DinerId, Self>,
265-
) -> Result<Self::Payload, ReceiveError<DinerId, Self::Protocol>> {
261+
message: MessageParts<DinerId, Self>,
262+
) -> Result<Self::Payload, ReceiveError<DinerId, Self>> {
266263
debug!("[Round2, receive_message] from {from:?} to {:?}", self.diner_id);
267264
let bcast = message.normal_broadcast;
268265
trace!("[Round2, receive_message] message (deserialized bcast): {:?}", bcast);
@@ -336,7 +333,7 @@ impl EntryPoint<DinerId> for DiningEntryPoint {
336333
"[DiningEntryPoint, make_round] diner {id:?} tossed: {:?} (paid? {paid})",
337334
round.own_toss
338335
);
339-
let round = BoxedRound::new_static(round);
336+
let round = BoxedRound::new(round);
340337
Ok(round)
341338
}
342339
}

examples/src/simple.rs

Lines changed: 29 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ use alloc::collections::{BTreeMap, BTreeSet};
22
use core::fmt::Debug;
33

44
use manul::protocol::{
5-
BoxedFormat, BoxedRound, BoxedRoundInfo, CommunicationInfo, EchoBroadcast, EntryPoint, EvidenceMessages,
6-
FinalizeOutcome, LocalError, NoMessage, PartyId, Protocol, ProtocolError, ProtocolMessage, ProtocolMessagePart,
7-
ProtocolValidationError, ProvableError, ReceiveError, RequiredMessageParts, RequiredMessages, RoundId,
8-
StaticProtocolMessage, StaticRound, TransitionInfo,
5+
BoxedRound, CommunicationInfo, EntryPoint, EvidenceError, EvidenceMessages, FinalizeOutcome, LocalError,
6+
MessageParts, NoMessage, PartyId, Protocol, ProvableError, ReceiveError, RequiredMessageParts, RequiredMessages,
7+
Round, RoundId, RoundInfo, TransitionInfo,
98
};
109
use rand_core::CryptoRngCore;
1110
use serde::{Deserialize, Serialize};
@@ -14,42 +13,37 @@ use tracing::debug;
1413
#[derive(Debug)]
1514
pub struct SimpleProtocol;
1615

17-
#[derive(displaydoc::Display, Debug, Clone, Serialize, Deserialize)]
18-
/// An example error.
19-
pub enum SimpleProtocolError {
20-
/// Invalid position in Round 1.
21-
Round1InvalidPosition,
22-
/// Invalid position in Round 2.
23-
Round2InvalidPosition,
24-
}
25-
2616
#[derive(displaydoc::Display, Debug, Clone, Copy, Serialize, Deserialize)]
2717
pub(crate) struct Round1ProvableError;
2818

2919
impl<Id: PartyId> ProvableError<Id> for Round1ProvableError {
3020
type Round = Round1<Id>;
31-
fn required_previous_messages(&self) -> RequiredMessages {
21+
fn required_messages(&self, _round_id: &RoundId) -> RequiredMessages {
3222
RequiredMessages::new(RequiredMessageParts::direct_message(), None, None)
3323
}
3424
fn verify_evidence(
3525
&self,
26+
_round_id: &RoundId,
3627
_from: &Id,
3728
_shared_randomness: &[u8],
38-
_shared_data: &<<Self::Round as StaticRound<Id>>::Protocol as Protocol<Id>>::SharedData,
29+
_shared_data: &<<Self::Round as Round<Id>>::Protocol as Protocol<Id>>::SharedData,
3930
messages: EvidenceMessages<Id, Self::Round>,
40-
) -> std::result::Result<(), ProtocolValidationError> {
31+
) -> std::result::Result<(), EvidenceError> {
4132
let _message: Round1Message = messages.direct_message()?;
4233
// Message contents would be checked here
4334
Ok(())
4435
}
36+
fn description(&self) -> std::string::String {
37+
"Invalid position".into()
38+
}
4539
}
4640

4741
#[derive(displaydoc::Display, Debug, Clone, Copy, Serialize, Deserialize)]
4842
pub(crate) struct Round2ProvableError;
4943

5044
impl<Id: PartyId> ProvableError<Id> for Round2ProvableError {
5145
type Round = Round2<Id>;
52-
fn required_previous_messages(&self) -> RequiredMessages {
46+
fn required_messages(&self, _round_id: &RoundId) -> RequiredMessages {
5347
RequiredMessages::new(
5448
RequiredMessageParts::direct_message(),
5549
Some([(1.into(), RequiredMessageParts::direct_message())].into()),
@@ -58,75 +52,29 @@ impl<Id: PartyId> ProvableError<Id> for Round2ProvableError {
5852
}
5953
fn verify_evidence(
6054
&self,
55+
_round_id: &RoundId,
6156
_from: &Id,
6257
_shared_randomness: &[u8],
63-
_shared_data: &<<Self::Round as StaticRound<Id>>::Protocol as Protocol<Id>>::SharedData,
58+
_shared_data: &<<Self::Round as Round<Id>>::Protocol as Protocol<Id>>::SharedData,
6459
messages: EvidenceMessages<Id, Self::Round>,
65-
) -> std::result::Result<(), ProtocolValidationError> {
60+
) -> std::result::Result<(), EvidenceError> {
6661
let _r2_message: Round2Message = messages.direct_message()?;
6762
let _r1_echos: BTreeMap<Id, Round1Echo> = messages.combined_echos::<Round1<Id>>(1)?;
6863
// Message contents would be checked here
6964
Ok(())
7065
}
71-
}
72-
73-
impl<Id> ProtocolError<Id> for SimpleProtocolError {
74-
type AssociatedData = ();
75-
76-
fn required_messages(&self) -> RequiredMessages {
77-
match self {
78-
Self::Round1InvalidPosition => RequiredMessages::new(RequiredMessageParts::direct_message(), None, None),
79-
Self::Round2InvalidPosition => RequiredMessages::new(
80-
RequiredMessageParts::direct_message(),
81-
Some([(1.into(), RequiredMessageParts::direct_message())].into()),
82-
Some([1.into()].into()),
83-
),
84-
}
85-
}
86-
87-
fn verify_messages_constitute_error(
88-
&self,
89-
format: &BoxedFormat,
90-
_guilty_party: &Id,
91-
_shared_randomness: &[u8],
92-
_associated_data: &Self::AssociatedData,
93-
message: ProtocolMessage,
94-
_previous_messages: BTreeMap<RoundId, ProtocolMessage>,
95-
combined_echos: BTreeMap<RoundId, BTreeMap<Id, EchoBroadcast>>,
96-
) -> Result<(), ProtocolValidationError> {
97-
match self {
98-
SimpleProtocolError::Round1InvalidPosition => {
99-
let _message = message.direct_message.deserialize::<Round1Message>(format)?;
100-
// Message contents would be checked here
101-
Ok(())
102-
}
103-
SimpleProtocolError::Round2InvalidPosition => {
104-
let _r1_message = message.direct_message.deserialize::<Round1Message>(format)?;
105-
let r1_echos_serialized = combined_echos
106-
.get(&1.into())
107-
.ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?;
108-
109-
// Deserialize the echos
110-
let _r1_echos = r1_echos_serialized
111-
.iter()
112-
.map(|(_id, echo)| echo.deserialize::<Round1Echo>(format))
113-
.collect::<Result<Vec<_>, _>>()?;
114-
115-
// Message contents would be checked here
116-
Ok(())
117-
}
118-
}
66+
fn description(&self) -> std::string::String {
67+
"Invalid position".into()
11968
}
12069
}
12170

12271
impl<Id: PartyId> Protocol<Id> for SimpleProtocol {
12372
type Result = u8;
12473
type SharedData = ();
125-
type ProtocolError = SimpleProtocolError;
126-
fn round_info(round_id: &RoundId) -> Option<BoxedRoundInfo<Id, Self>> {
74+
fn round_info(round_id: &RoundId) -> Option<RoundInfo<Id, Self>> {
12775
match round_id {
128-
_ if round_id == 1 => Some(BoxedRoundInfo::new::<Round1<Id>>()),
129-
_ if round_id == 2 => Some(BoxedRoundInfo::new::<Round2<Id>>()),
76+
_ if round_id == 1 => Some(RoundInfo::new::<Round1<Id>>()),
77+
_ if round_id == 2 => Some(RoundInfo::new::<Round2<Id>>()),
13078
_ => None,
13179
}
13280
}
@@ -201,7 +149,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
201149
let mut ids = self.all_ids;
202150
ids.remove(id);
203151

204-
Ok(BoxedRound::new_static(Round1 {
152+
Ok(BoxedRound::new(Round1 {
205153
context: Context {
206154
id: id.clone(),
207155
other_ids: ids,
@@ -211,7 +159,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
211159
}
212160
}
213161

214-
impl<Id: PartyId> StaticRound<Id> for Round1<Id> {
162+
impl<Id: PartyId> Round<Id> for Round1<Id> {
215163
type Protocol = SimpleProtocol;
216164
type ProvableError = Round1ProvableError;
217165

@@ -261,13 +209,13 @@ impl<Id: PartyId> StaticRound<Id> for Round1<Id> {
261209
fn receive_message(
262210
&self,
263211
from: &Id,
264-
message: StaticProtocolMessage<Id, Self>,
265-
) -> Result<Self::Payload, ReceiveError<Id, Self::Protocol>> {
212+
message: MessageParts<Id, Self>,
213+
) -> Result<Self::Payload, ReceiveError<Id, Self>> {
266214
debug!("{:?}: receiving message from {:?}", self.context.id, from);
267215
let message = message.direct_message;
268216

269217
if self.context.ids_to_positions[&self.context.id] != message.your_position {
270-
return Err(ReceiveError::protocol(SimpleProtocolError::Round1InvalidPosition));
218+
return Err(ReceiveError::Provable(Round1ProvableError));
271219
}
272220
Ok(Round1Payload { x: message.my_position })
273221
}
@@ -287,7 +235,7 @@ impl<Id: PartyId> StaticRound<Id> for Round1<Id> {
287235
let sum =
288236
self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::<u8>();
289237

290-
let round2 = BoxedRound::new_static(Round2 {
238+
let round2 = BoxedRound::new(Round2 {
291239
round1_sum: sum,
292240
context: self.context,
293241
});
@@ -307,7 +255,7 @@ pub(crate) struct Round2Message {
307255
pub(crate) your_position: u8,
308256
}
309257

310-
impl<Id: PartyId> StaticRound<Id> for Round2<Id> {
258+
impl<Id: PartyId> Round<Id> for Round2<Id> {
311259
type Protocol = SimpleProtocol;
312260
type ProvableError = Round2ProvableError;
313261

@@ -343,16 +291,16 @@ impl<Id: PartyId> StaticRound<Id> for Round2<Id> {
343291
fn receive_message(
344292
&self,
345293
from: &Id,
346-
message: StaticProtocolMessage<Id, Self>,
347-
) -> Result<Self::Payload, ReceiveError<Id, Self::Protocol>> {
294+
message: MessageParts<Id, Self>,
295+
) -> Result<Self::Payload, ReceiveError<Id, Self>> {
348296
debug!("{:?}: receiving message from {:?}", self.context.id, from);
349297

350298
let message = message.direct_message;
351299

352300
debug!("{:?}: received message: {:?}", self.context.id, message);
353301

354302
if self.context.ids_to_positions[&self.context.id] != message.your_position {
355-
return Err(ReceiveError::protocol(SimpleProtocolError::Round2InvalidPosition));
303+
return Err(ReceiveError::Provable(Round2ProvableError));
356304
}
357305

358306
Ok(Round1Payload { x: message.my_position })

examples/src/simple_test.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use core::marker::PhantomData;
55
use manul::{
66
combinators::extend::{Extendable, Extension},
77
dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
8-
protocol::{LocalError, PartyId, StaticRound},
8+
protocol::{LocalError, PartyId, Round},
99
signature::Keypair,
1010
};
1111
use rand_core::{CryptoRngCore, OsRng};
@@ -29,8 +29,8 @@ where
2929
_destination: &Id,
3030
) -> Result<
3131
Option<(
32-
<Self::Round as StaticRound<Id>>::DirectMessage,
33-
<Self::Round as StaticRound<Id>>::Artifact,
32+
<Self::Round as Round<Id>>::DirectMessage,
33+
<Self::Round as Round<Id>>::Artifact,
3434
)>,
3535
LocalError,
3636
> {
@@ -98,8 +98,8 @@ where
9898
_destination: &Id,
9999
) -> Result<
100100
Option<(
101-
<Self::Round as StaticRound<Id>>::DirectMessage,
102-
<Self::Round as StaticRound<Id>>::Artifact,
101+
<Self::Round as Round<Id>>::DirectMessage,
102+
<Self::Round as Round<Id>>::Artifact,
103103
)>,
104104
LocalError,
105105
> {

0 commit comments

Comments
 (0)