Skip to content

Commit 21c2256

Browse files
cstocktonChris Stockton
and
Chris Stockton
authoredFeb 4, 2025··
feat: improvements to config reloader, 100% coverage (#1933)
Increased test coverage of reloader to 100%. --------- Co-authored-by: Chris Stockton <[email protected]>
1 parent fbbebcc commit 21c2256

File tree

5 files changed

+456
-45
lines changed

5 files changed

+456
-45
lines changed
 

‎hack/coverage.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
FAIL=false
22

3-
for PKG in "crypto"
3+
for PKG in "crypto" "reloader"
44
do
55
UNCOVERED_FUNCS=$(go tool cover -func=coverage.out | grep "^github.com/supabase/auth/internal/$PKG/" | grep -v '100.0%$')
66
UNCOVERED_FUNCS_COUNT=$(echo "$UNCOVERED_FUNCS" | wc -l)

‎internal/reloader/handler_race_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ func TestAtomicHandlerRaces(t *testing.T) {
5050

5151
hr.Store(hrFunc)
5252

53+
// Calling string should be safe
54+
hr.String()
55+
5356
got := hr.load()
5457
_, ok := hrFuncMap[got]
5558
if !ok {

‎internal/reloader/handler_test.go

+15-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package reloader
22

33
import (
44
"net/http"
5+
"sync/atomic"
56
"testing"
67

78
"github.com/stretchr/testify/assert"
@@ -11,23 +12,31 @@ func TestAtomicHandler(t *testing.T) {
1112
// for ptr identity
1213
type testHandler struct{ http.Handler }
1314

15+
var calls atomic.Int64
1416
hrFn := func() http.Handler {
15-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
17+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18+
calls.Add(1)
19+
})
1620
}
1721

1822
hrFunc1 := &testHandler{hrFn()}
1923
hrFunc2 := &testHandler{hrFn()}
2024
assert.NotEqual(t, hrFunc1, hrFunc2)
2125

2226
// a new AtomicHandler should be non-nil
23-
hr := NewAtomicHandler(nil)
27+
hr := NewAtomicHandler(hrFunc1)
2428
assert.NotNil(t, hr)
29+
assert.Equal(t, "reloader.AtomicHandler", hr.String())
2530

26-
// should have no stored handler
31+
// should implement http.Handler
2732
{
28-
hrCur := hr.load()
29-
assert.Nil(t, hrCur)
30-
assert.Equal(t, true, hrCur == nil)
33+
v := (http.Handler)(hr)
34+
before := calls.Load()
35+
v.ServeHTTP(nil, nil)
36+
after := calls.Load()
37+
if exp, got := before+1, after; exp != got {
38+
t.Fatalf("exp %v to be %v after handler was called", got, exp)
39+
}
3140
}
3241

3342
// should be non-nil after store

‎internal/reloader/reloader.go

+104-22
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ package reloader
33

44
import (
55
"context"
6-
"log"
6+
"errors"
77
"strings"
8+
"sync"
89
"time"
910

1011
"github.com/fsnotify/fsnotify"
@@ -27,28 +28,24 @@ type Reloader struct {
2728
watchDir string
2829
reloadIval time.Duration
2930
tickerIval time.Duration
31+
watchFn func() (watcher, error)
32+
reloadFn func(dir string) (*conf.GlobalConfiguration, error)
3033
}
3134

3235
func NewReloader(watchDir string) *Reloader {
3336
return &Reloader{
3437
watchDir: watchDir,
3538
reloadIval: reloadInterval,
3639
tickerIval: tickerInterval,
40+
watchFn: newFSWatcher,
41+
reloadFn: defaultReloadFn,
3742
}
3843
}
3944

4045
// reload attempts to create a new *conf.GlobalConfiguration after loading the
4146
// currently configured watchDir.
4247
func (rl *Reloader) reload() (*conf.GlobalConfiguration, error) {
43-
if err := conf.LoadDirectory(rl.watchDir); err != nil {
44-
return nil, err
45-
}
46-
47-
cfg, err := conf.LoadGlobalFromEnv()
48-
if err != nil {
49-
return nil, err
50-
}
51-
return cfg, nil
48+
return rl.reloadFn(rl.watchDir)
5249
}
5350

5451
// reloadCheckAt checks if reloadConfig should be called, returns true if config
@@ -66,9 +63,10 @@ func (rl *Reloader) reloadCheckAt(at, lastUpdate time.Time) bool {
6663
}
6764

6865
func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
69-
wr, err := fsnotify.NewWatcher()
66+
wr, err := rl.watchFn()
7067
if err != nil {
71-
log.Fatal(err)
68+
logrus.WithError(err).Error("reloader: error creating fsnotify Watcher")
69+
return err
7270
}
7371
defer wr.Close()
7472

@@ -77,7 +75,7 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
7775

7876
// Ignore errors, if watch dir doesn't exist we can add it later.
7977
if err := wr.Add(rl.watchDir); err != nil {
80-
logrus.WithError(err).Error("watch dir failed")
78+
logrus.WithError(err).Error("reloader: error watching config directory")
8179
}
8280

8381
var lastUpdate time.Time
@@ -92,7 +90,7 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
9290
// scenarios and wr.WatchList() does not grow which aligns with
9391
// the documented behavior.
9492
if err := wr.Add(rl.watchDir); err != nil {
95-
logrus.WithError(err).Error("watch dir failed")
93+
logrus.WithError(err).Error(err)
9694
}
9795

9896
// Check to see if the config is ready to be relaoded.
@@ -105,17 +103,18 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
105103

106104
cfg, err := rl.reload()
107105
if err != nil {
108-
logrus.WithError(err).Error("config reload failed")
106+
logrus.WithError(err).Error("reloader: error loading config")
109107
continue
110108
}
111109

112110
// Call the callback function with the latest cfg.
113111
fn(cfg)
114112

115-
case evt, ok := <-wr.Events:
113+
case evt, ok := <-wr.Events():
116114
if !ok {
117-
logrus.WithError(err).Error("fsnotify has exited")
118-
return nil
115+
err := errors.New("reloader: fsnotify event channel was closed")
116+
logrus.WithError(err).Error(err)
117+
return err
119118
}
120119

121120
// We only read files ending in .env
@@ -130,12 +129,95 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
130129
evt.Op.Has(fsnotify.Write):
131130
lastUpdate = time.Now()
132131
}
133-
case err, ok := <-wr.Errors:
132+
case err, ok := <-wr.Errors():
134133
if !ok {
135-
logrus.Error("fsnotify has exited")
136-
return nil
134+
err := errors.New("reloader: fsnotify error channel was closed")
135+
logrus.WithError(err).Error(err)
136+
return err
137137
}
138-
logrus.WithError(err).Error("fsnotify has reported an error")
138+
logrus.WithError(err).Error(
139+
"reloader: fsnotify has reported an error")
139140
}
140141
}
141142
}
143+
144+
func defaultReloadFn(dir string) (*conf.GlobalConfiguration, error) {
145+
if err := conf.LoadDirectory(dir); err != nil {
146+
return nil, err
147+
}
148+
149+
cfg, err := conf.LoadGlobalFromEnv()
150+
if err != nil {
151+
return nil, err
152+
}
153+
return cfg, nil
154+
}
155+
156+
type watcher interface {
157+
Add(path string) error
158+
Close() error
159+
Events() chan fsnotify.Event
160+
Errors() chan error
161+
}
162+
163+
type fsNotifyWatcher struct {
164+
wr *fsnotify.Watcher
165+
}
166+
167+
func newFSWatcher() (watcher, error) {
168+
wr, err := fsnotify.NewWatcher()
169+
return &fsNotifyWatcher{wr}, err
170+
}
171+
172+
func (o *fsNotifyWatcher) Add(path string) error { return o.wr.Add(path) }
173+
func (o *fsNotifyWatcher) Close() error { return o.wr.Close() }
174+
func (o *fsNotifyWatcher) Errors() chan error { return o.wr.Errors }
175+
func (o *fsNotifyWatcher) Events() chan fsnotify.Event { return o.wr.Events }
176+
177+
type mockWatcher struct {
178+
mu sync.Mutex
179+
err error
180+
eventCh chan fsnotify.Event
181+
errorCh chan error
182+
addCh chan string
183+
}
184+
185+
func newMockWatcher(err error) *mockWatcher {
186+
wr := &mockWatcher{
187+
err: err,
188+
eventCh: make(chan fsnotify.Event, 1024),
189+
errorCh: make(chan error, 1024),
190+
addCh: make(chan string, 1024),
191+
}
192+
return wr
193+
}
194+
195+
func (o *mockWatcher) getErr() error {
196+
o.mu.Lock()
197+
defer o.mu.Unlock()
198+
err := o.err
199+
return err
200+
}
201+
202+
func (o *mockWatcher) setErr(err error) {
203+
o.mu.Lock()
204+
defer o.mu.Unlock()
205+
o.err = err
206+
}
207+
208+
func (o *mockWatcher) Add(path string) error {
209+
o.mu.Lock()
210+
defer o.mu.Unlock()
211+
if err := o.err; err != nil {
212+
return err
213+
}
214+
215+
select {
216+
case o.addCh <- path:
217+
default:
218+
}
219+
return nil
220+
}
221+
func (o *mockWatcher) Close() error { return o.getErr() }
222+
func (o *mockWatcher) Events() chan fsnotify.Event { return o.eventCh }
223+
func (o *mockWatcher) Errors() chan error { return o.errorCh }

‎internal/reloader/reloader_test.go

+333-16
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,279 @@ package reloader
22

33
import (
44
"bytes"
5-
"log"
5+
"context"
6+
"errors"
67
"os"
8+
"path"
79
"path/filepath"
810
"testing"
911
"time"
1012

13+
"github.com/fsnotify/fsnotify"
1114
"github.com/stretchr/testify/assert"
15+
"github.com/supabase/auth/internal/conf"
16+
"golang.org/x/sync/errgroup"
1217
)
1318

19+
func TestWatch(t *testing.T) {
20+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
21+
defer cancel()
22+
23+
dir, cleanup := helpTestDir(t)
24+
defer cleanup()
25+
26+
// test broken watcher
27+
{
28+
sentinelErr := errors.New("sentinel")
29+
rr := mockReloadRecorder()
30+
rl := NewReloader(dir)
31+
rl.watchFn = func() (watcher, error) { return nil, sentinelErr }
32+
33+
err := rl.Watch(ctx, rr.configFn)
34+
if exp, got := sentinelErr, err; exp != got {
35+
assert.Equal(t, exp, got)
36+
}
37+
}
38+
39+
// test watch invalid dir
40+
{
41+
doneCtx, doneCancel := context.WithCancel(ctx)
42+
doneCancel()
43+
44+
rr := mockReloadRecorder()
45+
rl := NewReloader(path.Join(dir, "__not_found__"))
46+
err := rl.Watch(doneCtx, rr.configFn)
47+
if exp, got := context.Canceled, err; exp != got {
48+
assert.Equal(t, exp, got)
49+
}
50+
}
51+
52+
// test watch error chan closed
53+
{
54+
rr := mockReloadRecorder()
55+
wr := newMockWatcher(nil)
56+
wr.errorCh <- errors.New("sentinel")
57+
close(wr.errorCh)
58+
59+
rl := NewReloader(dir)
60+
rl.watchFn = func() (watcher, error) { return wr, nil }
61+
62+
err := rl.Watch(ctx, rr.configFn)
63+
assert.NotNil(t, err)
64+
65+
msg := "reloader: fsnotify error channel was closed"
66+
if exp, got := msg, err.Error(); exp != got {
67+
assert.Equal(t, exp, got)
68+
}
69+
}
70+
71+
// test watch event chan closed
72+
{
73+
rr := mockReloadRecorder()
74+
wr := newMockWatcher(nil)
75+
close(wr.eventCh)
76+
77+
rl := NewReloader(dir)
78+
rl.reloadIval = time.Second / 100
79+
rl.watchFn = func() (watcher, error) { return wr, nil }
80+
81+
err := rl.Watch(ctx, rr.configFn)
82+
if err == nil {
83+
assert.NotNil(t, err)
84+
}
85+
86+
msg := "reloader: fsnotify event channel was closed"
87+
if exp, got := msg, err.Error(); exp != got {
88+
assert.Equal(t, exp, got)
89+
}
90+
}
91+
92+
// test watch error chan
93+
{
94+
rr := mockReloadRecorder()
95+
wr := newMockWatcher(nil)
96+
wr.errorCh <- errors.New("sentinel")
97+
98+
rl := NewReloader(dir)
99+
rl.watchFn = func() (watcher, error) { return wr, nil }
100+
101+
egCtx, egCancel := context.WithCancel(ctx)
102+
defer egCancel()
103+
104+
var eg errgroup.Group
105+
eg.Go(func() error {
106+
return rl.Watch(egCtx, rr.configFn)
107+
})
108+
109+
// need to ensure errorCh drains so test isn't racey
110+
eg.Go(func() error {
111+
defer egCancel()
112+
113+
tr := time.NewTicker(time.Second / 100)
114+
defer tr.Stop()
115+
116+
for {
117+
select {
118+
case <-egCtx.Done():
119+
return egCtx.Err()
120+
case <-tr.C:
121+
if len(wr.errorCh) == 0 {
122+
return nil
123+
}
124+
}
125+
}
126+
})
127+
128+
err := eg.Wait()
129+
if exp, got := context.Canceled, err; exp != got {
130+
assert.Equal(t, exp, got)
131+
}
132+
}
133+
134+
// test an end to end config reload
135+
{
136+
rr := mockReloadRecorder()
137+
wr := newMockWatcher(nil)
138+
rl := NewReloader(dir)
139+
rl.watchFn = func() (watcher, error) { return wr, wr.getErr() }
140+
rl.reloadFn = rr.reloadFn
141+
142+
// Need to lower reload ival to pickup config write quicker.
143+
rl.reloadIval = time.Second / 10
144+
rl.tickerIval = rl.reloadIval / 10
145+
146+
egCtx, egCancel := context.WithCancel(ctx)
147+
defer egCancel()
148+
149+
var eg errgroup.Group
150+
eg.Go(func() error {
151+
return rl.Watch(egCtx, rr.configFn)
152+
})
153+
154+
// Copy a full and valid example configuration to trigger Watch
155+
{
156+
select {
157+
case <-egCtx.Done():
158+
assert.Nil(t, egCtx.Err())
159+
case v := <-wr.addCh:
160+
assert.Equal(t, v, dir)
161+
}
162+
163+
name := helpCopyEnvFile(t, dir, "01_example.env", "testdata/50_example.env")
164+
wr.eventCh <- fsnotify.Event{
165+
Name: name,
166+
Op: fsnotify.Create,
167+
}
168+
select {
169+
case <-egCtx.Done():
170+
assert.Nil(t, egCtx.Err())
171+
case cfg := <-rr.configCh:
172+
assert.NotNil(t, cfg)
173+
assert.Equal(t, cfg.External.Apple.Enabled, false)
174+
}
175+
}
176+
177+
{
178+
drain(rr.configCh)
179+
drain(rr.reloadCh)
180+
181+
name := helpWriteEnvFile(t, dir, "02_example.env", map[string]string{
182+
"GOTRUE_EXTERNAL_APPLE_ENABLED": "true",
183+
})
184+
wr.eventCh <- fsnotify.Event{
185+
Name: name,
186+
Op: fsnotify.Create,
187+
}
188+
select {
189+
case <-egCtx.Done():
190+
assert.Nil(t, egCtx.Err())
191+
case cfg := <-rr.configCh:
192+
assert.NotNil(t, cfg)
193+
assert.Equal(t, cfg.External.Apple.Enabled, true)
194+
}
195+
}
196+
197+
{
198+
name := helpWriteEnvFile(t, dir, "03_example.env.bak", map[string]string{
199+
"GOTRUE_EXTERNAL_APPLE_ENABLED": "false",
200+
})
201+
wr.eventCh <- fsnotify.Event{
202+
Name: name,
203+
Op: fsnotify.Create,
204+
}
205+
}
206+
207+
{
208+
// empty the reload ch
209+
drain(rr.reloadCh)
210+
211+
name := helpWriteEnvFile(t, dir, "04_example.env", map[string]string{
212+
"GOTRUE_SMTP_PORT": "ABC",
213+
})
214+
wr.eventCh <- fsnotify.Event{
215+
Name: name,
216+
Op: fsnotify.Create,
217+
}
218+
219+
select {
220+
case <-egCtx.Done():
221+
assert.Nil(t, egCtx.Err())
222+
case p := <-rr.reloadCh:
223+
if exp, got := dir, p; exp != got {
224+
assert.Equal(t, exp, got)
225+
}
226+
}
227+
}
228+
229+
{
230+
name := helpWriteEnvFile(t, dir, "05_example.env", map[string]string{
231+
"GOTRUE_SMTP_PORT": "2222",
232+
})
233+
wr.eventCh <- fsnotify.Event{
234+
Name: name,
235+
Op: fsnotify.Create,
236+
}
237+
select {
238+
case <-egCtx.Done():
239+
assert.Nil(t, egCtx.Err())
240+
case cfg := <-rr.configCh:
241+
assert.NotNil(t, cfg)
242+
assert.Equal(t, cfg.SMTP.Port, 2222)
243+
}
244+
}
245+
246+
// test the wr.Add doesn't exit if bad watch dir is given during tick
247+
{
248+
// set the error on watcher
249+
sentinelErr := errors.New("sentinel")
250+
wr.setErr(sentinelErr)
251+
252+
name := helpWriteEnvFile(t, dir, "05_example.env", map[string]string{
253+
"GOTRUE_SMTP_PORT": "2222",
254+
})
255+
wr.eventCh <- fsnotify.Event{
256+
Name: name,
257+
Op: fsnotify.Create,
258+
}
259+
select {
260+
case <-egCtx.Done():
261+
assert.Nil(t, egCtx.Err())
262+
case cfg := <-rr.configCh:
263+
assert.NotNil(t, cfg)
264+
assert.Equal(t, cfg.SMTP.Port, 2222)
265+
}
266+
}
267+
268+
// test cases ran, end context to unblock Wait()
269+
egCancel()
270+
271+
err := eg.Wait()
272+
if exp, got := context.Canceled, err; exp != got {
273+
assert.Equal(t, exp, got)
274+
}
275+
}
276+
}
277+
14278
func TestReloadConfig(t *testing.T) {
15279
dir, cleanup := helpTestDir(t)
16280
defer cleanup()
@@ -21,9 +285,7 @@ func TestReloadConfig(t *testing.T) {
21285
helpCopyEnvFile(t, dir, "01_example.env", "testdata/50_example.env")
22286
{
23287
cfg, err := rl.reload()
24-
if err != nil {
25-
t.Fatal(err)
26-
}
288+
assert.Nil(t, err)
27289
assert.NotNil(t, cfg)
28290
assert.Equal(t, cfg.External.Apple.Enabled, false)
29291
}
@@ -33,9 +295,7 @@ func TestReloadConfig(t *testing.T) {
33295
})
34296
{
35297
cfg, err := rl.reload()
36-
if err != nil {
37-
t.Fatal(err)
38-
}
298+
assert.Nil(t, err)
39299
assert.NotNil(t, cfg)
40300
assert.Equal(t, cfg.External.Apple.Enabled, true)
41301
}
@@ -45,12 +305,30 @@ func TestReloadConfig(t *testing.T) {
45305
})
46306
{
47307
cfg, err := rl.reload()
48-
if err != nil {
49-
t.Fatal(err)
50-
}
308+
assert.Nil(t, err)
51309
assert.NotNil(t, cfg)
52310
assert.Equal(t, cfg.External.Apple.Enabled, true)
53311
}
312+
313+
// test cfg reload failure
314+
helpWriteEnvFile(t, dir, "04_example.env", map[string]string{
315+
"PORT": "INVALIDPORT",
316+
"GOTRUE_SMTP_PORT": "ABC",
317+
})
318+
{
319+
cfg, err := rl.reload()
320+
assert.NotNil(t, err)
321+
assert.Nil(t, cfg)
322+
}
323+
324+
// test directory loading failure
325+
{
326+
cleanup()
327+
328+
cfg, err := rl.reload()
329+
assert.NotNil(t, err)
330+
assert.Nil(t, cfg)
331+
}
54332
}
55333

56334
func TestReloadCheckAt(t *testing.T) {
@@ -136,21 +414,21 @@ func helpTestDir(t testing.TB) (dir string, cleanup func()) {
136414
dir = filepath.Join("testdata", t.Name())
137415
err := os.MkdirAll(dir, 0750)
138416
if err != nil && !os.IsExist(err) {
139-
t.Fatal(err)
417+
assert.Nil(t, err)
140418
}
141419
return dir, func() { os.RemoveAll(dir) }
142420
}
143421

144422
func helpCopyEnvFile(t testing.TB, dir, name, src string) string {
145423
data, err := os.ReadFile(src) // #nosec G304
146424
if err != nil {
147-
log.Fatal(err)
425+
assert.Nil(t, err)
148426
}
149427

150428
dst := filepath.Join(dir, name)
151429
err = os.WriteFile(dst, data, 0600)
152430
if err != nil {
153-
t.Fatal(err)
431+
assert.Nil(t, err)
154432
}
155433
return dst
156434
}
@@ -166,8 +444,47 @@ func helpWriteEnvFile(t testing.TB, dir, name string, values map[string]string)
166444

167445
dst := filepath.Join(dir, name)
168446
err := os.WriteFile(dst, buf.Bytes(), 0600)
169-
if err != nil {
170-
t.Fatal(err)
171-
}
447+
assert.Nil(t, err)
172448
return dst
173449
}
450+
451+
func mockReloadRecorder() *reloadRecorder {
452+
rr := &reloadRecorder{
453+
configCh: make(chan *conf.GlobalConfiguration, 1024),
454+
reloadCh: make(chan string, 1024),
455+
}
456+
return rr
457+
}
458+
459+
func drain[C ~chan T, T any](ch C) (out []T) {
460+
for {
461+
select {
462+
case v := <-ch:
463+
out = append(out, v)
464+
default:
465+
return out
466+
}
467+
}
468+
}
469+
470+
type reloadRecorder struct {
471+
configCh chan *conf.GlobalConfiguration
472+
reloadCh chan string
473+
}
474+
475+
func (o *reloadRecorder) reloadFn(dir string) (*conf.GlobalConfiguration, error) {
476+
defer func() {
477+
select {
478+
case o.reloadCh <- dir:
479+
default:
480+
}
481+
}()
482+
return defaultReloadFn(dir)
483+
}
484+
485+
func (o *reloadRecorder) configFn(gc *conf.GlobalConfiguration) {
486+
select {
487+
case o.configCh <- gc:
488+
default:
489+
}
490+
}

0 commit comments

Comments
 (0)
Please sign in to comment.