Skip to content

Commit b2f68ba

Browse files
authored
Merge pull request #20 from seatgeek/zh-user-ctx
add ctx passing to user.Store interface
2 parents 856dcca + f542cab commit b2f68ba

File tree

8 files changed

+69
-59
lines changed

8 files changed

+69
-59
lines changed

pkg/notifier/default.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type DefaultNotifier struct {
2424
func (d *DefaultNotifier) Push(ctx context.Context, notification common.Notification) error {
2525
var errs []error
2626

27-
recipientUser, err := d.userStore.Find(notification.Recipient())
27+
recipientUser, err := d.userStore.Find(ctx, notification.Recipient())
2828
if err != nil {
2929
slog.Debug("failed to find user", "id", notification.Context().ID, "user", notification.Recipient().String(), "error", err)
3030
return fmt.Errorf("failed to find recipient user: %w", err)

pkg/user/handler.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (ph *PreferencesHandler) GetPreferences(writer http.ResponseWriter, request
4141
vars := mux.Vars(request)
4242
key := vars["key"]
4343

44-
u, err := ph.userStore.Get(key)
44+
u, err := ph.userStore.Get(request.Context(), key)
4545
if err != nil {
4646
if errors.Is(err, ErrUserNotFound) {
4747
slog.Info("user not found", "key", key)
@@ -72,7 +72,7 @@ func (ph *PreferencesHandler) UpdatePreferences(writer http.ResponseWriter, requ
7272
return
7373
}
7474

75-
err := ph.userStore.SetPreferences(key, req.Preferences)
75+
err := ph.userStore.SetPreferences(request.Context(), key, req.Preferences)
7676
if err != nil {
7777
if errors.Is(err, ErrUserNotFound) {
7878
slog.Info("user not found", "key", key)

pkg/user/handler_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package user
66

77
import (
88
"bytes"
9+
"context"
910
"net/http/httptest"
1011
"testing"
1112

@@ -210,7 +211,7 @@ func TestPreferencesHandler_UpdatePreferences(t *testing.T) {
210211
}
211212
}`, writer.Body.String())
212213

213-
user, err := handler.userStore.Get("rufus")
214+
user, err := handler.userStore.Get(context.Background(), "rufus")
214215
assert.NoError(t, err)
215216

216217
assert.False(t, user.Wants("com.gitlab.push", "slack"))

pkg/user/postgres/store.go

+13-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package postgres
77

88
import (
9+
"context"
910
"fmt"
1011
"time"
1112

@@ -52,15 +53,15 @@ func NewPostgresStore(db *gorm.DB) *Store {
5253
}
5354

5455
// Add upserts a user to the postgres store
55-
func (s *Store) Add(u *user.User) error {
56+
func (s *Store) Add(ctx context.Context, u *user.User) error {
5657
var emails []string
5758
for _, id := range u.Identifiers.ToList() {
5859
if id.Kind() == identifier.KindEmail {
5960
emails = append(emails, id.Value)
6061
}
6162
}
6263

63-
result := s.db.Save(&UserModel{
64+
result := s.db.WithContext(ctx).Save(&UserModel{
6465
Key: u.Key,
6566
Preferences: u.Preferences,
6667
Identifiers: u.Identifiers.ToMap(),
@@ -70,12 +71,12 @@ func (s *Store) Add(u *user.User) error {
7071
}
7172

7273
// Find implements user.Store.
73-
func (s *Store) Find(possibleIdentifiers identifier.Set) (*user.User, error) {
74+
func (s *Store) Find(ctx context.Context, possibleIdentifiers identifier.Set) (*user.User, error) {
7475
if possibleIdentifiers.Len() == 0 {
7576
return nil, fmt.Errorf("%w: no identifiers provided", user.ErrUserNotFound)
7677
}
7778

78-
query := s.db.Model(&UserModel{})
79+
query := s.db.WithContext(ctx).Model(&UserModel{})
7980
for _, id := range possibleIdentifiers.ToList() {
8081
query = query.Or("identifiers @> ?", fmt.Sprintf(`{"%s": "%s"}`, id.NamespaceAndKind, id.Value))
8182
}
@@ -105,7 +106,7 @@ func (s *Store) Find(possibleIdentifiers identifier.Set) (*user.User, error) {
105106
return nil, fmt.Errorf("%w: no identifiers matched and no fallback emails were available", user.ErrUserNotFound)
106107
}
107108

108-
query = s.db.Model(&UserModel{})
109+
query = s.db.WithContext(ctx).Model(&UserModel{})
109110
for email := range possibleEmails {
110111
query = query.Or("emails @> ?", fmt.Sprintf(`"%s"`, email))
111112
}
@@ -126,25 +127,25 @@ func (s *Store) Find(possibleIdentifiers identifier.Set) (*user.User, error) {
126127
}
127128

128129
// Get implements user.Store.
129-
func (s *Store) Get(key string) (*user.User, error) {
130+
func (s *Store) Get(ctx context.Context, key string) (*user.User, error) {
130131
var u UserModel
131-
if err := s.db.Where("key = ?", key).First(&u).Error; err != nil {
132+
if err := s.db.WithContext(ctx).Where("key = ?", key).First(&u).Error; err != nil {
132133
return nil, err
133134
}
134135

135136
return u.ToUser(), nil
136137
}
137138

138139
// GetByIdentifier implements user.Store.
139-
func (s *Store) GetByIdentifier(id identifier.Identifier) (*user.User, error) {
140+
func (s *Store) GetByIdentifier(ctx context.Context, id identifier.Identifier) (*user.User, error) {
140141
var u UserModel
141-
if err := s.db.Where("identifiers @> ?", fmt.Sprintf(`{"%s": "%s"}`, id.NamespaceAndKind, id.Value)).First(&u).Error; err == nil {
142+
if err := s.db.WithContext(ctx).Where("identifiers @> ?", fmt.Sprintf(`{"%s": "%s"}`, id.NamespaceAndKind, id.Value)).First(&u).Error; err == nil {
142143
return u.ToUser(), nil
143144
}
144145

145146
// Fall back to any email identifier
146147
if id.Kind() == identifier.KindEmail {
147-
if err := s.db.Where("emails @> ?", fmt.Sprintf(`"%s"`, id.Value)).First(&u).Error; err == nil {
148+
if err := s.db.WithContext(ctx).Where("emails @> ?", fmt.Sprintf(`"%s"`, id.Value)).First(&u).Error; err == nil {
148149
return u.ToUser(), nil
149150
}
150151
}
@@ -153,8 +154,8 @@ func (s *Store) GetByIdentifier(id identifier.Identifier) (*user.User, error) {
153154
}
154155

155156
// SetPreferences implements user.Store.
156-
func (s *Store) SetPreferences(key string, prefs user.Preferences) error {
157-
return s.db.Model(&UserModel{}).Where("key = ?", key).Update("preferences", prefs).Error
157+
func (s *Store) SetPreferences(ctx context.Context, key string, prefs user.Preferences) error {
158+
return s.db.WithContext(ctx).Model(&UserModel{}).Where("key = ?", key).Update("preferences", prefs).Error
158159
}
159160

160161
var _ user.Store = &Store{}

pkg/user/postgres/store_test.go

+23-23
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestPostgresStore_Add(t *testing.T) {
2727
store := createDatastore(t)
2828

2929
// Prove that our user doesn't exist in the database yet
30-
_, err := store.GetByIdentifier(identifier.New("email", "[email protected]"))
30+
_, err := store.GetByIdentifier(context.Background(), identifier.New("email", "[email protected]"))
3131
assert.ErrorIs(t, err, user.ErrUserNotFound)
3232

3333
// Insert a new user
@@ -36,53 +36,53 @@ func TestPostgresStore_Add(t *testing.T) {
3636
user.WithIdentifier(identifier.New("email", "[email protected]")),
3737
)
3838

39-
err = store.Add(u)
39+
err = store.Add(context.Background(), u)
4040
assert.NoError(t, err)
4141

4242
// Check if inserted (by key)
43-
got, err := store.Get(u.Key)
43+
got, err := store.Get(context.Background(), u.Key)
4444
assert.NoError(t, err)
4545
assert.Equal(t, u, got)
4646

4747
// Check if inserted (by identifier)
48-
got, err = store.GetByIdentifier(identifier.New("email", "[email protected]"))
48+
got, err = store.GetByIdentifier(context.Background(), identifier.New("email", "[email protected]"))
4949
assert.NoError(t, err)
5050
assert.Equal(t, u, got)
5151

5252
// Update that same user object
5353
u.Identifiers.Add(identifier.New("gitlab.com/email", "[email protected]"))
5454

55-
err = store.Add(u)
55+
err = store.Add(context.Background(), u)
5656
assert.NoError(t, err)
5757

5858
// Check if updated (by key)
59-
got, err = store.Get(u.Key)
59+
got, err = store.Get(context.Background(), u.Key)
6060
assert.NoError(t, err)
6161
assert.Equal(t, u, got)
6262

6363
// Check if updated (by identifier)
64-
got, err = store.GetByIdentifier(identifier.New("email", "[email protected]"))
64+
got, err = store.GetByIdentifier(context.Background(), identifier.New("email", "[email protected]"))
6565
assert.NoError(t, err)
6666
assert.Equal(t, u, got)
6767

6868
// Update that user using a completely different object with the same key and different identifier
6969
u = user.New("codell", user.WithIdentifier(identifier.New("email", "[email protected]")))
7070

71-
err = store.Add(u)
71+
err = store.Add(context.Background(), u)
7272
assert.NoError(t, err)
7373

7474
// Check if updated (by key)
75-
got, err = store.Get(u.Key)
75+
got, err = store.Get(context.Background(), u.Key)
7676
assert.NoError(t, err)
7777
assert.Equal(t, u, got)
7878

7979
// Check if updated (by identifier)
80-
got, err = store.GetByIdentifier(identifier.New("email", "[email protected]"))
80+
got, err = store.GetByIdentifier(context.Background(), identifier.New("email", "[email protected]"))
8181
assert.NoError(t, err)
8282
assert.Equal(t, u, got)
8383

8484
// Check if old identifier is gone
85-
_, err = store.GetByIdentifier(identifier.New("email", "[email protected]"))
85+
_, err = store.GetByIdentifier(context.Background(), identifier.New("email", "[email protected]"))
8686
assert.ErrorIs(t, err, user.ErrUserNotFound)
8787
}
8888

@@ -115,10 +115,10 @@ func TestPostgresStore_Get(t *testing.T) {
115115
t.Run(tc.name, func(t *testing.T) {
116116
t.Parallel()
117117

118-
err := store.Add(tc.expected)
118+
err := store.Add(context.Background(), tc.expected)
119119
assert.NoError(t, err)
120120

121-
got, err := store.Get(tc.key)
121+
got, err := store.Get(context.Background(), tc.key)
122122
assert.NoError(t, err)
123123
assert.Equal(t, tc.expected, got)
124124
})
@@ -178,11 +178,11 @@ func TestPostgresStore_Find(t *testing.T) {
178178
t.Parallel()
179179

180180
if tc.expected != nil {
181-
err := store.Add(tc.expected)
181+
err := store.Add(context.Background(), tc.expected)
182182
assert.NoError(t, err)
183183
}
184184

185-
got, err := store.Find(tc.arg)
185+
got, err := store.Find(context.Background(), tc.arg)
186186

187187
if tc.wantErr != nil {
188188
assert.ErrorIs(t, err, tc.wantErr)
@@ -202,19 +202,19 @@ func TestPostgresStore_Find_duplicate(t *testing.T) {
202202

203203
duplicateIdentifier := identifier.New("email", "[email protected]")
204204

205-
err := store.Add(user.New(
205+
err := store.Add(context.Background(), user.New(
206206
"duplicateA",
207207
user.WithIdentifier(duplicateIdentifier),
208208
))
209209
assert.NoError(t, err)
210210

211-
err = store.Add(user.New(
211+
err = store.Add(context.Background(), user.New(
212212
"duplicateb",
213213
user.WithIdentifier(duplicateIdentifier),
214214
))
215215
assert.NoError(t, err)
216216

217-
got, err := store.Find(identifier.NewSet(duplicateIdentifier))
217+
got, err := store.Find(context.Background(), identifier.NewSet(duplicateIdentifier))
218218

219219
wantErr := errors.New("found multiple users with identifiers: [email:[email protected]]")
220220
//nolint:testifylint
@@ -264,11 +264,11 @@ func TestPostgresStore_GetByIdentifier(t *testing.T) {
264264
t.Parallel()
265265

266266
if tc.expected != nil {
267-
err := store.Add(tc.expected)
267+
err := store.Add(context.Background(), tc.expected)
268268
assert.NoError(t, err)
269269
}
270270

271-
got, err := store.GetByIdentifier(tc.arg)
271+
got, err := store.GetByIdentifier(context.Background(), tc.arg)
272272

273273
if tc.wantErr != nil {
274274
assert.ErrorIs(t, err, tc.wantErr)
@@ -291,7 +291,7 @@ func TestPostgresStore_SetPreferences(t *testing.T) {
291291
"zach",
292292
user.WithIdentifier(identifier.New("email", "[email protected]")),
293293
)
294-
err := store.Add(userWithNoPreferences)
294+
err := store.Add(context.Background(), userWithNoPreferences)
295295
assert.NoError(t, err)
296296

297297
// Set preferences
@@ -301,7 +301,7 @@ func TestPostgresStore_SetPreferences(t *testing.T) {
301301
"slack": false,
302302
},
303303
}
304-
err = store.SetPreferences(userWithNoPreferences.Key, expectedPreferences)
304+
err = store.SetPreferences(context.Background(), userWithNoPreferences.Key, expectedPreferences)
305305
assert.NoError(t, err)
306306

307307
// Check if set
@@ -310,7 +310,7 @@ func TestPostgresStore_SetPreferences(t *testing.T) {
310310
user.WithIdentifiers(userWithNoPreferences.Identifiers),
311311
user.WithPreferences(expectedPreferences),
312312
)
313-
got, err := store.Get(userWithNoPreferences.Key)
313+
got, err := store.Get(context.Background(), userWithNoPreferences.Key)
314314
assert.NoError(t, err)
315315
assert.Equal(t, expectedUser, got)
316316
}

pkg/user/store.go

+12-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package user
66

77
import (
8+
"context"
89
"errors"
910

1011
"github.com/seatgeek/mailroom/pkg/identifier"
@@ -24,16 +25,16 @@ var ErrUserNotFound = errors.New("user not found")
2425
// new integrations that utilize email identifiers without having to update all existing user information in the store.
2526
type Store interface {
2627
// Get returns a user by its key, or an error if the user is not found
27-
Get(key string) (*User, error)
28+
Get(ctx context.Context, key string) (*User, error)
2829
// GetByIdentifier returns a user by a given identifier, or an error if the user is not found
29-
GetByIdentifier(identifier identifier.Identifier) (*User, error)
30+
GetByIdentifier(ctx context.Context, identifier identifier.Identifier) (*User, error)
3031

3132
// Find searches for a user matching any of the given identifiers
3233
// (The user is not required to match all of them, just one is enough)
33-
Find(possibleIdentifiers identifier.Set) (*User, error)
34+
Find(ctx context.Context, possibleIdentifiers identifier.Set) (*User, error)
3435

3536
// SetPreferences replaces the preferences for a user by key
36-
SetPreferences(key string, prefs Preferences) error
37+
SetPreferences(ctx context.Context, key string, prefs Preferences) error
3738
}
3839

3940
// InMemoryStore is a simple in-memory implementation of the Store interface
@@ -50,12 +51,12 @@ func NewInMemoryStore(users ...*User) *InMemoryStore {
5051
}
5152

5253
// Add adds a user to the in-memory store
53-
func (s *InMemoryStore) Add(u *User) error {
54+
func (s *InMemoryStore) Add(ctx context.Context, u *User) error {
5455
s.users = append(s.users, u)
5556
return nil
5657
}
5758

58-
func (s *InMemoryStore) Get(key string) (*User, error) {
59+
func (s *InMemoryStore) Get(ctx context.Context, key string) (*User, error) {
5960
for _, u := range s.users {
6061
if u.Key == key {
6162
return u, nil
@@ -65,7 +66,7 @@ func (s *InMemoryStore) Get(key string) (*User, error) {
6566
return nil, ErrUserNotFound
6667
}
6768

68-
func (s *InMemoryStore) GetByIdentifier(identifier identifier.Identifier) (*User, error) {
69+
func (s *InMemoryStore) GetByIdentifier(ctx context.Context, identifier identifier.Identifier) (*User, error) {
6970
isEmail := identifier.Kind() == "email"
7071

7172
for _, u := range s.users {
@@ -85,9 +86,9 @@ func (s *InMemoryStore) GetByIdentifier(identifier identifier.Identifier) (*User
8586
return nil, ErrUserNotFound
8687
}
8788

88-
func (s *InMemoryStore) Find(possibleIdentifiers identifier.Set) (*User, error) {
89+
func (s *InMemoryStore) Find(ctx context.Context, possibleIdentifiers identifier.Set) (*User, error) {
8990
for _, i := range possibleIdentifiers.ToList() {
90-
u, err := s.GetByIdentifier(i)
91+
u, err := s.GetByIdentifier(ctx, i)
9192
if err == nil {
9293
return u, nil
9394
}
@@ -96,8 +97,8 @@ func (s *InMemoryStore) Find(possibleIdentifiers identifier.Set) (*User, error)
9697
return nil, ErrUserNotFound
9798
}
9899

99-
func (s *InMemoryStore) SetPreferences(key string, prefs Preferences) error {
100-
u, err := s.Get(key)
100+
func (s *InMemoryStore) SetPreferences(ctx context.Context, key string, prefs Preferences) error {
101+
u, err := s.Get(ctx, key)
101102
if err != nil {
102103
return err
103104
}

0 commit comments

Comments
 (0)