diff --git a/chains/evm/listener/eventHandlers/tss.go b/chains/evm/listener/eventHandlers/tss.go index 4fe59244..05081c9e 100644 --- a/chains/evm/listener/eventHandlers/tss.go +++ b/chains/evm/listener/eventHandlers/tss.go @@ -14,6 +14,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" "github.com/sprintertech/sprinter-signing/chains/evm/calls/events" "github.com/sprintertech/sprinter-signing/comm" "github.com/sprintertech/sprinter-signing/comm/p2p" @@ -86,7 +87,7 @@ func (eh *KeygenEventHandler) HandleEvents( keygenBlockNumber := big.NewInt(0).SetUint64(keygenEvents[0].BlockNumber) keygen := keygen.NewKeygen(eh.sessionID(keygenBlockNumber), eh.threshold, eh.host, eh.communication, eh.storer) - err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{keygen}, make(chan interface{}, 1)) + err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{keygen}, make(chan interface{}, 1), peer.ID("")) if err != nil { log.Err(err).Msgf("Failed executing keygen") } @@ -178,7 +179,7 @@ func (eh *RefreshEventHandler) HandleEvents( resharing := resharing.NewResharing( eh.sessionID(startBlock), topology.Threshold, eh.host, eh.communication, eh.ecdsaStorer, ) - err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{resharing}, make(chan interface{}, 1)) + err = eh.coordinator.Execute(context.Background(), []tss.TssProcess{resharing}, make(chan interface{}, 1), peer.ID("")) if err != nil { log.Err(err).Msgf("Failed executing ecdsa key refresh") return nil diff --git a/tss/coordinator.go b/tss/coordinator.go index 987454e9..17dd20b9 100644 --- a/tss/coordinator.go +++ b/tss/coordinator.go @@ -72,7 +72,7 @@ func NewCoordinator( // Execute calculates process leader and coordinates party readiness and start the tss processes. // Array of processes can be passed if all the processes have to have the same peer subset and // the result of all of them is needed. The processes should have an unique session ID for each one. -func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}) error { +func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, resultChn chan interface{}, coordinator peer.ID) error { sessionID := tssProcesses[0].SessionID() value, ok := c.pendingProcesses[sessionID] if ok && value { @@ -98,7 +98,9 @@ func (c *Coordinator) Execute(ctx context.Context, tssProcesses []TssProcess, re }() coordinatorElector := c.electorFactory.CoordinatorElector(sessionID, elector.Static) - coordinator, _ := coordinatorElector.Coordinator(ctx, tssProcesses[0].ValidCoordinators()) + if coordinator.String() == "" { + coordinator, _ = coordinatorElector.Coordinator(ctx, tssProcesses[0].ValidCoordinators()) + } log.Info().Str("SessionID", sessionID).Msgf("Starting process with coordinator %s", coordinator.String()) diff --git a/tss/ecdsa/keygen/keygen_test.go b/tss/ecdsa/keygen/keygen_test.go index 1398fdcf..f1ef6a24 100644 --- a/tss/ecdsa/keygen/keygen_test.go +++ b/tss/ecdsa/keygen/keygen_test.go @@ -50,7 +50,9 @@ func (s *KeygenTestSuite) Test_ValidKeygenProcess() { s.MockECDSAStorer.EXPECT().StoreKeyshare(gomock.Any()).Times(3) pool := pool.New().WithContext(context.Background()).WithCancelOnError() for i, coordinator := range coordinators { - pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) }) + pool.Go(func(ctx context.Context) error { + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID("")) + }) } err := pool.Wait() @@ -81,7 +83,9 @@ func (s *KeygenTestSuite) Test_KeygenTimeout() { s.MockECDSAStorer.EXPECT().StoreKeyshare(gomock.Any()).Times(0) pool := pool.New().WithContext(context.Background()) for i, coordinator := range coordinators { - pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) }) + pool.Go(func(ctx context.Context) error { + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID("")) + }) } err := pool.Wait() diff --git a/tss/ecdsa/resharing/resharing_test.go b/tss/ecdsa/resharing/resharing_test.go index f71436cb..b8396cdd 100644 --- a/tss/ecdsa/resharing/resharing_test.go +++ b/tss/ecdsa/resharing/resharing_test.go @@ -69,7 +69,7 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_OldAndNewSubset() { pool := pool.New().WithContext(context.Background()).WithCancelOnError() for i, coordinator := range coordinators { pool.Go(func(ctx context.Context) error { - return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn) + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID("")) }) } @@ -116,7 +116,7 @@ func (s *ResharingTestSuite) Test_ValidResharingProcess_RemovePeer() { pool := pool.New().WithContext(context.Background()).WithCancelOnError() for i, coordinator := range coordinators { pool.Go(func(ctx context.Context) error { - return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn) + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID("")) }) } @@ -166,7 +166,7 @@ func (s *ResharingTestSuite) Test_InvalidResharingProcess_InvalidOldThreshold_Le pool := pool.New().WithContext(context.Background()) for i, coordinator := range coordinators { pool.Go(func(ctx context.Context) error { - return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn) + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID("")) }) } err := pool.Wait() @@ -215,7 +215,7 @@ func (s *ResharingTestSuite) Test_InvalidResharingProcess_InvalidOldThreshold_Bi pool := pool.New().WithContext(context.Background()) for i, coordinator := range coordinators { pool.Go(func(ctx context.Context) error { - return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn) + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID("")) }) } diff --git a/tss/ecdsa/signing/signing_test.go b/tss/ecdsa/signing/signing_test.go index 70e7e63b..120bbf5a 100644 --- a/tss/ecdsa/signing/signing_test.go +++ b/tss/ecdsa/signing/signing_test.go @@ -63,7 +63,7 @@ func (s *SigningTestSuite) Test_ValidSigningProcess() { for i, coordinator := range coordinators { coordinator := coordinator pool.Go(func(ctx context.Context) error { - return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn) + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID("")) }) } @@ -80,6 +80,60 @@ func (s *SigningTestSuite) Test_ValidSigningProcess() { s.Nil(err) } +func (s *SigningTestSuite) Test_ValidSigningProcess_ManualCoordinator() { + communicationMap := make(map[peer.ID]*tsstest.TestCommunication) + coordinators := []*tss.Coordinator{} + processes := []tss.TssProcess{} + + for i, host := range s.Hosts { + communication := tsstest.TestCommunication{ + Host: host, + Subscriptions: make(map[comm.SubscriptionID]chan *comm.WrappedMessage), + } + communicationMap[host.ID()] = &communication + fetcher := keyshare.NewECDSAKeyshareStore(fmt.Sprintf("../../test/keyshares/%d.keyshare", i)) + + msgBytes := []byte("Message") + msg := big.NewInt(0) + msg.SetBytes(msgBytes) + signing, err := signing.NewSigning(msg, "signing1", "signing1", host, &communication, fetcher) + if err != nil { + panic(err) + } + electorFactory := elector.NewCoordinatorElectorFactory(host, s.BullyConfig) + coordinators = append(coordinators, tss.NewCoordinator(host, &communication, electorFactory)) + processes = append(processes, signing) + } + tsstest.SetupCommunication(communicationMap) + + resultChn := make(chan interface{}, 2) + + coordinatorPeerID := s.Hosts[1].ID() + ctx, cancel := context.WithCancel(context.Background()) + pool := pool.New().WithContext(ctx) + for i, coordinator := range coordinators { + coordinator := coordinator + + if s.Hosts[i].ID().String() == coordinatorPeerID.String() { + pool.Go(func(ctx context.Context) error { + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, coordinatorPeerID) + }) + } else { + pool.Go(func(ctx context.Context) error { + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, make(chan interface{}, 1), coordinatorPeerID) + }) + } + } + + sig := <-resultChn + s.NotNil(sig) + + time.Sleep(time.Millisecond * 100) + cancel() + err := pool.Wait() + s.Nil(err) +} + func (s *SigningTestSuite) Test_SigningTimeout() { communicationMap := make(map[peer.ID]*tsstest.TestCommunication) coordinators := []*tss.Coordinator{} @@ -113,7 +167,7 @@ func (s *SigningTestSuite) Test_SigningTimeout() { for i, coordinator := range coordinators { coordinator := coordinator pool.Go(func(ctx context.Context) error { - return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn) + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, resultChn, peer.ID("")) }) } @@ -142,8 +196,12 @@ func (s *SigningTestSuite) Test_PendingProcessExists() { s.MockECDSAStorer.EXPECT().UnlockKeyshare().AnyTimes() pool := pool.New().WithContext(context.Background()).WithCancelOnError() for i, coordinator := range coordinators { - pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) }) - pool.Go(func(ctx context.Context) error { return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil) }) + pool.Go(func(ctx context.Context) error { + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID("")) + }) + pool.Go(func(ctx context.Context) error { + return coordinator.Execute(ctx, []tss.TssProcess{processes[i]}, nil, peer.ID("")) + }) } err := pool.Wait()