-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathkbucket_fix_test.go
125 lines (107 loc) · 3.25 KB
/
kbucket_fix_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package kbucketfix
import (
"context"
"fmt"
"log"
"testing"
"time"
"github.com/libp2p/go-libp2p-core/peer"
)
func TestPatcher1(t *testing.T) {
testPatcher(t, 0.5, 0, false)
}
func TestPatcher2(t *testing.T) {
testPatcher(t, 0.3, 0, true)
}
func TestPatcher3(t *testing.T) {
testPatcher(t, 0.8, 0, false)
}
func TestPatcher4(t *testing.T) {
testPatcher(t, 0.5, 10, false)
}
func TestPatcher5(t *testing.T) {
testPatcher(t, 0.5, 10, true)
}
func testPatcher(t *testing.T, targetProtectionRate float32, maxProtected int, heartbeat bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
host, hostDHT := makeHost(t, ctx)
patcher := NewPatcher()
patcher.ProtectionRate = targetProtectionRate
patcher.MaxProtected = maxProtected
patcher.Patch(hostDHT)
rt := hostDHT.RoutingTable()
if rt == nil {
t.Error("No routing table found")
}
connMgr := host.ConnManager()
added := 0
removed := 0
protected := 0
peerAdded := rt.PeerAdded
rt.PeerAdded = func(p peer.ID) {
peerAdded(p)
added += 1
if connMgr.IsProtected(p, "") {
protected += 1
}
}
peerRemoved := rt.PeerRemoved
rt.PeerRemoved = func(p peer.ID) {
peerRemoved(p)
removed += 1
}
host2, _ := makeHost(t, ctx)
connect(host, host2, ctx)
for i := 0; i < 3000; i += 1 {
peerHost, _ := makeHost(t, ctx)
// connect(host2, peerHost, ctx)
connect(host, peerHost, ctx)
if heartbeat {
patcher.Heartbeat(peerHost.ID())
}
}
hostDHT.RefreshRoutingTable()
log.Println("Done refreshing routing table")
time.Sleep(time.Second * 2)
rt.PeerAdded = func(p peer.ID) { log.Println("PeerAdded: " + p.String()) }
rt.PeerRemoved = func(p peer.ID) { log.Println("PeerRemoved: " + p.String()) }
// Ensure numbers of active peers matches between the patcher and connectiion manager
if added-removed != patcher.getProtectedLenThreadUnsafe()+patcher.getTaggedLenThreadUnsafe() {
t.Error(fmt.Sprintf("%d - %d != %d + %d", added, removed, patcher.getProtectedLenThreadUnsafe(), patcher.getTaggedLenThreadUnsafe()))
}
percentage := patcher.getProtectionRateThreadUnsafe()
if maxProtected > 0 {
// Ensure MaxProtected (peers) restriction works as expected
if patcher.getProtectedLenThreadUnsafe() > maxProtected || percentage > targetProtectionRate {
t.Error(fmt.Sprintf("%d - %f", patcher.getProtectedLenThreadUnsafe(), percentage))
}
} else {
// Ensure the actual peer protection rate approximately matches the target
// when MaxProtected (peers) is not specified
const BIAS float32 = .05
if percentage < targetProtectionRate-BIAS || percentage > targetProtectionRate+BIAS {
t.Error(percentage)
}
}
// Ensure all peers that are marked as protected in the patcher
// are actually protected in the connection manager
for _, m := range patcher.dist2protected {
for _, k := range m.Keys() {
pid := k.(peer.ID)
if !connMgr.IsProtected(pid, kbucketTag) {
t.Error(fmt.Sprintf("Peer %s should be protected", pid))
}
}
}
// Ensure all peers that are marked as tagged in the patcher
// are actually not protected in the connection manager
for _, m := range patcher.dist2tagged {
for _, k := range m.Keys() {
pid := k.(peer.ID)
if connMgr.IsProtected(pid, kbucketTag) {
t.Error(fmt.Sprintf("Peer %s should not be protected", pid))
}
}
}
}