Skip to content

Commit a51009d

Browse files
authored
resolver: convert EndpointMap to use generics (#8189)
1 parent b0d1203 commit a51009d

File tree

10 files changed

+91
-114
lines changed

10 files changed

+91
-114
lines changed

balancer/endpointsharding/endpointsharding.go

+9-11
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions, childBuilde
7373
esOpts: esOpts,
7474
childBuilder: childBuilder,
7575
}
76-
es.children.Store(resolver.NewEndpointMap())
76+
es.children.Store(resolver.NewEndpointMap[*balancerWrapper]())
7777
return es
7878
}
7979

@@ -90,7 +90,7 @@ type endpointSharding struct {
9090
// calls into a child. To avoid deadlocks, do not acquire childMu while
9191
// holding mu.
9292
childMu sync.Mutex
93-
children atomic.Pointer[resolver.EndpointMap] // endpoint -> *balancerWrapper
93+
children atomic.Pointer[resolver.EndpointMap[*balancerWrapper]]
9494

9595
// inhibitChildUpdates is set during UpdateClientConnState/ResolverError
9696
// calls (calls to children will each produce an update, only want one
@@ -122,7 +122,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
122122
var ret error
123123

124124
children := es.children.Load()
125-
newChildren := resolver.NewEndpointMap()
125+
newChildren := resolver.NewEndpointMap[*balancerWrapper]()
126126

127127
// Update/Create new children.
128128
for _, endpoint := range state.ResolverState.Endpoints {
@@ -131,9 +131,8 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
131131
// update.
132132
continue
133133
}
134-
var childBalancer *balancerWrapper
135-
if val, ok := children.Get(endpoint); ok {
136-
childBalancer = val.(*balancerWrapper)
134+
childBalancer, ok := children.Get(endpoint)
135+
if ok {
137136
// Endpoint attributes may have changed, update the stored endpoint.
138137
es.mu.Lock()
139138
childBalancer.childState.Endpoint = endpoint
@@ -166,7 +165,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
166165
for _, e := range children.Keys() {
167166
child, _ := children.Get(e)
168167
if _, ok := newChildren.Get(e); !ok {
169-
child.(*balancerWrapper).closeLocked()
168+
child.closeLocked()
170169
}
171170
}
172171
es.children.Store(newChildren)
@@ -189,7 +188,7 @@ func (es *endpointSharding) ResolverError(err error) {
189188
}()
190189
children := es.children.Load()
191190
for _, child := range children.Values() {
192-
child.(*balancerWrapper).resolverErrorLocked(err)
191+
child.resolverErrorLocked(err)
193192
}
194193
}
195194

@@ -202,7 +201,7 @@ func (es *endpointSharding) Close() {
202201
defer es.childMu.Unlock()
203202
children := es.children.Load()
204203
for _, child := range children.Values() {
205-
child.(*balancerWrapper).closeLocked()
204+
child.closeLocked()
206205
}
207206
}
208207

@@ -222,8 +221,7 @@ func (es *endpointSharding) updateState() {
222221
childStates := make([]ChildState, 0, children.Len())
223222

224223
for _, child := range children.Values() {
225-
bw := child.(*balancerWrapper)
226-
childState := bw.childState
224+
childState := child.childState
227225
childStates = append(childStates, childState)
228226
childPicker := childState.State.Picker
229227
switch childState.State.ConnectivityState {

balancer/leastrequest/leastrequest.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (bb) Name() string {
8888
func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
8989
b := &leastRequestBalancer{
9090
ClientConn: cc,
91-
endpointRPCCounts: resolver.NewEndpointMap(),
91+
endpointRPCCounts: resolver.NewEndpointMap[*atomic.Int32](),
9292
}
9393
b.child = endpointsharding.NewBalancer(b, bOpts, balancer.Get(pickfirstleaf.Name).Build, endpointsharding.Options{})
9494
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
@@ -110,7 +110,7 @@ type leastRequestBalancer struct {
110110
choiceCount uint32
111111
// endpointRPCCounts holds RPC counts to keep track for subsequent picker
112112
// updates.
113-
endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32
113+
endpointRPCCounts *resolver.EndpointMap[*atomic.Int32]
114114
}
115115

116116
func (lrb *leastRequestBalancer) Close() {
@@ -164,7 +164,7 @@ func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
164164
}
165165

166166
// Reconcile endpoints.
167-
newEndpoints := resolver.NewEndpointMap() // endpoint -> nil
167+
newEndpoints := resolver.NewEndpointMap[any]()
168168
for _, child := range readyEndpoints {
169169
newEndpoints.Set(child.Endpoint, nil)
170170
}
@@ -179,13 +179,11 @@ func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
179179
// Copy refs to counters into picker.
180180
endpointStates := make([]endpointState, 0, len(readyEndpoints))
181181
for _, child := range readyEndpoints {
182-
var counter *atomic.Int32
183-
if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok {
182+
counter, ok := lrb.endpointRPCCounts.Get(child.Endpoint)
183+
if !ok {
184184
// Create new counts if needed.
185185
counter = new(atomic.Int32)
186186
lrb.endpointRPCCounts.Set(child.Endpoint, counter)
187-
} else {
188-
counter = val.(*atomic.Int32)
189187
}
190188
endpointStates = append(endpointStates, endpointState{
191189
picker: child.State.Picker,

balancer/weightedroundrobin/balancer.go

+7-11
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Ba
105105
target: bOpts.Target.String(),
106106
metricsRecorder: cc.MetricsRecorder(),
107107
addressWeights: resolver.NewAddressMapV2[*endpointWeight](),
108-
endpointToWeight: resolver.NewEndpointMap(),
108+
endpointToWeight: resolver.NewEndpointMap[*endpointWeight](),
109109
scToWeight: make(map[balancer.SubConn]*endpointWeight),
110110
}
111111

@@ -155,17 +155,15 @@ func (bb) Name() string {
155155
//
156156
// Caller must hold b.mu.
157157
func (b *wrrBalancer) updateEndpointsLocked(endpoints []resolver.Endpoint) {
158-
endpointSet := resolver.NewEndpointMap()
158+
endpointSet := resolver.NewEndpointMap[*endpointWeight]()
159159
addressSet := resolver.NewAddressMapV2[*endpointWeight]()
160160
for _, endpoint := range endpoints {
161161
endpointSet.Set(endpoint, nil)
162162
for _, addr := range endpoint.Addresses {
163163
addressSet.Set(addr, nil)
164164
}
165-
var ew *endpointWeight
166-
if ewi, ok := b.endpointToWeight.Get(endpoint); ok {
167-
ew = ewi.(*endpointWeight)
168-
} else {
165+
ew, ok := b.endpointToWeight.Get(endpoint)
166+
if !ok {
169167
ew = &endpointWeight{
170168
logger: b.logger,
171169
connectivityState: connectivity.Connecting,
@@ -215,7 +213,7 @@ type wrrBalancer struct {
215213
locality string
216214
stopPicker *grpcsync.Event
217215
addressWeights *resolver.AddressMapV2[*endpointWeight]
218-
endpointToWeight *resolver.EndpointMap // endpoint -> endpointWeight
216+
endpointToWeight *resolver.EndpointMap[*endpointWeight]
219217
scToWeight map[balancer.SubConn]*endpointWeight
220218
}
221219

@@ -260,13 +258,12 @@ func (b *wrrBalancer) UpdateState(state balancer.State) {
260258

261259
for _, childState := range childStates {
262260
if childState.State.ConnectivityState == connectivity.Ready {
263-
ewv, ok := b.endpointToWeight.Get(childState.Endpoint)
261+
ew, ok := b.endpointToWeight.Get(childState.Endpoint)
264262
if !ok {
265263
// Should never happen, simply continue and ignore this endpoint
266264
// for READY pickers.
267265
continue
268266
}
269-
ew := ewv.(*endpointWeight)
270267
readyPickersWeight = append(readyPickersWeight, pickerWeightedEndpoint{
271268
picker: childState.State.Picker,
272269
weightedEndpoint: ew,
@@ -398,8 +395,7 @@ func (b *wrrBalancer) Close() {
398395
b.mu.Unlock()
399396

400397
// Ensure any lingering OOB watchers are stopped.
401-
for _, ewv := range b.endpointToWeight.Values() {
402-
ew := ewv.(*endpointWeight)
398+
for _, ew := range b.endpointToWeight.Values() {
403399
if ew.stopORCAListener != nil {
404400
ew.stopORCAListener()
405401
}

resolver/map.go

+16-16
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,21 @@ type endpointMapKey string
162162
// unordered set of address strings within an endpoint. This map is not thread
163163
// safe, thus it is unsafe to access concurrently. Must be created via
164164
// NewEndpointMap; do not construct directly.
165-
type EndpointMap struct {
166-
endpoints map[endpointMapKey]endpointData
165+
type EndpointMap[T any] struct {
166+
endpoints map[endpointMapKey]endpointData[T]
167167
}
168168

169-
type endpointData struct {
169+
type endpointData[T any] struct {
170170
// decodedKey stores the original key to avoid decoding when iterating on
171171
// EndpointMap keys.
172172
decodedKey Endpoint
173-
value any
173+
value T
174174
}
175175

176176
// NewEndpointMap creates a new EndpointMap.
177-
func NewEndpointMap() *EndpointMap {
178-
return &EndpointMap{
179-
endpoints: make(map[endpointMapKey]endpointData),
177+
func NewEndpointMap[T any]() *EndpointMap[T] {
178+
return &EndpointMap[T]{
179+
endpoints: make(map[endpointMapKey]endpointData[T]),
180180
}
181181
}
182182

@@ -196,25 +196,25 @@ func encodeEndpoint(e Endpoint) endpointMapKey {
196196
}
197197

198198
// Get returns the value for the address in the map, if present.
199-
func (em *EndpointMap) Get(e Endpoint) (value any, ok bool) {
199+
func (em *EndpointMap[T]) Get(e Endpoint) (value T, ok bool) {
200200
val, found := em.endpoints[encodeEndpoint(e)]
201201
if found {
202202
return val.value, true
203203
}
204-
return nil, false
204+
return value, false
205205
}
206206

207207
// Set updates or adds the value to the address in the map.
208-
func (em *EndpointMap) Set(e Endpoint, value any) {
208+
func (em *EndpointMap[T]) Set(e Endpoint, value T) {
209209
en := encodeEndpoint(e)
210-
em.endpoints[en] = endpointData{
210+
em.endpoints[en] = endpointData[T]{
211211
decodedKey: Endpoint{Addresses: e.Addresses},
212212
value: value,
213213
}
214214
}
215215

216216
// Len returns the number of entries in the map.
217-
func (em *EndpointMap) Len() int {
217+
func (em *EndpointMap[T]) Len() int {
218218
return len(em.endpoints)
219219
}
220220

@@ -223,7 +223,7 @@ func (em *EndpointMap) Len() int {
223223
// the unordered set of addresses. Thus, endpoint information returned is not
224224
// the full endpoint data (drops duplicated addresses and attributes) but can be
225225
// used for EndpointMap accesses.
226-
func (em *EndpointMap) Keys() []Endpoint {
226+
func (em *EndpointMap[T]) Keys() []Endpoint {
227227
ret := make([]Endpoint, 0, len(em.endpoints))
228228
for _, en := range em.endpoints {
229229
ret = append(ret, en.decodedKey)
@@ -232,16 +232,16 @@ func (em *EndpointMap) Keys() []Endpoint {
232232
}
233233

234234
// Values returns a slice of all current map values.
235-
func (em *EndpointMap) Values() []any {
236-
ret := make([]any, 0, len(em.endpoints))
235+
func (em *EndpointMap[T]) Values() []T {
236+
ret := make([]T, 0, len(em.endpoints))
237237
for _, val := range em.endpoints {
238238
ret = append(ret, val.value)
239239
}
240240
return ret
241241
}
242242

243243
// Delete removes the specified endpoint from the map.
244-
func (em *EndpointMap) Delete(e Endpoint) {
244+
func (em *EndpointMap[T]) Delete(e Endpoint) {
245245
en := encodeEndpoint(e)
246246
delete(em.endpoints, en)
247247
}

0 commit comments

Comments
 (0)