diff --git a/.travis.yml b/.travis.yml
index 74942d9..80cce77 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,13 +1,16 @@
language: go
sudo: false
-go:
- - 1.4
- - 1.5
- - tip
+
+matrix:
+ include:
+ - go: 1.5
+ - go: tip
+
install:
- go get golang.org/x/tools/cmd/vet
+
script:
- go get -t -v ./...
- - diff -u <(echo -n) <(gofmt -d -s .)
+ - diff -u <(echo -n) <(gofmt -d .)
- go tool vet .
- go test -v -race ./...
diff --git a/README.md b/README.md
index f065659..0d866a8 100644
--- a/README.md
+++ b/README.md
@@ -13,11 +13,11 @@ forgery](http://blog.codinghorror.com/preventing-csrf-and-xsrf-attacks/) (CSRF)
templates to replace a `{{ .csrfField }}` template tag with a hidden input
field.
-This library is designed to work with the [Goji](https://github.com/zenazn/goji)
-micro-framework, which is a simple web framework for Go that is broadly
-compatible with other parts of the Go ecosystem. It makes use of Goji's `web.C`
-request context, which doesn't rely on a global map, and is therefore safe to
-attach to your top-level router (if you so wish).
+This library is designed to work with not just the the [Goji](https://github.com/goji/goji)
+micro-framework, but any framework that accepts the `func(context.Context, w http.ResponseWriter, r *http.Request)`
+signature. This makes it compatible with other parts of the Go ecosystem. The
+`context.Context` request context doesn't rely on a global map, and is therefore
+free from contention in a busy web service.
The library also assumes HTTPS by default: sending cookies over vanilla HTTP
is risky and you're likely to get hurt.
@@ -27,7 +27,7 @@ is risky and you're likely to get hurt.
goji/csrf is easy to use: add the middleware to your stack with the below:
```go
-goji.Use(csrf.Protect([]byte("32-byte-long-auth-key")))
+goji.UseC(csrf.Protect([]byte("32-byte-long-auth-key")))
```
... and then collect the token with `csrf.Token(c, r)` before passing it to the
@@ -47,25 +47,27 @@ import (
"html/template"
"net/http"
+ "goji.io"
"github.com/goji/csrf"
- "github.com/zenazn/goji"
+ "github.com/zenazn/goji/graceful"
)
func main() {
+ m := goji.NewMux()
// Add the middleware to your router.
- goji.Use(csrf.Protect([]byte("32-byte-long-auth-key")))
- goji.Get("/signup", ShowSignupForm)
+ m.UseC(csrf.Protect([]byte("32-byte-long-auth-key")))
+ m.HandleFuncC(pat.Get("/signup"), ShowSignupForm)
// POST requests without a valid token will return a HTTP 403 Forbidden.
- goji.Post("/signup/post", SubmitSignupForm)
+ m.HandleFuncC(pat.Post("/signup/post"), SubmitSignupForm)
- goji.Serve()
+ graceful.ListenAndServe(":8000", m)
}
-func ShowSignupForm(c web.C, w http.ResponseWriter, r *http.Request) {
+func ShowSignupForm(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// signup_form.tmpl just needs a {{ .csrfField }} template tag for
// csrf.TemplateField to inject the CSRF token into. Easy!
t.ExecuteTemplate(w, "signup_form.tmpl", map[string]interface{
- csrf.TemplateTag: csrf.TemplateField(c, r),
+ csrf.TemplateTag: csrf.TemplateField(ctx, r),
})
// We could also retrieve the token directly from csrf.Token(c, r) and
// set it in the request header - w.Header.Set("X-CSRF-Token", token)
@@ -73,7 +75,7 @@ func ShowSignupForm(c web.C, w http.ResponseWriter, r *http.Request) {
// framework.
}
-func SubmitSignupForm(c web.C, w http.ResponseWriter, r *http.Request) {
+func SubmitSignupForm(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// We can trust that requests making it this far have satisfied
// our CSRF protection requirements.
}
@@ -91,38 +93,38 @@ as we don't handle any POST/PUT/DELETE requests with our top-level router.
package main
import (
+ "goji.io"
"github.com/goji/csrf"
"github.com/zenazn/goji/graceful"
- "github.com/zenazn/goji/web"
)
func main() {
- r := web.New()
+ m := goji.NewMux()
// Our top-level router doesn't need CSRF protection: it's simple.
- r.Get("/", ShowIndex)
+ m.HandleFuncC(pat.Get("/"), ShowIndex)
- api := web.New()
- r.Handle("/api/*", s)
+ api := goji.NewMux()
+ m.HandleC("/api/*", api)
// ... but our /api/* routes do, so we add it to the sub-router only.
- s.Use(csrf.Protect([]byte("32-byte-long-auth-key")))
+ api.UseC(csrf.Protect([]byte("32-byte-long-auth-key")))
- s.Get("/api/user/:id", GetUser)
- s.Post("/api/user", PostUser)
+ api.Get("/api/user/:id", GetUser)
+ api.Post("/api/user", PostUser)
- graceful.ListenAndServe(":8000", r)
+ graceful.ListenAndServe(":8000", m)
}
-func GetUser(c web.C, w http.ResponseWriter, r *http.Request) {
+func GetUser(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// Authenticate the request, get the :id from the route params,
// and fetch the user from the DB, etc.
// Get the token and pass it in the CSRF header. Our JSON-speaking client
// or JavaScript framework can now read the header and return the token in
// in its own "X-CSRF-Token" request header on the subsequent POST.
- w.Header().Set("X-CSRF-Token", csrf.Token(c, r))
+ w.Header().Set("X-CSRF-Token", csrf.Token(ctx, r))
b, err := json.Marshal(user)
if err != nil {
- http.Error(...)
+ http.Error(w, http.StatusText(500), 500)
return
}
@@ -138,23 +140,25 @@ goji/csrf provides options for changing these as you see fit:
```go
func main() {
+ m := goji.NewMux()
CSRF := csrf.Protect(
[]byte("a-32-byte-long-key-goes-here"),
csrf.RequestHeader("Authenticity-Token"),
csrf.FieldName("authenticity_token"),
- // Note that csrf.ErrorHandler takes a Goji web.Handler type, else
- // your error handler can't retrieve the error reason from the context.
- // The signature `func UnauthHandler(c web.C, w http.ResponseWriter, r *http.Request)`
- // is a web.Handler, and the simplest to use if you'd like to serve
+ // Note that csrf.ErrorHandler takes a Goji goji.Handler type, else
+ // your error handler can't retrieve the error reason from the
+ // context.
+ // The signature `func UnauthHandler(ctx context.Context, w http.ResponseWriter, r *http.Request)`
+ // is a goji.Handler, and the simplest to use if you'd like to serve
// "pretty" error pages (who doesn't?).
- csrf.ErrorHandler(web.HandlerFunc(serverError(403))),
+ csrf.ErrorHandler(goji.HandlerFunc(serverError(403))),
)
- goji.Use(CSRF)
- goji.Get("/signup", GetSignupForm)
- goji.Post("/signup", PostSignupForm)
+ m.UseC(CSRF)
+ m.HandleFuncC(pat.Get("/signup"), GetSignupForm)
+ m.HandleFuncC(pat.Post("/signup"), PostSignupForm)
- goji.Serve()
+ graceful.ListenAndServe(":8000", m)
}
```
diff --git a/csrf.go b/csrf.go
index 1295a40..c1bb0de 100644
--- a/csrf.go
+++ b/csrf.go
@@ -8,8 +8,11 @@ import (
"net/http"
"net/url"
+ "golang.org/x/net/context"
+
+ "goji.io"
+
"github.com/gorilla/securecookie"
- "github.com/zenazn/goji/web"
)
// CSRF token length in bytes.
@@ -52,8 +55,7 @@ var (
)
type csrf struct {
- c *web.C
- h http.Handler
+ h goji.Handler
sc *securecookie.SecureCookie
st store
opts options
@@ -70,7 +72,7 @@ type options struct {
Secure bool
RequestHeader string
FieldName string
- ErrorHandler web.Handler
+ ErrorHandler goji.Handler
CookieName string
}
@@ -115,13 +117,13 @@ type options struct {
// // framework.
// }
//
-func Protect(authKey []byte, opts ...Option) func(*web.C, http.Handler) http.Handler {
- return func(c *web.C, h http.Handler) http.Handler {
+func Protect(authKey []byte, opts ...Option) func(goji.Handler) goji.Handler {
+ return func(h goji.Handler) goji.Handler {
cs := parseOptions(h, opts...)
// Set the defaults if no options have been specified
if cs.opts.ErrorHandler == nil {
- cs.opts.ErrorHandler = web.HandlerFunc(unauthorizedHandler)
+ cs.opts.ErrorHandler = goji.HandlerFunc(unauthorizedHandler)
}
if cs.opts.MaxAge < 1 {
@@ -163,24 +165,16 @@ func Protect(authKey []byte, opts ...Option) func(*web.C, http.Handler) http.Han
}
}
- // Initialize Goji's request context
- cs.c = c
-
return *cs
}
}
// Implements http.Handler for the csrf type.
-func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // Create our request context if it does not already exist.
- if cs.c.Env == nil {
- cs.c.Env = make(map[interface{}]interface{})
- }
-
+func (cs csrf) ServeHTTPC(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// Retrieve the token from the session.
// An error represents either a cookie that failed HMAC validation
// or that doesn't exist.
- realToken, err := cs.st.Get(cs.c, r)
+ realToken, err := cs.st.Get(r)
if err != nil || len(realToken) != tokenLength {
// If there was an error retrieving the token, the token doesn't exist
// yet, or it's the wrong length, generate a new token.
@@ -188,24 +182,24 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// as it will no longer match the request token.
realToken, err = generateRandomBytes(tokenLength)
if err != nil {
- envError(cs.c, err)
- cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r)
+ setEnvError(ctx, err)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
return
}
// Save the new (real) token in the session store.
err = cs.st.Save(realToken, w)
if err != nil {
- envError(cs.c, err)
- cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r)
+ setEnvError(ctx, err)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
return
}
}
// Save the masked token to the request context
- cs.c.Env[tokenKey] = mask(realToken, cs.c, r)
+ ctx = context.WithValue(ctx, tokenKey, mask(realToken, r))
// Save the field name to the request context
- cs.c.Env[formKey] = cs.opts.FieldName
+ ctx = context.WithValue(ctx, formKey, cs.opts.FieldName)
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
@@ -218,14 +212,14 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// otherwise fails to parse.
referer, err := url.Parse(r.Referer())
if err != nil || referer.String() == "" {
- envError(cs.c, ErrNoReferer)
- cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r)
+ setEnvError(ctx, ErrNoReferer)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
return
}
if sameOrigin(r.URL, referer) == false {
- envError(cs.c, ErrBadReferer)
- cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r)
+ setEnvError(ctx, ErrBadReferer)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
return
}
}
@@ -233,8 +227,16 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If the token returned from the session store is nil for non-idempotent
// ("unsafe") methods, call the error handler.
if realToken == nil {
- envError(cs.c, ErrNoToken)
- cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r)
+ setEnvError(ctx, ErrNoToken)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
+ return
+ }
+
+ // If the token returned from the session store is nil for non-idempotent
+ // ("unsafe") methods, call the error handler.
+ if realToken == nil {
+ setEnvError(ctx, ErrNoToken)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
return
}
@@ -243,8 +245,8 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Compare the request token against the real token
if !compareTokens(requestToken, realToken) {
- envError(cs.c, ErrBadToken)
- cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r)
+ setEnvError(ctx, ErrBadToken)
+ cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r)
return
}
@@ -254,14 +256,14 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Vary", "Cookie")
// Call the wrapped handler/router on success
- cs.h.ServeHTTP(w, r)
+ cs.h.ServeHTTPC(ctx, w, r)
}
// unauthorizedhandler sets a HTTP 403 Forbidden status and writes the
// CSRF failure reason to the response.
-func unauthorizedHandler(c web.C, w http.ResponseWriter, r *http.Request) {
+func unauthorizedHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("%s - %s",
- http.StatusText(http.StatusForbidden), FailureReason(c, r)),
+ http.StatusText(http.StatusForbidden), FailureReason(ctx, r)),
http.StatusForbidden)
return
}
diff --git a/csrf_test.go b/csrf_test.go
index 9a7586c..22da59a 100644
--- a/csrf_test.go
+++ b/csrf_test.go
@@ -6,19 +6,22 @@ import (
"strings"
"testing"
- "github.com/zenazn/goji/web"
+ "goji.io/pat"
+
+ "golang.org/x/net/context"
+
+ "goji.io"
)
var testKey = []byte("keep-it-secret-keep-it-safe-----")
-var testHandler = web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {})
+var testHandler = goji.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {})
// TestProtect is a high-level test to make sure the middleware returns the
// wrapped handler with a 200 OK status.
func TestProtect(t *testing.T) {
- s := web.New()
- s.Use(Protect(testKey))
-
- s.Get("/", testHandler)
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
+ m.HandleFuncC(pat.Get("/"), testHandler)
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
@@ -26,7 +29,7 @@ func TestProtect(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -41,11 +44,9 @@ func TestProtect(t *testing.T) {
// Test that idempotent methods return a 200 OK status and that non-idempotent
// methods return a 403 Forbidden status when a CSRF cookie is not present.
func TestMethods(t *testing.T) {
- s := web.New()
- s.Use(Protect(testKey))
-
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- }))
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
+ m.HandleFuncC(pat.New("/"), testHandler)
// Test idempontent ("safe") methods
for _, method := range safeMethods {
@@ -55,7 +56,7 @@ func TestMethods(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -76,7 +77,7 @@ func TestMethods(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusForbidden {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -90,22 +91,15 @@ func TestMethods(t *testing.T) {
}
-// Tests for failure if the cookie containing the session is removed from the
-// request.
-func TestNoCookie(t *testing.T) {
-
-}
-
// TestBadCookie tests for failure when a cookie header is modified (malformed).
func TestBadCookie(t *testing.T) {
- s := web.New()
- CSRF := Protect(testKey)
- s.Use(CSRF)
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
var token string
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- token = Token(c, r)
- }))
+ m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ token = Token(ctx, r)
+ })
// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
@@ -114,7 +108,7 @@ func TestBadCookie(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
// POST the token back in the header.
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
@@ -129,7 +123,7 @@ func TestBadCookie(t *testing.T) {
r.Header.Set("Referer", "http://www.gorillatoolkit.org/")
rr = httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusForbidden {
t.Fatalf("middleware failed to reject a bad cookie: got %v want %v",
@@ -140,10 +134,9 @@ func TestBadCookie(t *testing.T) {
// Responses should set a "Vary: Cookie" header to protect client/proxy caching.
func TestVaryHeader(t *testing.T) {
-
- s := web.New()
- s.Use(Protect(testKey))
- s.Get("/", testHandler)
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
+ m.HandleFuncC(pat.Get("/"), testHandler)
r, err := http.NewRequest("HEAD", "https://www.golang.org/", nil)
if err != nil {
@@ -151,7 +144,7 @@ func TestVaryHeader(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -165,10 +158,9 @@ func TestVaryHeader(t *testing.T) {
// Requests with no Referer header should fail.
func TestNoReferer(t *testing.T) {
-
- s := web.New()
- s.Use(Protect(testKey))
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {}))
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
+ m.HandleFuncC(pat.Get("/"), testHandler)
r, err := http.NewRequest("POST", "https://golang.org/", nil)
if err != nil {
@@ -176,7 +168,7 @@ func TestNoReferer(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusForbidden {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -187,15 +179,13 @@ func TestNoReferer(t *testing.T) {
// TestBadReferer checks that HTTPS requests with a Referer that does not
// match the request URL correctly fail CSRF validation.
func TestBadReferer(t *testing.T) {
-
- s := web.New()
- CSRF := Protect(testKey)
- s.Use(CSRF)
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
var token string
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- token = Token(c, r)
- }))
+ m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ token = Token(ctx, r)
+ })
// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil)
@@ -204,7 +194,7 @@ func TestBadReferer(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
// POST the token back in the header.
r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil)
@@ -219,7 +209,7 @@ func TestBadReferer(t *testing.T) {
r.Header.Set("Referer", "http://goji.io")
rr = httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusForbidden {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -229,14 +219,13 @@ func TestBadReferer(t *testing.T) {
// Requests with a valid Referer should pass.
func TestWithReferer(t *testing.T) {
- s := web.New()
- CSRF := Protect(testKey)
- s.Use(CSRF)
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
var token string
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- token = Token(c, r)
- }))
+ m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ token = Token(ctx, r)
+ })
// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
@@ -245,7 +234,7 @@ func TestWithReferer(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
// POST the token back in the header.
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
@@ -258,7 +247,7 @@ func TestWithReferer(t *testing.T) {
r.Header.Set("Referer", "http://www.gorillatoolkit.org/")
rr = httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -268,6 +257,7 @@ func TestWithReferer(t *testing.T) {
// TestFormField tests that a token in the form field takes precedence over a
// token in the HTTP header.
+// TODO(matt): Finish this test.
func TestFormField(t *testing.T) {
}
diff --git a/helpers.go b/helpers.go
index 8aa0854..b94262c 100644
--- a/helpers.go
+++ b/helpers.go
@@ -9,26 +9,26 @@ import (
"net/http"
"net/url"
- "github.com/zenazn/goji/web"
+ "golang.org/x/net/context"
)
// Token returns a masked CSRF token ready for passing into HTML template or
// a JSON response body. An empty token will be returned if the middleware
// has not been applied (which will fail subsequent validation).
-func Token(c web.C, r *http.Request) string {
- if maskedToken, ok := c.Env[tokenKey].(string); ok {
+func Token(ctx context.Context, r *http.Request) string {
+ if maskedToken, ok := ctx.Value(tokenKey).(string); ok {
return maskedToken
}
return ""
}
-// FailureReason makes CSRF validation errors available in Goji's request
+// FailureReason makes CSRF validation errors available in the request
// context.
// This is useful when you want to log the cause of the error or report it to
// client.
-func FailureReason(c web.C, r *http.Request) error {
- if err, ok := c.Env[errorKey].(error); ok {
+func FailureReason(ctx context.Context, r *http.Request) error {
+ if err, ok := ctx.Value(errorKey).(error); ok {
return err
}
@@ -46,9 +46,9 @@ func FailureReason(c web.C, r *http.Request) error {
// // ... becomes:
//
//
-func TemplateField(c web.C, r *http.Request) template.HTML {
+func TemplateField(ctx context.Context, r *http.Request) template.HTML {
fragment := fmt.Sprintf(``,
- c.Env[formKey], Token(c, r))
+ ctx.Value(formKey), Token(ctx, r))
return template.HTML(fragment)
}
@@ -60,7 +60,7 @@ func TemplateField(c web.C, r *http.Request) template.HTML {
// token and returning them together as a 64-byte slice. This effectively
// randomises the token on a per-request basis without breaking multiple browser
// tabs/windows.
-func mask(realToken []byte, c *web.C, r *http.Request) string {
+func mask(realToken []byte, r *http.Request) string {
otp, err := generateRandomBytes(tokenLength)
if err != nil {
return ""
@@ -180,7 +180,7 @@ func contains(vals []string, s string) bool {
return false
}
-// envError stores a CSRF error in the request context.
-func envError(c *web.C, err error) {
- c.Env[errorKey] = err
+// setEnvError stores a CSRF error in the request context.
+func setEnvError(ctx context.Context, err error) {
+ ctx = context.WithValue(ctx, errorKey, err)
}
diff --git a/helpers_test.go b/helpers_test.go
index f04505a..d353a84 100644
--- a/helpers_test.go
+++ b/helpers_test.go
@@ -14,7 +14,10 @@ import (
"testing"
"text/template"
- "github.com/zenazn/goji/web"
+ "goji.io/pat"
+ "golang.org/x/net/context"
+
+ "goji.io"
)
var testTemplate = `
@@ -31,18 +34,18 @@ var testTemplateField = ``
// Test that our form helpers correctly inject a token into the response body.
func TestFormToken(t *testing.T) {
- s := web.New()
- s.Use(Protect(testKey))
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
// Make the token available outside of the handler for comparison.
var token string
- s.Get("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- token = Token(c, r)
+ m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ token = Token(ctx, r)
t := template.Must((template.New("base").Parse(testTemplate)))
t.Execute(w, map[string]interface{}{
- TemplateTag: TemplateField(c, r),
+ TemplateTag: TemplateField(ctx, r),
})
- }))
+ })
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
@@ -50,7 +53,7 @@ func TestFormToken(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -68,18 +71,18 @@ func TestFormToken(t *testing.T) {
// Test that we can extract a CSRF token from a multipart form.
func TestMultipartFormToken(t *testing.T) {
- s := web.New()
- s.Use(Protect(testKey))
+ m := goji.NewMux()
+ m.UseC(Protect(testKey))
// Make the token available outside of the handler for comparison.
var token string
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- token = Token(c, r)
+ m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ token = Token(ctx, r)
t := template.Must((template.New("base").Parse(testTemplate)))
t.Execute(w, map[string]interface{}{
- TemplateTag: TemplateField(c, r),
+ TemplateTag: TemplateField(ctx, r),
})
- }))
+ })
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
@@ -87,7 +90,7 @@ func TestMultipartFormToken(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
// Set up our multipart form
var b bytes.Buffer
@@ -112,7 +115,7 @@ func TestMultipartFormToken(t *testing.T) {
setCookie(rr, r)
rr = httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
@@ -134,7 +137,7 @@ func TestMaskUnmaskTokens(t *testing.T) {
t.Fatal(err)
}
- issued := mask(realToken, nil, nil)
+ issued := mask(realToken, nil)
decoded, err := base64.StdEncoding.DecodeString(issued)
if err != nil {
t.Fatal(err)
@@ -224,19 +227,16 @@ func TestGenerateRandomBytes(t *testing.T) {
}
func TestTemplateField(t *testing.T) {
- s := web.New()
- CSRF := Protect(
- testKey,
- FieldName(testFieldName),
- )
- s.Use(CSRF)
+ m := goji.NewMux()
+ CSRF := Protect(testKey, FieldName(testFieldName))
+ m.UseC(CSRF)
var token string
var customTemplateField string
- s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
- token = Token(c, r)
- customTemplateField = string(TemplateField(c, r))
- }))
+ m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ token = Token(ctx, r)
+ customTemplateField = string(TemplateField(ctx, r))
+ })
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
@@ -244,7 +244,7 @@ func TestTemplateField(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
expectedTemplateField := fmt.Sprintf(testTemplateField, testFieldName, token)
diff --git a/options.go b/options.go
index 319a0e2..2d65f38 100644
--- a/options.go
+++ b/options.go
@@ -1,10 +1,6 @@
package csrf
-import (
- "net/http"
-
- "github.com/zenazn/goji/web"
-)
+import "goji.io"
// Option describes a functional option for configuring the CSRF handler.
type Option func(*csrf) error
@@ -70,7 +66,7 @@ func HttpOnly(h bool) Option {
//
// Note that a custom error handler can also access the csrf.Failure(c, r)
// function to retrieve the CSRF validation reason from Goji's request context.
-func ErrorHandler(h web.Handler) Option {
+func ErrorHandler(h goji.Handler) Option {
return func(cs *csrf) error {
cs.opts.ErrorHandler = h
return nil
@@ -117,7 +113,7 @@ func setStore(s store) Option {
// parseOptions parses the supplied options functions and returns a configured
// csrf handler.
-func parseOptions(h http.Handler, opts ...Option) *csrf {
+func parseOptions(h goji.Handler, opts ...Option) *csrf {
// Set the handler to call after processing.
cs := &csrf{
h: h,
diff --git a/options_test.go b/options_test.go
index 4d27332..c9bcbd7 100644
--- a/options_test.go
+++ b/options_test.go
@@ -1,16 +1,15 @@
package csrf
import (
- "net/http"
"reflect"
"testing"
- "github.com/zenazn/goji/web"
+ "goji.io"
)
// Tests that options functions are applied to the middleware.
func TestOptions(t *testing.T) {
- var h http.Handler
+ var h goji.Handler
age := 86400
domain := "goji.io"
@@ -28,7 +27,7 @@ func TestOptions(t *testing.T) {
Secure(false),
RequestHeader(header),
FieldName(field),
- ErrorHandler(web.HandlerFunc(errorHandler)),
+ ErrorHandler(goji.HandlerFunc(errorHandler)),
CookieName(name),
}
diff --git a/store.go b/store.go
index 621aef9..96145df 100644
--- a/store.go
+++ b/store.go
@@ -5,13 +5,12 @@ import (
"time"
"github.com/gorilla/securecookie"
- "github.com/zenazn/goji/web"
)
// store represents the session storage used for CSRF tokens.
type store interface {
// Get returns the real CSRF token from the store.
- Get(c *web.C, r *http.Request) ([]byte, error)
+ Get(r *http.Request) ([]byte, error)
// Save stores the real CSRF token in the store and writes a
// cookie to the http.ResponseWriter.
// For non-cookie stores, the cookie should contain a unique (256 bit) ID
@@ -33,7 +32,7 @@ type cookieStore struct {
// Get retrieves a CSRF token from the session cookie. It returns an empty token
// if decoding fails (e.g. HMAC validation fails or the named cookie doesn't exist).
-func (cs *cookieStore) Get(c *web.C, r *http.Request) ([]byte, error) {
+func (cs *cookieStore) Get(r *http.Request) ([]byte, error) {
// Retrieve the cookie from the request
cookie, err := r.Cookie(cs.name)
if err != nil {
diff --git a/store_test.go b/store_test.go
index a708562..c4f6981 100644
--- a/store_test.go
+++ b/store_test.go
@@ -7,8 +7,11 @@ import (
"net/http/httptest"
"testing"
+ "goji.io"
+
+ "goji.io/pat"
+
"github.com/gorilla/securecookie"
- "github.com/zenazn/goji/web"
)
// Check Store implementations
@@ -19,7 +22,7 @@ type brokenSaveStore struct {
store
}
-func (bs *brokenSaveStore) Get(*web.C, *http.Request) ([]byte, error) {
+func (bs *brokenSaveStore) Get(*http.Request) ([]byte, error) {
// Generate an invalid token so we can progress to our Save method
return generateRandomBytes(24)
}
@@ -30,10 +33,10 @@ func (bs *brokenSaveStore) Save(realToken []byte, w http.ResponseWriter) error {
// Tests for failure if the middleware can't save to the Store.
func TestStoreCannotSave(t *testing.T) {
- s := web.New()
+ m := goji.NewMux()
bs := &brokenSaveStore{}
- s.Use(Protect(testKey, setStore(bs)))
- s.Get("/", testHandler)
+ m.UseC(Protect(testKey, setStore(bs)))
+ m.HandleFuncC(pat.Get("/"), testHandler)
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
@@ -41,7 +44,7 @@ func TestStoreCannotSave(t *testing.T) {
}
rr := httptest.NewRecorder()
- s.ServeHTTP(rr, r)
+ m.ServeHTTP(rr, r)
if rr.Code != http.StatusForbidden {
t.Fatalf("broken store did not set an error status: got %v want %v",
@@ -72,7 +75,7 @@ func TestCookieDecode(t *testing.T) {
// Set a fake cookie value so r.Cookie passes.
r.Header.Set("Cookie", fmt.Sprintf("%s=%s", cookieName, "notacookie"))
- _, err = st.Get(&web.C{}, r)
+ _, err = st.Get(r)
if err == nil {
t.Fatal("cookiestore did not report an invalid hashkey on decode")
}