@@ -2,15 +2,279 @@ package reloader
2
2
3
3
import (
4
4
"bytes"
5
- "log"
5
+ "context"
6
+ "errors"
6
7
"os"
8
+ "path"
7
9
"path/filepath"
8
10
"testing"
9
11
"time"
10
12
13
+ "github.com/fsnotify/fsnotify"
11
14
"github.com/stretchr/testify/assert"
15
+ "github.com/supabase/auth/internal/conf"
16
+ "golang.org/x/sync/errgroup"
12
17
)
13
18
19
+ func TestWatch (t * testing.T ) {
20
+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
21
+ defer cancel ()
22
+
23
+ dir , cleanup := helpTestDir (t )
24
+ defer cleanup ()
25
+
26
+ // test broken watcher
27
+ {
28
+ sentinelErr := errors .New ("sentinel" )
29
+ rr := mockReloadRecorder ()
30
+ rl := NewReloader (dir )
31
+ rl .watchFn = func () (watcher , error ) { return nil , sentinelErr }
32
+
33
+ err := rl .Watch (ctx , rr .configFn )
34
+ if exp , got := sentinelErr , err ; exp != got {
35
+ assert .Equal (t , exp , got )
36
+ }
37
+ }
38
+
39
+ // test watch invalid dir
40
+ {
41
+ doneCtx , doneCancel := context .WithCancel (ctx )
42
+ doneCancel ()
43
+
44
+ rr := mockReloadRecorder ()
45
+ rl := NewReloader (path .Join (dir , "__not_found__" ))
46
+ err := rl .Watch (doneCtx , rr .configFn )
47
+ if exp , got := context .Canceled , err ; exp != got {
48
+ assert .Equal (t , exp , got )
49
+ }
50
+ }
51
+
52
+ // test watch error chan closed
53
+ {
54
+ rr := mockReloadRecorder ()
55
+ wr := newMockWatcher (nil )
56
+ wr .errorCh <- errors .New ("sentinel" )
57
+ close (wr .errorCh )
58
+
59
+ rl := NewReloader (dir )
60
+ rl .watchFn = func () (watcher , error ) { return wr , nil }
61
+
62
+ err := rl .Watch (ctx , rr .configFn )
63
+ assert .NotNil (t , err )
64
+
65
+ msg := "reloader: fsnotify error channel was closed"
66
+ if exp , got := msg , err .Error (); exp != got {
67
+ assert .Equal (t , exp , got )
68
+ }
69
+ }
70
+
71
+ // test watch event chan closed
72
+ {
73
+ rr := mockReloadRecorder ()
74
+ wr := newMockWatcher (nil )
75
+ close (wr .eventCh )
76
+
77
+ rl := NewReloader (dir )
78
+ rl .reloadIval = time .Second / 100
79
+ rl .watchFn = func () (watcher , error ) { return wr , nil }
80
+
81
+ err := rl .Watch (ctx , rr .configFn )
82
+ if err == nil {
83
+ assert .NotNil (t , err )
84
+ }
85
+
86
+ msg := "reloader: fsnotify event channel was closed"
87
+ if exp , got := msg , err .Error (); exp != got {
88
+ assert .Equal (t , exp , got )
89
+ }
90
+ }
91
+
92
+ // test watch error chan
93
+ {
94
+ rr := mockReloadRecorder ()
95
+ wr := newMockWatcher (nil )
96
+ wr .errorCh <- errors .New ("sentinel" )
97
+
98
+ rl := NewReloader (dir )
99
+ rl .watchFn = func () (watcher , error ) { return wr , nil }
100
+
101
+ egCtx , egCancel := context .WithCancel (ctx )
102
+ defer egCancel ()
103
+
104
+ var eg errgroup.Group
105
+ eg .Go (func () error {
106
+ return rl .Watch (egCtx , rr .configFn )
107
+ })
108
+
109
+ // need to ensure errorCh drains so test isn't racey
110
+ eg .Go (func () error {
111
+ defer egCancel ()
112
+
113
+ tr := time .NewTicker (time .Second / 100 )
114
+ defer tr .Stop ()
115
+
116
+ for {
117
+ select {
118
+ case <- egCtx .Done ():
119
+ return egCtx .Err ()
120
+ case <- tr .C :
121
+ if len (wr .errorCh ) == 0 {
122
+ return nil
123
+ }
124
+ }
125
+ }
126
+ })
127
+
128
+ err := eg .Wait ()
129
+ if exp , got := context .Canceled , err ; exp != got {
130
+ assert .Equal (t , exp , got )
131
+ }
132
+ }
133
+
134
+ // test an end to end config reload
135
+ {
136
+ rr := mockReloadRecorder ()
137
+ wr := newMockWatcher (nil )
138
+ rl := NewReloader (dir )
139
+ rl .watchFn = func () (watcher , error ) { return wr , wr .getErr () }
140
+ rl .reloadFn = rr .reloadFn
141
+
142
+ // Need to lower reload ival to pickup config write quicker.
143
+ rl .reloadIval = time .Second / 10
144
+ rl .tickerIval = rl .reloadIval / 10
145
+
146
+ egCtx , egCancel := context .WithCancel (ctx )
147
+ defer egCancel ()
148
+
149
+ var eg errgroup.Group
150
+ eg .Go (func () error {
151
+ return rl .Watch (egCtx , rr .configFn )
152
+ })
153
+
154
+ // Copy a full and valid example configuration to trigger Watch
155
+ {
156
+ select {
157
+ case <- egCtx .Done ():
158
+ assert .Nil (t , egCtx .Err ())
159
+ case v := <- wr .addCh :
160
+ assert .Equal (t , v , dir )
161
+ }
162
+
163
+ name := helpCopyEnvFile (t , dir , "01_example.env" , "testdata/50_example.env" )
164
+ wr .eventCh <- fsnotify.Event {
165
+ Name : name ,
166
+ Op : fsnotify .Create ,
167
+ }
168
+ select {
169
+ case <- egCtx .Done ():
170
+ assert .Nil (t , egCtx .Err ())
171
+ case cfg := <- rr .configCh :
172
+ assert .NotNil (t , cfg )
173
+ assert .Equal (t , cfg .External .Apple .Enabled , false )
174
+ }
175
+ }
176
+
177
+ {
178
+ drain (rr .configCh )
179
+ drain (rr .reloadCh )
180
+
181
+ name := helpWriteEnvFile (t , dir , "02_example.env" , map [string ]string {
182
+ "GOTRUE_EXTERNAL_APPLE_ENABLED" : "true" ,
183
+ })
184
+ wr .eventCh <- fsnotify.Event {
185
+ Name : name ,
186
+ Op : fsnotify .Create ,
187
+ }
188
+ select {
189
+ case <- egCtx .Done ():
190
+ assert .Nil (t , egCtx .Err ())
191
+ case cfg := <- rr .configCh :
192
+ assert .NotNil (t , cfg )
193
+ assert .Equal (t , cfg .External .Apple .Enabled , true )
194
+ }
195
+ }
196
+
197
+ {
198
+ name := helpWriteEnvFile (t , dir , "03_example.env.bak" , map [string ]string {
199
+ "GOTRUE_EXTERNAL_APPLE_ENABLED" : "false" ,
200
+ })
201
+ wr .eventCh <- fsnotify.Event {
202
+ Name : name ,
203
+ Op : fsnotify .Create ,
204
+ }
205
+ }
206
+
207
+ {
208
+ // empty the reload ch
209
+ drain (rr .reloadCh )
210
+
211
+ name := helpWriteEnvFile (t , dir , "04_example.env" , map [string ]string {
212
+ "GOTRUE_SMTP_PORT" : "ABC" ,
213
+ })
214
+ wr .eventCh <- fsnotify.Event {
215
+ Name : name ,
216
+ Op : fsnotify .Create ,
217
+ }
218
+
219
+ select {
220
+ case <- egCtx .Done ():
221
+ assert .Nil (t , egCtx .Err ())
222
+ case p := <- rr .reloadCh :
223
+ if exp , got := dir , p ; exp != got {
224
+ assert .Equal (t , exp , got )
225
+ }
226
+ }
227
+ }
228
+
229
+ {
230
+ name := helpWriteEnvFile (t , dir , "05_example.env" , map [string ]string {
231
+ "GOTRUE_SMTP_PORT" : "2222" ,
232
+ })
233
+ wr .eventCh <- fsnotify.Event {
234
+ Name : name ,
235
+ Op : fsnotify .Create ,
236
+ }
237
+ select {
238
+ case <- egCtx .Done ():
239
+ assert .Nil (t , egCtx .Err ())
240
+ case cfg := <- rr .configCh :
241
+ assert .NotNil (t , cfg )
242
+ assert .Equal (t , cfg .SMTP .Port , 2222 )
243
+ }
244
+ }
245
+
246
+ // test the wr.Add doesn't exit if bad watch dir is given during tick
247
+ {
248
+ // set the error on watcher
249
+ sentinelErr := errors .New ("sentinel" )
250
+ wr .setErr (sentinelErr )
251
+
252
+ name := helpWriteEnvFile (t , dir , "05_example.env" , map [string ]string {
253
+ "GOTRUE_SMTP_PORT" : "2222" ,
254
+ })
255
+ wr .eventCh <- fsnotify.Event {
256
+ Name : name ,
257
+ Op : fsnotify .Create ,
258
+ }
259
+ select {
260
+ case <- egCtx .Done ():
261
+ assert .Nil (t , egCtx .Err ())
262
+ case cfg := <- rr .configCh :
263
+ assert .NotNil (t , cfg )
264
+ assert .Equal (t , cfg .SMTP .Port , 2222 )
265
+ }
266
+ }
267
+
268
+ // test cases ran, end context to unblock Wait()
269
+ egCancel ()
270
+
271
+ err := eg .Wait ()
272
+ if exp , got := context .Canceled , err ; exp != got {
273
+ assert .Equal (t , exp , got )
274
+ }
275
+ }
276
+ }
277
+
14
278
func TestReloadConfig (t * testing.T ) {
15
279
dir , cleanup := helpTestDir (t )
16
280
defer cleanup ()
@@ -21,9 +285,7 @@ func TestReloadConfig(t *testing.T) {
21
285
helpCopyEnvFile (t , dir , "01_example.env" , "testdata/50_example.env" )
22
286
{
23
287
cfg , err := rl .reload ()
24
- if err != nil {
25
- t .Fatal (err )
26
- }
288
+ assert .Nil (t , err )
27
289
assert .NotNil (t , cfg )
28
290
assert .Equal (t , cfg .External .Apple .Enabled , false )
29
291
}
@@ -33,9 +295,7 @@ func TestReloadConfig(t *testing.T) {
33
295
})
34
296
{
35
297
cfg , err := rl .reload ()
36
- if err != nil {
37
- t .Fatal (err )
38
- }
298
+ assert .Nil (t , err )
39
299
assert .NotNil (t , cfg )
40
300
assert .Equal (t , cfg .External .Apple .Enabled , true )
41
301
}
@@ -45,12 +305,30 @@ func TestReloadConfig(t *testing.T) {
45
305
})
46
306
{
47
307
cfg , err := rl .reload ()
48
- if err != nil {
49
- t .Fatal (err )
50
- }
308
+ assert .Nil (t , err )
51
309
assert .NotNil (t , cfg )
52
310
assert .Equal (t , cfg .External .Apple .Enabled , true )
53
311
}
312
+
313
+ // test cfg reload failure
314
+ helpWriteEnvFile (t , dir , "04_example.env" , map [string ]string {
315
+ "PORT" : "INVALIDPORT" ,
316
+ "GOTRUE_SMTP_PORT" : "ABC" ,
317
+ })
318
+ {
319
+ cfg , err := rl .reload ()
320
+ assert .NotNil (t , err )
321
+ assert .Nil (t , cfg )
322
+ }
323
+
324
+ // test directory loading failure
325
+ {
326
+ cleanup ()
327
+
328
+ cfg , err := rl .reload ()
329
+ assert .NotNil (t , err )
330
+ assert .Nil (t , cfg )
331
+ }
54
332
}
55
333
56
334
func TestReloadCheckAt (t * testing.T ) {
@@ -136,21 +414,21 @@ func helpTestDir(t testing.TB) (dir string, cleanup func()) {
136
414
dir = filepath .Join ("testdata" , t .Name ())
137
415
err := os .MkdirAll (dir , 0750 )
138
416
if err != nil && ! os .IsExist (err ) {
139
- t . Fatal ( err )
417
+ assert . Nil ( t , err )
140
418
}
141
419
return dir , func () { os .RemoveAll (dir ) }
142
420
}
143
421
144
422
func helpCopyEnvFile (t testing.TB , dir , name , src string ) string {
145
423
data , err := os .ReadFile (src ) // #nosec G304
146
424
if err != nil {
147
- log . Fatal ( err )
425
+ assert . Nil ( t , err )
148
426
}
149
427
150
428
dst := filepath .Join (dir , name )
151
429
err = os .WriteFile (dst , data , 0600 )
152
430
if err != nil {
153
- t . Fatal ( err )
431
+ assert . Nil ( t , err )
154
432
}
155
433
return dst
156
434
}
@@ -166,8 +444,47 @@ func helpWriteEnvFile(t testing.TB, dir, name string, values map[string]string)
166
444
167
445
dst := filepath .Join (dir , name )
168
446
err := os .WriteFile (dst , buf .Bytes (), 0600 )
169
- if err != nil {
170
- t .Fatal (err )
171
- }
447
+ assert .Nil (t , err )
172
448
return dst
173
449
}
450
+
451
+ func mockReloadRecorder () * reloadRecorder {
452
+ rr := & reloadRecorder {
453
+ configCh : make (chan * conf.GlobalConfiguration , 1024 ),
454
+ reloadCh : make (chan string , 1024 ),
455
+ }
456
+ return rr
457
+ }
458
+
459
+ func drain [C ~ chan T , T any ](ch C ) (out []T ) {
460
+ for {
461
+ select {
462
+ case v := <- ch :
463
+ out = append (out , v )
464
+ default :
465
+ return out
466
+ }
467
+ }
468
+ }
469
+
470
+ type reloadRecorder struct {
471
+ configCh chan * conf.GlobalConfiguration
472
+ reloadCh chan string
473
+ }
474
+
475
+ func (o * reloadRecorder ) reloadFn (dir string ) (* conf.GlobalConfiguration , error ) {
476
+ defer func () {
477
+ select {
478
+ case o .reloadCh <- dir :
479
+ default :
480
+ }
481
+ }()
482
+ return defaultReloadFn (dir )
483
+ }
484
+
485
+ func (o * reloadRecorder ) configFn (gc * conf.GlobalConfiguration ) {
486
+ select {
487
+ case o .configCh <- gc :
488
+ default :
489
+ }
490
+ }
0 commit comments