Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow for coordinator to be set manually #9

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions chains/evm/listener/eventHandlers/tss.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/sprintertech/sprinter-signing/chains/evm/calls/events"
"github.com/sprintertech/sprinter-signing/comm"
"github.com/sprintertech/sprinter-signing/comm/p2p"
Expand Down Expand Up @@ -86,7 +87,7 @@ func (eh *KeygenEventHandler) HandleEvents(

keygenBlockNumber := big.NewInt(0).SetUint64(keygenEvents[0].BlockNumber)
keygen := keygen.NewKeygen(eh.sessionID(keygenBlockNumber), eh.threshold, eh.host, eh.communication, eh.storer)
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{keygen}, make(chan interface{}, 1))
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{keygen}, make(chan interface{}, 1), peer.ID(""))
if err != nil {
log.Err(err).Msgf("Failed executing keygen")
}
Expand Down Expand Up @@ -178,7 +179,7 @@ func (eh *RefreshEventHandler) HandleEvents(
resharing := resharing.NewResharing(
eh.sessionID(startBlock), topology.Threshold, eh.host, eh.communication, eh.ecdsaStorer,
)
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{resharing}, make(chan interface{}, 1))
err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{resharing}, make(chan interface{}, 1), peer.ID(""))
if err != nil {
log.Err(err).Msgf("Failed executing ecdsa key refresh")
return nil
Expand Down
6 changes: 4 additions & 2 deletions tss/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func NewCoordinator(
// Execute calculates process leader and coordinates party readiness and start the tss processes.
// Array of processes can be passed if all the processes have to have the same peer subset and
// the result of all of them is needed. The processes should have an unique session ID for each one.
func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}) error {
func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}, coordinator peer.ID) error {
sessionID := tssProcesses[0].SessionID()
value, ok := c.pendingProcesses[sessionID]
if ok && value {
Expand All @@ -98,7 +98,9 @@ func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, re
}()

coordinatorElector := c.electorFactory.CoordinatorElector(sessionID, elector.Static)
coordinator, _ := coordinatorElector.Coordinator(ctx, tssProcesses[0].ValidCoordinators())
if coordinator.String() == "" {
coordinator, _ = coordinatorElector.Coordinator(ctx, tssProcesses[0].ValidCoordinators())
}

log.Info().Str("SessionID", sessionID).Msgf("Starting process with coordinator %s", coordinator.String())

Expand Down
8 changes: 6 additions & 2 deletions tss/ecdsa/keygen/keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ func (s *KeygenTestSuite) Test_ValidKeygenProcess() {
s.MockECDSAStorer.EXPECT().StoreKeyshare(gomock.Any()).Times(3)
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID(""))
})
}

err := pool.Wait()
Expand Down Expand Up @@ -81,7 +83,9 @@ func (s *KeygenTestSuite) Test_KeygenTimeout() {
s.MockECDSAStorer.EXPECT().StoreKeyshare(gomock.Any()).Times(0)
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID(""))
})
}

err := pool.Wait()
Expand Down
8 changes: 4 additions & 4 deletions tss/ecdsa/resharing/resharing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_OldAndNewSubset() {
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID(""))
})
}

Expand Down Expand Up @@ -116,7 +116,7 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_RemovePeer() {
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID(""))
})
}

Expand Down Expand Up @@ -166,7 +166,7 @@ func (s *ResharingTestSuite) Test_InvalidResharingProcess_InvalidOldThreshold_Le
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID(""))
})
}
err := pool.Wait()
Expand Down Expand Up @@ -215,7 +215,7 @@ func (s *ResharingTestSuite) Test_InvalidResharingProcess_InvalidOldThreshold_Bi
pool := pool.New().WithContext(context.Background())
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID(""))
})
}

Expand Down
66 changes: 62 additions & 4 deletions tss/ecdsa/signing/signing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (s *SigningTestSuite) Test_ValidSigningProcess() {
for i, coordinator := range coordinators {
coordinator := coordinator
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID(""))
})
}

Expand All @@ -80,6 +80,60 @@ func (s *SigningTestSuite) Test_ValidSigningProcess() {
s.Nil(err)
}

func (s *SigningTestSuite) Test_ValidSigningProcess_ManualCoordinator() {
communicationMap := make(map[peer.ID]*tsstest.TestCommunication)
coordinators := []*tss.Coordinator{}
processes := []tss.TssProcess{}

for i, host := range s.Hosts {
communication := tsstest.TestCommunication{
Host: host,
Subscriptions: make(map[comm.SubscriptionID]chan *comm.WrappedMessage),
}
communicationMap[host.ID()] = &communication
fetcher := keyshare.NewECDSAKeyshareStore(fmt.Sprintf("../../test/keyshares/%d.keyshare", i))

msgBytes := []byte("Message")
msg := big.NewInt(0)
msg.SetBytes(msgBytes)
signing, err := signing.NewSigning(msg, "signing1", "signing1", host, &communication, fetcher)
if err != nil {
panic(err)
}
electorFactory := elector.NewCoordinatorElectorFactory(host, s.BullyConfig)
coordinators = append(coordinators, tss.NewCoordinator(host, &communication, electorFactory))
processes = append(processes, signing)
}
tsstest.SetupCommunication(communicationMap)

resultChn := make(chan interface{}, 2)

coordinatorPeerID := s.Hosts[1].ID()
ctx, cancel := context.WithCancel(context.Background())
pool := pool.New().WithContext(ctx)
for i, coordinator := range coordinators {
coordinator := coordinator

if s.Hosts[i].ID().String() == coordinatorPeerID.String() {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, coordinatorPeerID)
})
} else {
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, make(chan interface{}, 1), coordinatorPeerID)
})
}
}

sig := <-resultChn
s.NotNil(sig)

time.Sleep(time.Millisecond * 100)
cancel()
err := pool.Wait()
s.Nil(err)
}

func (s *SigningTestSuite) Test_SigningTimeout() {
communicationMap := make(map[peer.ID]*tsstest.TestCommunication)
coordinators := []*tss.Coordinator{}
Expand Down Expand Up @@ -113,7 +167,7 @@ func (s *SigningTestSuite) Test_SigningTimeout() {
for i, coordinator := range coordinators {
coordinator := coordinator
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn)
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID(""))
})
}

Expand Down Expand Up @@ -142,8 +196,12 @@ func (s *SigningTestSuite) Test_PendingProcessExists() {
s.MockECDSAStorer.EXPECT().UnlockKeyshare().AnyTimes()
pool := pool.New().WithContext(context.Background()).WithCancelOnError()
for i, coordinator := range coordinators {
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) })
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID(""))
})
pool.Go(func(ctx context.Context) error {
return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID(""))
})
}

err := pool.Wait()
Expand Down
Loading