Skip to content

Commit 313d20c

Browse files
authored
fix concurrent writes/reads in keyring (#342)
The keyring has several data races between its public methods where we read from or iterate over the keyring without taking the mutex. A concurrent `AddKey`, `RemoveKey`, or `UseKey` call can write to the location being read. Move some of the data read functions into private methods that expect the lock to be held, and hold that lock in the public methods. Fixes: #341
1 parent 21a632a commit 313d20c

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

keyring.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ func (k *Keyring) AddKey(key []byte) error {
8383
return err
8484
}
8585

86+
k.l.Lock()
87+
defer k.l.Unlock()
88+
8689
// No-op if key is already installed
8790
for _, installedKey := range k.keys {
8891
if bytes.Equal(installedKey, key) {
@@ -91,20 +94,22 @@ func (k *Keyring) AddKey(key []byte) error {
9194
}
9295

9396
keys := append(k.keys, key)
94-
primaryKey := k.GetPrimaryKey()
97+
primaryKey := k.getPrimaryKeyLocked()
9598
if primaryKey == nil {
9699
primaryKey = key
97100
}
98-
k.installKeys(keys, primaryKey)
101+
k.installKeysLocked(keys, primaryKey)
99102
return nil
100103
}
101104

102105
// UseKey changes the key used to encrypt messages. This is the only key used to
103106
// encrypt messages, so peers should know this key before this method is called.
104107
func (k *Keyring) UseKey(key []byte) error {
108+
k.l.Lock()
109+
defer k.l.Unlock()
105110
for _, installedKey := range k.keys {
106111
if bytes.Equal(key, installedKey) {
107-
k.installKeys(k.keys, key)
112+
k.installKeysLocked(k.keys, key)
108113
return nil
109114
}
110115
}
@@ -114,25 +119,25 @@ func (k *Keyring) UseKey(key []byte) error {
114119
// RemoveKey drops a key from the keyring. This will return an error if the key
115120
// requested for removal is currently at position 0 (primary key).
116121
func (k *Keyring) RemoveKey(key []byte) error {
122+
k.l.Lock()
123+
defer k.l.Unlock()
124+
117125
if bytes.Equal(key, k.keys[0]) {
118126
return fmt.Errorf("removing the primary key is not allowed")
119127
}
120128
for i, installedKey := range k.keys {
121129
if bytes.Equal(key, installedKey) {
122130
keys := append(k.keys[:i], k.keys[i+1:]...)
123-
k.installKeys(keys, k.keys[0])
131+
k.installKeysLocked(keys, k.keys[0])
124132
}
125133
}
126134
return nil
127135
}
128136

129-
// installKeys will take out a lock on the keyring, and replace the keys with a
137+
// installKeysLocked will take out a lock on the keyring, and replace the keys with a
130138
// new set of keys. The key indicated by primaryKey will be installed as the new
131-
// primary key.
132-
func (k *Keyring) installKeys(keys [][]byte, primaryKey []byte) {
133-
k.l.Lock()
134-
defer k.l.Unlock()
135-
139+
// primary key. The caller must be holding the lock.
140+
func (k *Keyring) installKeysLocked(keys [][]byte, primaryKey []byte) {
136141
newKeys := [][]byte{primaryKey}
137142
for _, key := range keys {
138143
if !bytes.Equal(key, primaryKey) {
@@ -155,7 +160,11 @@ func (k *Keyring) GetKeys() [][]byte {
155160
func (k *Keyring) GetPrimaryKey() (key []byte) {
156161
k.l.Lock()
157162
defer k.l.Unlock()
163+
return k.getPrimaryKeyLocked()
164+
}
158165

166+
// getPrimaryKeyLocked must be called while holding the mutex
167+
func (k *Keyring) getPrimaryKeyLocked() (key []byte) {
159168
if len(k.keys) > 0 {
160169
key = k.keys[0]
161170
}

keyring_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package memberlist
55

66
import (
77
"bytes"
8+
"sync"
89
"testing"
910
)
1011

@@ -155,3 +156,30 @@ func TestKeyRing_MultiKeyEncryptDecrypt(t *testing.T) {
155156
t.Fatalf("Expected no keys to decrypt message")
156157
}
157158
}
159+
160+
func TestKeyring_AddConcurrentKeys(t *testing.T) {
161+
keyring, err := NewKeyring(nil, TestKeys[0])
162+
if err != nil {
163+
t.Fatalf("err :%s", err)
164+
}
165+
166+
var wg sync.WaitGroup
167+
errs := make(chan error)
168+
wg.Add(2)
169+
go func() {
170+
defer wg.Done()
171+
errs <- keyring.AddKey(TestKeys[1])
172+
}()
173+
go func() {
174+
defer wg.Done()
175+
errs <- keyring.AddKey(TestKeys[2])
176+
}()
177+
178+
if err := <-errs; err != nil {
179+
t.Fatal(err)
180+
}
181+
if err := <-errs; err != nil {
182+
t.Fatal(err)
183+
}
184+
wg.Wait()
185+
}

0 commit comments

Comments
 (0)