@@ -15,11 +15,14 @@ const (
15
15
csrfTokenLength = 32
16
16
csrfCookieName = "csrf_token"
17
17
csrfHeaderName = "X-CSRF-Token"
18
+ csrfContextKey = "CSRFToken"
19
+ cleanupInterval = 1 * time .Hour
20
+ tokenExpiryTime = 12 * time .Hour
18
21
)
19
22
20
23
var (
21
24
tokenStore = make (map [string ]time.Time )
22
- tokenStoreMu sync.Mutex
25
+ tokenStoreMu sync.RWMutex
23
26
timeNow = time .Now
24
27
csrfLogger * slog.Logger
25
28
)
@@ -28,47 +31,54 @@ func initCSRFLogger(logger *slog.Logger) {
28
31
csrfLogger = logger
29
32
}
30
33
31
- func generateCSRFToken () ( string , error ) {
34
+ func generateCSRFToken () string {
32
35
b := make ([]byte , csrfTokenLength )
33
- _ , err := rand .Read (b )
34
- if err != nil {
35
- return "" , err
36
+ if _ , err := rand .Read (b ); err != nil {
37
+ if csrfLogger != nil {
38
+ csrfLogger .Error ("Failed to generate CSRF token" , "error" , err )
39
+ }
40
+ return ""
36
41
}
37
- return base64 .StdEncoding .EncodeToString (b ), nil
42
+ return base64 .StdEncoding .EncodeToString (b )
38
43
}
39
44
40
- func setCSRFToken (r * http.Request , w http.ResponseWriter ) (string , error ) {
41
- token , err := generateCSRFToken ()
42
- if err != nil {
43
- return "" , err
45
+ func setCSRFToken (w http.ResponseWriter , r * http.Request ) (string , error ) {
46
+ token := generateCSRFToken ()
47
+ if token == "" {
48
+ return "" , errors . New ( "Failed to generate CSRF token" )
44
49
}
45
50
46
51
isSecure := r .TLS != nil || r .Header .Get ("X-Forwarded-Proto" ) == "https"
47
52
48
53
http .SetCookie (w , & http.Cookie {
49
54
Name : csrfCookieName ,
50
55
Value : token ,
56
+ Path : "/" ,
51
57
HttpOnly : true ,
52
- Secure : isSecure , // Set to true if using HTTPS
58
+ Secure : isSecure ,
53
59
SameSite : http .SameSiteStrictMode ,
60
+ MaxAge : int (tokenExpiryTime .Seconds ()),
54
61
})
55
62
56
- if csrfLogger != nil {
57
- csrfLogger .DebugContext (r .Context (), "set csrf cookie" )
58
- }
59
-
60
63
tokenStoreMu .Lock ()
61
- tokenStore [token ] = timeNow ().Add (24 * time . Hour ) // Token expires in 24 hours
64
+ tokenStore [token ] = timeNow ().Add (tokenExpiryTime )
62
65
tokenStoreMu .Unlock ()
63
66
64
67
return token , nil
65
68
}
66
69
70
+ func GetCSRFToken (r * http.Request ) string {
71
+ if token , ok := r .Context ().Value (csrfContextKey ).(string ); ok {
72
+ return token
73
+ }
74
+ return ""
75
+ }
76
+
67
77
func validateCSRFToken (r * http.Request ) error {
68
78
cookie , err := r .Cookie (csrfCookieName )
69
79
if err != nil {
70
80
if csrfLogger != nil {
71
- csrfLogger .DebugContext ( r . Context (), "csrf error " , slog . String ( "message " , err . Error ()) )
81
+ csrfLogger .Debug ( "CSRF cookie not found " , "error " , err )
72
82
}
73
83
return errors .New ("CSRF cookie not found" )
74
84
}
@@ -78,59 +88,99 @@ func validateCSRFToken(r *http.Request) error {
78
88
token = r .FormValue ("csrf_token" )
79
89
if token == "" {
80
90
if csrfLogger != nil {
81
- csrfLogger .DebugContext ( r . Context (), "csrf token not found in header or form" )
91
+ csrfLogger .Debug ( "CSRF token not found in header or form" )
82
92
}
83
- return errors .New ("CSRF token not found in header or form " )
93
+ return errors .New ("CSRF token not found" )
84
94
}
85
95
}
86
96
87
97
if cookie .Value != token {
88
98
if csrfLogger != nil {
89
- csrfLogger .DebugContext ( r . Context (), "csrf token mismatch" )
99
+ csrfLogger .Debug ( "CSRF token mismatch" )
90
100
}
91
101
return errors .New ("CSRF token mismatch" )
92
102
}
93
103
94
- tokenStoreMu .Lock ()
95
- defer tokenStoreMu .Unlock ()
96
-
104
+ tokenStoreMu .RLock ()
97
105
expiry , exists := tokenStore [token ]
106
+ tokenStoreMu .RUnlock ()
107
+
98
108
if ! exists {
99
109
if csrfLogger != nil {
100
- csrfLogger .DebugContext ( r . Context (), "csrf token not found in store" )
110
+ csrfLogger .Debug ( "CSRF token not found in store" )
101
111
}
102
112
return errors .New ("CSRF token not found in store" )
103
113
}
104
114
105
115
if timeNow ().After (expiry ) {
106
- delete (tokenStore , token )
107
116
if csrfLogger != nil {
108
- csrfLogger .DebugContext ( r . Context (), "csrf token expired" )
117
+ csrfLogger .Debug ( "CSRF token expired" )
109
118
}
110
119
return errors .New ("CSRF token expired" )
111
120
}
112
121
122
+ if csrfLogger != nil {
123
+ csrfLogger .Debug ("CSRF validation successful" )
124
+ }
113
125
return nil
114
126
}
115
127
116
128
func CSRFMiddleware (next http.HandlerFunc ) http.HandlerFunc {
117
129
return func (w http.ResponseWriter , r * http.Request ) {
118
- if r .Method == "GET" || r .Method == "HEAD" || r .Method == "OPTIONS" {
119
- token , err := setCSRFToken (r , w )
130
+ var token string
131
+ var err error
132
+
133
+ if r .Method == http .MethodGet || r .Method == http .MethodHead {
134
+ token , err = setCSRFToken (w , r )
120
135
if err != nil {
121
- http .Error (w , "Failed to set CSRF token" , http .StatusInternalServerError )
136
+ http .Error (w , http . StatusText ( http . StatusInternalServerError ) , http .StatusInternalServerError )
122
137
return
123
138
}
124
139
w .Header ().Set (csrfHeaderName , token )
125
-
126
- ctx := context .WithValue (r .Context (), "CSRFToken" , token )
127
- next .ServeHTTP (w , r .WithContext (ctx ))
128
140
} else {
129
141
if err := validateCSRFToken (r ); err != nil {
130
142
http .Error (w , "CSRF validation failed" , http .StatusForbidden )
131
143
return
132
144
}
145
+ token = r .Header .Get (csrfHeaderName )
146
+ if token == "" {
147
+ token = r .FormValue ("csrf_token" )
148
+ }
149
+ }
150
+
151
+ ctx := context .WithValue (r .Context (), csrfContextKey , token )
152
+ next .ServeHTTP (w , r .WithContext (ctx ))
153
+ }
154
+ }
155
+
156
+ func cleanupExpiredTokens () {
157
+ tokenStoreMu .Lock ()
158
+ defer tokenStoreMu .Unlock ()
159
+ now := timeNow ()
160
+ for token , expiry := range tokenStore {
161
+ if now .After (expiry ) {
162
+ delete (tokenStore , token )
133
163
}
134
- next .ServeHTTP (w , r )
135
164
}
165
+
166
+ if csrfLogger != nil {
167
+ csrfLogger .Debug ("Cleaned up expired CSRF tokens" )
168
+ }
169
+ }
170
+
171
+ func startCleanupRoutine (ctx context.Context ) {
172
+ ticker := time .NewTicker (cleanupInterval )
173
+ defer ticker .Stop ()
174
+ for {
175
+ select {
176
+ case <- ticker .C :
177
+ cleanupExpiredTokens ()
178
+ case <- ctx .Done ():
179
+ return
180
+ }
181
+ }
182
+ }
183
+
184
+ func init () {
185
+ go startCleanupRoutine (context .Background ())
136
186
}
0 commit comments