Skip to content

Commit c09f90a

Browse files
committed
💜 add authentication middleware
1 parent 07e3ad8 commit c09f90a

15 files changed

+367
-40
lines changed

api/account.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package api
22

33
import (
44
"database/sql"
5+
"errors"
56
"net/http"
67

78
"github.com/gin-gonic/gin"
89
"github.com/lib/pq"
910
db "github.com/nc-minh/tinybank/db/sqlc"
11+
"github.com/nc-minh/tinybank/token"
1012
)
1113

1214
type createAccountRequest struct {
@@ -21,8 +23,10 @@ func (server *Server) createAccount(ctx *gin.Context) {
2123
return
2224
}
2325

26+
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
27+
2428
arg := db.CreateAccountParams{
25-
Owner: req.Owner,
29+
Owner: authPayload.Username,
2630
Balance: 0,
2731
Currency: req.Currency,
2832
}
@@ -66,6 +70,13 @@ func (server *Server) getAccount(ctx *gin.Context) {
6670
return
6771
}
6872

73+
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
74+
if account.Owner != authPayload.Username {
75+
err := errors.New("you don't have permission to access this account")
76+
ctx.JSON(http.StatusForbidden, errorResponse(err))
77+
return
78+
}
79+
6980
ctx.JSON(http.StatusOK, account)
7081

7182
}
@@ -83,7 +94,10 @@ func (server *Server) listAccount(ctx *gin.Context) {
8394
return
8495
}
8596

97+
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
98+
8699
arg := db.ListAccountsParams{
100+
Owner: authPayload.Username,
87101
Limit: req.Limit,
88102
Offset: (req.Offset - 1) * req.Limit,
89103
}

api/account_test.go

+48-5
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,33 @@ import (
99
"net/http"
1010
"net/http/httptest"
1111
"testing"
12+
"time"
1213

1314
"github.com/golang/mock/gomock"
1415
mockdb "github.com/nc-minh/tinybank/db/mock"
1516
db "github.com/nc-minh/tinybank/db/sqlc"
17+
"github.com/nc-minh/tinybank/token"
1618
"github.com/nc-minh/tinybank/utils"
1719
"github.com/stretchr/testify/require"
1820
)
1921

2022
func TestGetAccount(t *testing.T) {
21-
account := randomAccount()
23+
user, _ := randomUser(t)
24+
account := randomAccount(user.Username)
2225

2326
testCases := []struct {
2427
name string
2528
accountID int64
29+
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
2630
buildStubs func(store *mockdb.MockStore)
2731
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
2832
}{
2933
{
3034
name: "OK",
3135
accountID: account.ID,
36+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
37+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
38+
},
3239
buildStubs: func(store *mockdb.MockStore) {
3340
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(account, nil)
3441
},
@@ -37,9 +44,37 @@ func TestGetAccount(t *testing.T) {
3744
requireBodyMatchAccount(t, recorder.Body, account)
3845
},
3946
},
47+
{
48+
name: "UnauthorizedUser",
49+
accountID: account.ID,
50+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
51+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute)
52+
},
53+
buildStubs: func(store *mockdb.MockStore) {
54+
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(account, nil)
55+
},
56+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
57+
require.Equal(t, recorder.Code, http.StatusForbidden)
58+
},
59+
},
60+
{
61+
name: "NoAuthorization",
62+
accountID: account.ID,
63+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
64+
},
65+
buildStubs: func(store *mockdb.MockStore) {
66+
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
67+
},
68+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
69+
require.Equal(t, recorder.Code, http.StatusUnauthorized)
70+
},
71+
},
4072
{
4173
name: "NotFound",
4274
accountID: account.ID,
75+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
76+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
77+
},
4378
buildStubs: func(store *mockdb.MockStore) {
4479
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
4580
},
@@ -50,6 +85,9 @@ func TestGetAccount(t *testing.T) {
5085
{
5186
name: "InternalError",
5287
accountID: account.ID,
88+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
89+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
90+
},
5391
buildStubs: func(store *mockdb.MockStore) {
5492
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(db.Account{}, sql.ErrConnDone)
5593
},
@@ -60,6 +98,9 @@ func TestGetAccount(t *testing.T) {
6098
{
6199
name: "InvalidID",
62100
accountID: 0,
101+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
102+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
103+
},
63104
buildStubs: func(store *mockdb.MockStore) {
64105
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
65106
},
@@ -80,23 +121,25 @@ func TestGetAccount(t *testing.T) {
80121
store := mockdb.NewMockStore(ctrl)
81122
tc.buildStubs(store)
82123

83-
server := NewServer(store)
124+
server := newTestServer(t, store)
84125
recorder := httptest.NewRecorder()
85126

86127
url := fmt.Sprintf("/accounts/%d", tc.accountID)
87128
request := httptest.NewRequest(http.MethodGet, url, nil)
88-
server.router.ServeHTTP(recorder, request)
89129

130+
tc.setupAuth(t, request, server.tokenMaker)
131+
132+
server.router.ServeHTTP(recorder, request)
90133
tc.checkResponse(t, recorder)
91134
})
92135

93136
}
94137
}
95138

96-
func randomAccount() db.Account {
139+
func randomAccount(owner string) db.Account {
97140
return db.Account{
98141
ID: utils.RandomInt(1, 1000),
99-
Owner: utils.RandomOwner(),
142+
Owner: owner,
100143
Balance: utils.RandomInt(1, 1000),
101144
Currency: utils.RandomCurrency(),
102145
}

api/main_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,26 @@ package api
33
import (
44
"os"
55
"testing"
6+
"time"
67

78
"github.com/gin-gonic/gin"
9+
db "github.com/nc-minh/tinybank/db/sqlc"
10+
"github.com/nc-minh/tinybank/utils"
11+
"github.com/stretchr/testify/require"
812
)
913

14+
func newTestServer(t *testing.T, store db.Store) *Server {
15+
config := utils.Config{
16+
TokenSymmetricKey: utils.RandomString(32),
17+
AccessTokenDuration: time.Minute,
18+
}
19+
20+
server, err := NewServer(config, store)
21+
require.NoError(t, err)
22+
23+
return server
24+
}
25+
1026
func TestMain(m *testing.M) {
1127
gin.SetMode(gin.TestMode)
1228

api/middleware.go

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package api
2+
3+
import (
4+
"errors"
5+
"net/http"
6+
"strings"
7+
8+
"github.com/gin-gonic/gin"
9+
"github.com/nc-minh/tinybank/token"
10+
)
11+
12+
const (
13+
authorizationHeaderKey = "authorization"
14+
authorizationTypeBearer = "bearer"
15+
authorizationPayloadKey = "authorization_payload"
16+
)
17+
18+
func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
19+
return func(ctx *gin.Context) {
20+
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)
21+
if len(authorizationHeader) == 0 {
22+
err := errors.New("authorization header is empty")
23+
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
24+
return
25+
}
26+
27+
fields := strings.Fields(authorizationHeader)
28+
if len(fields) < 2 {
29+
err := errors.New("authorization header is invalid")
30+
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
31+
return
32+
}
33+
34+
authorizationType := strings.ToLower(fields[0])
35+
if authorizationType != authorizationTypeBearer {
36+
err := errors.New("authorization type is not bearer")
37+
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
38+
return
39+
}
40+
41+
accessToken := fields[1]
42+
43+
payload, err := tokenMaker.VerifyToken(accessToken)
44+
if err != nil {
45+
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
46+
return
47+
}
48+
49+
ctx.Set(authorizationPayloadKey, payload)
50+
ctx.Next()
51+
}
52+
}

api/middleware_test.go

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package api
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
"time"
9+
10+
"github.com/gin-gonic/gin"
11+
"github.com/nc-minh/tinybank/token"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func addAuthorization(
16+
t *testing.T,
17+
request *http.Request,
18+
tokenMaker token.Maker,
19+
authorizationType string,
20+
username string,
21+
duration time.Duration,
22+
) {
23+
token, err := tokenMaker.CreateToken(username, duration)
24+
require.NoError(t, err)
25+
26+
authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
27+
request.Header.Set(authorizationHeaderKey, authorizationHeader)
28+
}
29+
30+
func TestAuthMiddleware(t *testing.T) {
31+
testCases := []struct {
32+
name string
33+
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
34+
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
35+
}{
36+
{
37+
name: "OK",
38+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
39+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute)
40+
},
41+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
42+
require.Equal(t, http.StatusOK, recorder.Code)
43+
},
44+
},
45+
{
46+
name: "NoAuthorization",
47+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
48+
},
49+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
50+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
51+
},
52+
},
53+
{
54+
name: "UnsupportedAuthorization",
55+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
56+
addAuthorization(t, request, tokenMaker, "unsupported", "user", time.Minute)
57+
},
58+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
59+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
60+
},
61+
},
62+
{
63+
name: "InvalidAuthorizationFormat",
64+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
65+
addAuthorization(t, request, tokenMaker, "", "user", time.Minute)
66+
},
67+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
68+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
69+
},
70+
},
71+
{
72+
name: "ExpiredToken",
73+
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
74+
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", -time.Minute)
75+
},
76+
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
77+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
78+
},
79+
},
80+
}
81+
82+
for i := range testCases {
83+
tc := testCases[i]
84+
85+
t.Run(tc.name, func(t *testing.T) {
86+
server := newTestServer(t, nil)
87+
88+
authPath := "/auth"
89+
server.router.GET(
90+
authPath,
91+
authMiddleware(server.tokenMaker),
92+
func(ctx *gin.Context) {
93+
ctx.JSON(http.StatusOK, gin.H{})
94+
},
95+
)
96+
97+
recorder := httptest.NewRecorder()
98+
request, err := http.NewRequest(http.MethodGet, authPath, nil)
99+
require.NoError(t, err)
100+
101+
tc.setupAuth(t, request, server.tokenMaker)
102+
server.router.ServeHTTP(recorder, request)
103+
tc.checkResponse(t, recorder)
104+
})
105+
}
106+
}

0 commit comments

Comments
 (0)