diff --git a/controller/sharding/consistent/consistent.go b/controller/sharding/consistent/consistent.go index 3de3337846257..ba0d5f52fccc8 100644 --- a/controller/sharding/consistent/consistent.go +++ b/controller/sharding/consistent/consistent.go @@ -33,25 +33,17 @@ type Host struct { type Consistent struct { servers map[uint64]string - clients *btree.BTree + clients *btree.BTreeG[uint64] loadMap map[string]*Host totalLoad int64 replicationFactor int lock sync.RWMutex } -type item struct { - value uint64 -} - -func (i item) Less(than btree.Item) bool { - return i.value < than.(item).value -} - func New() *Consistent { return &Consistent{ servers: map[uint64]string{}, - clients: btree.New(2), + clients: btree.NewOrderedG[uint64](2), loadMap: map[string]*Host{}, replicationFactor: 1000, } @@ -60,7 +52,7 @@ func New() *Consistent { func NewWithReplicationFactor(replicationFactor int) *Consistent { return &Consistent{ servers: map[uint64]string{}, - clients: btree.New(2), + clients: btree.NewOrderedG[uint64](2), loadMap: map[string]*Host{}, replicationFactor: replicationFactor, } @@ -78,7 +70,7 @@ func (c *Consistent) Add(server string) { for i := 0; i < c.replicationFactor; i++ { h := c.hash(fmt.Sprintf("%s%d", server, i)) c.servers[h] = server - c.clients.ReplaceOrInsert(item{h}) + c.clients.ReplaceOrInsert(h) } } @@ -94,18 +86,21 @@ func (c *Consistent) Get(client string) (string, error) { } h := c.hash(client) - var foundItem btree.Item - c.clients.AscendGreaterOrEqual(item{h}, func(i btree.Item) bool { - foundItem = i + var foundKey uint64 + c.clients.AscendGreaterOrEqual(h, func(i uint64) bool { + foundKey = i return false // stop the iteration }) - if foundItem == nil { - // If no host found, wrap around to the first one. - foundItem = c.clients.Min() + if foundKey == 0 { + // If no key found, get the minimum key + c.clients.Ascend(func(i uint64) bool { + foundKey = i + return false // stop the iteration + }) } - host := c.servers[foundItem.(item).value] + host := c.servers[foundKey] return host, nil } @@ -122,30 +117,46 @@ func (c *Consistent) GetLeast(client string) (string, error) { return "", ErrNoHosts } h := c.hash(client) + start := h for { - var foundItem btree.Item - c.clients.AscendGreaterOrEqual(item{h}, func(bItem btree.Item) bool { - if h != bItem.(item).value { - foundItem = bItem + var foundKey uint64 + c.clients.AscendGreaterOrEqual(h, func(i uint64) bool { + if h != i { + foundKey = i return false // stop the iteration } return true }) - if foundItem == nil { - // If no host found, wrap around to the first one. - foundItem = c.clients.Min() + if foundKey == 0 { + // If no key found, get the minimum key + c.clients.Ascend(func(i uint64) bool { + foundKey = i + return false // stop the iteration + }) + } + + // Check if we have looped all the way around + if foundKey == start { + break } - key := c.clients.Get(foundItem) - if key == nil { - return client, nil + + host, exists := c.servers[foundKey] + if !exists { + return "", ErrNoHosts } - host := c.servers[key.(item).value] if c.loadOK(host) { return host, nil } - h = key.(item).value + // Start searching from the next point on the ring + h = foundKey + 1 } + // If no suitable host is found, return the first one or an error + host, exists := c.servers[start] + if !exists { + return "", ErrNoHosts + } + return host, nil } // Sets the load of `server` to the given `load` @@ -264,7 +275,7 @@ func (c *Consistent) loadOK(server string) bool { } func (c *Consistent) delSlice(val uint64) { - c.clients.Delete(item{val}) + c.clients.Delete(val) } func (c *Consistent) hash(key string) uint64 { diff --git a/controller/sharding/consistent/consistent_test.go b/controller/sharding/consistent/consistent_test.go new file mode 100644 index 0000000000000..ba479e9713e29 --- /dev/null +++ b/controller/sharding/consistent/consistent_test.go @@ -0,0 +1,177 @@ +package consistent + +import ( + "encoding/binary" + "fmt" + "sync" + "testing" + + "github.com/google/btree" + + blake2b "github.com/minio/blake2b-simd" +) + +const ( + testNumShards = 3 +) + +type OldConsistent struct { + servers map[uint64]string + clients *btree.BTree + loadMap map[string]*Host + replicationFactor int + lock sync.RWMutex +} + +type item struct { + value uint64 +} + +func (i item) Less(than btree.Item) bool { + return i.value < than.(item).value +} + +func NewOld() *OldConsistent { + return &OldConsistent{ + servers: map[uint64]string{}, + clients: btree.New(2), + loadMap: map[string]*Host{}, + replicationFactor: 1000, + } +} + +func (c *OldConsistent) Add(server string) { + c.lock.Lock() + defer c.lock.Unlock() + + if _, ok := c.loadMap[server]; ok { + return + } + + c.loadMap[server] = &Host{Name: server, Load: 0} + for i := 0; i < c.replicationFactor; i++ { + h := c.hash(fmt.Sprintf("%s%d", server, i)) + c.servers[h] = server + c.clients.ReplaceOrInsert(item{h}) + } +} + +func (c *OldConsistent) Get(client string) (string, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.clients.Len() == 0 { + return "", ErrNoHosts + } + + h := c.hash(client) + var foundItem btree.Item + c.clients.AscendGreaterOrEqual(item{h}, func(i btree.Item) bool { + foundItem = i + return false // stop the iteration + }) + + if foundItem == nil { + // If no host found, wrap around to the first one. + foundItem = c.clients.Min() + } + + host := c.servers[foundItem.(item).value] + + return host, nil +} + +func (c *OldConsistent) hash(key string) uint64 { + out := blake2b.Sum512([]byte(key)) + return binary.LittleEndian.Uint64(out[:]) +} + +func BenchmarkOldBTreeAdd(b *testing.B) { + for i := 0; i < b.N; i++ { + c := NewOld() + for j := 0; j < testNumShards; j++ { + c.Add(fmt.Sprintf("server%d", j)) + } + } +} + +func BenchmarkBTreeGAdd(b *testing.B) { + for i := 0; i < b.N; i++ { + c := New() + for j := 0; j < testNumShards; j++ { + c.Add(fmt.Sprintf("server%d", j)) + } + } +} + +func BenchmarkOldBTreeGet(b *testing.B) { + c := NewOld() + for j := 0; j < testNumShards; j++ { + c.Add(fmt.Sprintf("server%d", j)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Get(fmt.Sprintf("client%d", i)) + } +} + +func BenchmarkBTreeGGet(b *testing.B) { + c := New() + for j := 0; j < testNumShards; j++ { + c.Add(fmt.Sprintf("server%d", j)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Get(fmt.Sprintf("client%d", i)) + } +} + +func BenchmarkOldBTreeAddAndGet(b *testing.B) { + for i := 0; i < b.N; i++ { + c := NewOld() + for j := 0; j < testNumShards; j++ { + c.Add(fmt.Sprintf("server%d", j)) + } + for k := 0; k < 10; k++ { + _, _ = c.Get(fmt.Sprintf("client%d", k)) + } + } +} + +func BenchmarkBTreeGAddAndGet(b *testing.B) { + for i := 0; i < b.N; i++ { + c := New() + for j := 0; j < testNumShards; j++ { + c.Add(fmt.Sprintf("server%d", j)) + } + for k := 0; k < 10; k++ { + _, _ = c.Get(fmt.Sprintf("client%d", k)) + } + } +} + +func BenchmarkLargeOldBTreeAddAndGet(b *testing.B) { + for i := 0; i < b.N; i++ { + c := NewOld() + for j := 0; j < 100; j++ { + c.Add(fmt.Sprintf("server%03d", j)) + } + for k := 0; k < 1000; k++ { + _, _ = c.Get(fmt.Sprintf("client%04d", k)) + } + } +} + +func BenchmarkLargeBTreeGAddAndGet(b *testing.B) { + for i := 0; i < b.N; i++ { + c := New() + for j := 0; j < 100; j++ { + c.Add(fmt.Sprintf("server%03d", j)) + } + for k := 0; k < 1000; k++ { + _, _ = c.Get(fmt.Sprintf("client%04d", k)) + } + } +}