Skip to content

Commit 553a40b

Browse files
committed
Better concurrency handling
1 parent dcf8f19 commit 553a40b

5 files changed

Lines changed: 53 additions & 39 deletions

File tree

cache.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (f *fronted) prepopulateFronts(cacheFile string) {
4242
now := time.Now()
4343

4444
// update last succeeded status of masquerades based on cached values
45-
for _, fr := range f.fronts {
45+
for _, fr := range f.fronts.fronts {
4646
for _, cf := range cachedFronts {
4747
sameFront := cf.ProviderID == fr.getProviderID() && cf.Domain == fr.getDomain() && cf.IpAddress == fr.getIpAddress()
4848
cachedValueFresh := now.Sub(fr.lastSucceeded()) < f.maxAllowedCachedAge

cache_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestCaching(t *testing.T) {
2929
log.Debug("Creating fronted")
3030
makeFronted := func() *fronted {
3131
f := &fronted{
32-
fronts: make(sortedFronts, 0, 1000),
32+
fronts: newThreadSafeFronts(1000),
3333
maxAllowedCachedAge: 250 * time.Millisecond,
3434
maxCacheSize: 4,
3535
cacheSaveInterval: 50 * time.Millisecond,
@@ -51,7 +51,7 @@ func TestCaching(t *testing.T) {
5151
f := makeFronted()
5252

5353
log.Debug("Adding fronts")
54-
f.fronts = append(f.fronts, mb, mc, md)
54+
f.fronts.fronts = append(f.fronts.fronts, mb, mc, md)
5555

5656
readCached := func() []*front {
5757
log.Debug("Reading cached fronts")

front.go

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,45 @@ func NewStatusCodeValidator(reject []int) ResponseValidator {
360360
}
361361
}
362362

363+
type threadSafeFronts struct {
364+
fronts sortedFronts
365+
mx sync.RWMutex
366+
}
367+
368+
func newThreadSafeFronts(size int) *threadSafeFronts {
369+
return &threadSafeFronts{
370+
fronts: make(sortedFronts, 0, size),
371+
mx: sync.RWMutex{},
372+
}
373+
}
374+
375+
func (tsf *threadSafeFronts) sortedCopy() sortedFronts {
376+
tsf.mx.RLock()
377+
defer tsf.mx.RUnlock()
378+
c := make(sortedFronts, len(tsf.fronts))
379+
copy(c, tsf.fronts)
380+
sort.Sort(c)
381+
return c
382+
}
383+
384+
func (tsf *threadSafeFronts) addFronts(newFronts ...Front) {
385+
tsf.mx.Lock()
386+
defer tsf.mx.Unlock()
387+
tsf.fronts = append(tsf.fronts, newFronts...)
388+
}
389+
390+
func (tsf *threadSafeFronts) frontSize() int {
391+
tsf.mx.RLock()
392+
defer tsf.mx.RUnlock()
393+
return len(tsf.fronts)
394+
}
395+
396+
func (tsf *threadSafeFronts) frontAt(i int) Front {
397+
tsf.mx.RLock()
398+
defer tsf.mx.RUnlock()
399+
return tsf.fronts[i]
400+
}
401+
363402
// slice of masquerade sorted by last vetted time
364403
type sortedFronts []Front
365404

@@ -375,13 +414,6 @@ func (m sortedFronts) Less(i, j int) bool {
375414
}
376415
}
377416

378-
func (m sortedFronts) sortedCopy() sortedFronts {
379-
c := make(sortedFronts, len(m))
380-
copy(c, m)
381-
sort.Sort(c)
382-
return c
383-
}
384-
385417
func (fr *front) markCacheDirty() {
386418
select {
387419
case fr.cacheDirty <- nil:

fronted.go

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ var (
4949
// an implementation of http.RoundTripper for the convenience of callers.
5050
type fronted struct {
5151
certPool atomic.Value
52-
fronts sortedFronts
52+
fronts *threadSafeFronts
5353
maxAllowedCachedAge time.Duration
5454
maxCacheSize int
5555
cacheFile string
@@ -102,7 +102,7 @@ func NewFronted(options ...Option) Fronted {
102102

103103
f := &fronted{
104104
certPool: atomic.Value{},
105-
fronts: make(sortedFronts, 0),
105+
fronts: newThreadSafeFronts(0),
106106
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
107107
maxCacheSize: defaultMaxCacheSize,
108108
cacheSaveInterval: defaultCacheSaveInterval,
@@ -260,7 +260,7 @@ func (f *fronted) onNewFronts(pool *x509.CertPool, providers map[string]*Provide
260260
}
261261
providersCopy := copyProviders(providers, f.countryCode)
262262
f.addProviders(providersCopy)
263-
f.addFronts(loadFronts(providersCopy, f.cacheDirty))
263+
f.fronts.addFronts(loadFronts(providersCopy, f.cacheDirty)...)
264264
f.certPool.Store(pool)
265265

266266
// The goroutine for finding working fronts runs forever, so only start it once.
@@ -320,8 +320,8 @@ func (f *fronted) tryAllFronts() {
320320
pool := pond.NewPool(40)
321321

322322
// Submit all fronts to the worker pool.
323-
for i := range f.frontSize() {
324-
m := f.frontAt(i)
323+
for i := range f.fronts.frontSize() {
324+
m := f.fronts.frontAt(i)
325325
pool.Submit(func() {
326326
if f.isStopped() {
327327
return
@@ -348,18 +348,6 @@ func (f *fronted) hasEnoughWorkingFronts() bool {
348348
return len(f.frontsCh) >= 4
349349
}
350350

351-
func (f *fronted) frontSize() int {
352-
f.frontsMu.RLock()
353-
defer f.frontsMu.RUnlock()
354-
return len(f.fronts)
355-
}
356-
357-
func (f *fronted) frontAt(i int) Front {
358-
f.frontsMu.RLock()
359-
defer f.frontsMu.RUnlock()
360-
return f.fronts[i]
361-
}
362-
363351
func (f *fronted) vetFront(fr Front) bool {
364352
conn, err := f.dialFront(fr)
365353
if err != nil {
@@ -571,7 +559,7 @@ func copyProviders(providers map[string]*Provider, countryCode string) map[strin
571559
return providersCopy
572560
}
573561

574-
func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) sortedFronts {
562+
func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) []Front {
575563
log.Debugf("Loading candidates for %d providers", len(providers))
576564
defer log.Debug("Finished loading candidates")
577565

@@ -581,7 +569,7 @@ func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) sor
581569
size += len(p.Masquerades)
582570
}
583571

584-
fronts := make(sortedFronts, size)
572+
fronts := make([]Front, size)
585573

586574
// Note that map iteration order is random, so the order of the providers is automatically randomized.
587575
index := 0
@@ -616,13 +604,6 @@ func (f *fronted) addProviders(providers map[string]*Provider) {
616604
}
617605
}
618606

619-
func (f *fronted) addFronts(fronts sortedFronts) {
620-
// Add new masquerades to the existing masquerades slice, but add them at the beginning.
621-
f.frontsMu.Lock()
622-
defer f.frontsMu.Unlock()
623-
f.fronts = append(fronts, f.fronts.sortedCopy()...)
624-
}
625-
626607
func (f *fronted) providerFor(m Front) *Provider {
627608
pid := m.getProviderID()
628609
if pid == "" {

fronted_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE
155155
// Check the number of masquerades at the end, waiting until we get the right number
156156
masqueradesAtEnd := 0
157157
for i := 0; i < 1000; i++ {
158-
masqueradesAtEnd = len(d.fronts)
158+
masqueradesAtEnd = len(d.fronts.fronts)
159159
if masqueradesAtEnd == expectedMasqueradesAtEnd {
160160
break
161161
}
@@ -761,9 +761,10 @@ func TestFindWorkingMasquerades(t *testing.T) {
761761
}
762762
f.providers = make(map[string]*Provider)
763763
f.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, "")
764-
f.fronts = make(sortedFronts, len(tt.masquerades))
764+
//f.fronts = make(sortedFronts, len(tt.masquerades))
765+
f.fronts = newThreadSafeFronts(len(tt.masquerades))
765766
for i, m := range tt.masquerades {
766-
f.fronts[i] = m
767+
f.fronts.fronts[i] = m
767768
}
768769

769770
f.tryAllFronts()

0 commit comments

Comments
 (0)