diff --git a/.travis.yml b/.travis.yml index 74942d9..80cce77 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,16 @@ language: go sudo: false -go: - - 1.4 - - 1.5 - - tip + +matrix: + include: + - go: 1.5 + - go: tip + install: - go get golang.org/x/tools/cmd/vet + script: - go get -t -v ./... - - diff -u <(echo -n) <(gofmt -d -s .) + - diff -u <(echo -n) <(gofmt -d .) - go tool vet . - go test -v -race ./... diff --git a/README.md b/README.md index f065659..0d866a8 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,11 @@ forgery](http://blog.codinghorror.com/preventing-csrf-and-xsrf-attacks/) (CSRF) templates to replace a `{{ .csrfField }}` template tag with a hidden input field. -This library is designed to work with the [Goji](https://github.com/zenazn/goji) -micro-framework, which is a simple web framework for Go that is broadly -compatible with other parts of the Go ecosystem. It makes use of Goji's `web.C` -request context, which doesn't rely on a global map, and is therefore safe to -attach to your top-level router (if you so wish). +This library is designed to work with not just the the [Goji](https://github.com/goji/goji) +micro-framework, but any framework that accepts the `func(context.Context, w http.ResponseWriter, r *http.Request)` +signature. This makes it compatible with other parts of the Go ecosystem. The +`context.Context` request context doesn't rely on a global map, and is therefore +free from contention in a busy web service. The library also assumes HTTPS by default: sending cookies over vanilla HTTP is risky and you're likely to get hurt. @@ -27,7 +27,7 @@ is risky and you're likely to get hurt. goji/csrf is easy to use: add the middleware to your stack with the below: ```go -goji.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) +goji.UseC(csrf.Protect([]byte("32-byte-long-auth-key"))) ``` ... and then collect the token with `csrf.Token(c, r)` before passing it to the @@ -47,25 +47,27 @@ import ( "html/template" "net/http" + "goji.io" "github.com/goji/csrf" - "github.com/zenazn/goji" + "github.com/zenazn/goji/graceful" ) func main() { + m := goji.NewMux() // Add the middleware to your router. - goji.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) - goji.Get("/signup", ShowSignupForm) + m.UseC(csrf.Protect([]byte("32-byte-long-auth-key"))) + m.HandleFuncC(pat.Get("/signup"), ShowSignupForm) // POST requests without a valid token will return a HTTP 403 Forbidden. - goji.Post("/signup/post", SubmitSignupForm) + m.HandleFuncC(pat.Post("/signup/post"), SubmitSignupForm) - goji.Serve() + graceful.ListenAndServe(":8000", m) } -func ShowSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { +func ShowSignupForm(ctx context.Context, w http.ResponseWriter, r *http.Request) { // signup_form.tmpl just needs a {{ .csrfField }} template tag for // csrf.TemplateField to inject the CSRF token into. Easy! t.ExecuteTemplate(w, "signup_form.tmpl", map[string]interface{ - csrf.TemplateTag: csrf.TemplateField(c, r), + csrf.TemplateTag: csrf.TemplateField(ctx, r), }) // We could also retrieve the token directly from csrf.Token(c, r) and // set it in the request header - w.Header.Set("X-CSRF-Token", token) @@ -73,7 +75,7 @@ func ShowSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { // framework. } -func SubmitSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { +func SubmitSignupForm(ctx context.Context, w http.ResponseWriter, r *http.Request) { // We can trust that requests making it this far have satisfied // our CSRF protection requirements. } @@ -91,38 +93,38 @@ as we don't handle any POST/PUT/DELETE requests with our top-level router. package main import ( + "goji.io" "github.com/goji/csrf" "github.com/zenazn/goji/graceful" - "github.com/zenazn/goji/web" ) func main() { - r := web.New() + m := goji.NewMux() // Our top-level router doesn't need CSRF protection: it's simple. - r.Get("/", ShowIndex) + m.HandleFuncC(pat.Get("/"), ShowIndex) - api := web.New() - r.Handle("/api/*", s) + api := goji.NewMux() + m.HandleC("/api/*", api) // ... but our /api/* routes do, so we add it to the sub-router only. - s.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) + api.UseC(csrf.Protect([]byte("32-byte-long-auth-key"))) - s.Get("/api/user/:id", GetUser) - s.Post("/api/user", PostUser) + api.Get("/api/user/:id", GetUser) + api.Post("/api/user", PostUser) - graceful.ListenAndServe(":8000", r) + graceful.ListenAndServe(":8000", m) } -func GetUser(c web.C, w http.ResponseWriter, r *http.Request) { +func GetUser(ctx context.Context, w http.ResponseWriter, r *http.Request) { // Authenticate the request, get the :id from the route params, // and fetch the user from the DB, etc. // Get the token and pass it in the CSRF header. Our JSON-speaking client // or JavaScript framework can now read the header and return the token in // in its own "X-CSRF-Token" request header on the subsequent POST. - w.Header().Set("X-CSRF-Token", csrf.Token(c, r)) + w.Header().Set("X-CSRF-Token", csrf.Token(ctx, r)) b, err := json.Marshal(user) if err != nil { - http.Error(...) + http.Error(w, http.StatusText(500), 500) return } @@ -138,23 +140,25 @@ goji/csrf provides options for changing these as you see fit: ```go func main() { + m := goji.NewMux() CSRF := csrf.Protect( []byte("a-32-byte-long-key-goes-here"), csrf.RequestHeader("Authenticity-Token"), csrf.FieldName("authenticity_token"), - // Note that csrf.ErrorHandler takes a Goji web.Handler type, else - // your error handler can't retrieve the error reason from the context. - // The signature `func UnauthHandler(c web.C, w http.ResponseWriter, r *http.Request)` - // is a web.Handler, and the simplest to use if you'd like to serve + // Note that csrf.ErrorHandler takes a Goji goji.Handler type, else + // your error handler can't retrieve the error reason from the + // context. + // The signature `func UnauthHandler(ctx context.Context, w http.ResponseWriter, r *http.Request)` + // is a goji.Handler, and the simplest to use if you'd like to serve // "pretty" error pages (who doesn't?). - csrf.ErrorHandler(web.HandlerFunc(serverError(403))), + csrf.ErrorHandler(goji.HandlerFunc(serverError(403))), ) - goji.Use(CSRF) - goji.Get("/signup", GetSignupForm) - goji.Post("/signup", PostSignupForm) + m.UseC(CSRF) + m.HandleFuncC(pat.Get("/signup"), GetSignupForm) + m.HandleFuncC(pat.Post("/signup"), PostSignupForm) - goji.Serve() + graceful.ListenAndServe(":8000", m) } ``` diff --git a/csrf.go b/csrf.go index 1295a40..c1bb0de 100644 --- a/csrf.go +++ b/csrf.go @@ -8,8 +8,11 @@ import ( "net/http" "net/url" + "golang.org/x/net/context" + + "goji.io" + "github.com/gorilla/securecookie" - "github.com/zenazn/goji/web" ) // CSRF token length in bytes. @@ -52,8 +55,7 @@ var ( ) type csrf struct { - c *web.C - h http.Handler + h goji.Handler sc *securecookie.SecureCookie st store opts options @@ -70,7 +72,7 @@ type options struct { Secure bool RequestHeader string FieldName string - ErrorHandler web.Handler + ErrorHandler goji.Handler CookieName string } @@ -115,13 +117,13 @@ type options struct { // // framework. // } // -func Protect(authKey []byte, opts ...Option) func(*web.C, http.Handler) http.Handler { - return func(c *web.C, h http.Handler) http.Handler { +func Protect(authKey []byte, opts ...Option) func(goji.Handler) goji.Handler { + return func(h goji.Handler) goji.Handler { cs := parseOptions(h, opts...) // Set the defaults if no options have been specified if cs.opts.ErrorHandler == nil { - cs.opts.ErrorHandler = web.HandlerFunc(unauthorizedHandler) + cs.opts.ErrorHandler = goji.HandlerFunc(unauthorizedHandler) } if cs.opts.MaxAge < 1 { @@ -163,24 +165,16 @@ func Protect(authKey []byte, opts ...Option) func(*web.C, http.Handler) http.Han } } - // Initialize Goji's request context - cs.c = c - return *cs } } // Implements http.Handler for the csrf type. -func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Create our request context if it does not already exist. - if cs.c.Env == nil { - cs.c.Env = make(map[interface{}]interface{}) - } - +func (cs csrf) ServeHTTPC(ctx context.Context, w http.ResponseWriter, r *http.Request) { // Retrieve the token from the session. // An error represents either a cookie that failed HMAC validation // or that doesn't exist. - realToken, err := cs.st.Get(cs.c, r) + realToken, err := cs.st.Get(r) if err != nil || len(realToken) != tokenLength { // If there was an error retrieving the token, the token doesn't exist // yet, or it's the wrong length, generate a new token. @@ -188,24 +182,24 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // as it will no longer match the request token. realToken, err = generateRandomBytes(tokenLength) if err != nil { - envError(cs.c, err) - cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) + setEnvError(ctx, err) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) return } // Save the new (real) token in the session store. err = cs.st.Save(realToken, w) if err != nil { - envError(cs.c, err) - cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) + setEnvError(ctx, err) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) return } } // Save the masked token to the request context - cs.c.Env[tokenKey] = mask(realToken, cs.c, r) + ctx = context.WithValue(ctx, tokenKey, mask(realToken, r)) // Save the field name to the request context - cs.c.Env[formKey] = cs.opts.FieldName + ctx = context.WithValue(ctx, formKey, cs.opts.FieldName) // HTTP methods not defined as idempotent ("safe") under RFC7231 require // inspection. @@ -218,14 +212,14 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // otherwise fails to parse. referer, err := url.Parse(r.Referer()) if err != nil || referer.String() == "" { - envError(cs.c, ErrNoReferer) - cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) + setEnvError(ctx, ErrNoReferer) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) return } if sameOrigin(r.URL, referer) == false { - envError(cs.c, ErrBadReferer) - cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) + setEnvError(ctx, ErrBadReferer) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) return } } @@ -233,8 +227,16 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If the token returned from the session store is nil for non-idempotent // ("unsafe") methods, call the error handler. if realToken == nil { - envError(cs.c, ErrNoToken) - cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) + setEnvError(ctx, ErrNoToken) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) + return + } + + // If the token returned from the session store is nil for non-idempotent + // ("unsafe") methods, call the error handler. + if realToken == nil { + setEnvError(ctx, ErrNoToken) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) return } @@ -243,8 +245,8 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Compare the request token against the real token if !compareTokens(requestToken, realToken) { - envError(cs.c, ErrBadToken) - cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) + setEnvError(ctx, ErrBadToken) + cs.opts.ErrorHandler.ServeHTTPC(ctx, w, r) return } @@ -254,14 +256,14 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("Vary", "Cookie") // Call the wrapped handler/router on success - cs.h.ServeHTTP(w, r) + cs.h.ServeHTTPC(ctx, w, r) } // unauthorizedhandler sets a HTTP 403 Forbidden status and writes the // CSRF failure reason to the response. -func unauthorizedHandler(c web.C, w http.ResponseWriter, r *http.Request) { +func unauthorizedHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("%s - %s", - http.StatusText(http.StatusForbidden), FailureReason(c, r)), + http.StatusText(http.StatusForbidden), FailureReason(ctx, r)), http.StatusForbidden) return } diff --git a/csrf_test.go b/csrf_test.go index 9a7586c..22da59a 100644 --- a/csrf_test.go +++ b/csrf_test.go @@ -6,19 +6,22 @@ import ( "strings" "testing" - "github.com/zenazn/goji/web" + "goji.io/pat" + + "golang.org/x/net/context" + + "goji.io" ) var testKey = []byte("keep-it-secret-keep-it-safe-----") -var testHandler = web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {}) +var testHandler = goji.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {}) // TestProtect is a high-level test to make sure the middleware returns the // wrapped handler with a 200 OK status. func TestProtect(t *testing.T) { - s := web.New() - s.Use(Protect(testKey)) - - s.Get("/", testHandler) + m := goji.NewMux() + m.UseC(Protect(testKey)) + m.HandleFuncC(pat.Get("/"), testHandler) r, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -26,7 +29,7 @@ func TestProtect(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -41,11 +44,9 @@ func TestProtect(t *testing.T) { // Test that idempotent methods return a 200 OK status and that non-idempotent // methods return a 403 Forbidden status when a CSRF cookie is not present. func TestMethods(t *testing.T) { - s := web.New() - s.Use(Protect(testKey)) - - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - })) + m := goji.NewMux() + m.UseC(Protect(testKey)) + m.HandleFuncC(pat.New("/"), testHandler) // Test idempontent ("safe") methods for _, method := range safeMethods { @@ -55,7 +56,7 @@ func TestMethods(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -76,7 +77,7 @@ func TestMethods(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -90,22 +91,15 @@ func TestMethods(t *testing.T) { } -// Tests for failure if the cookie containing the session is removed from the -// request. -func TestNoCookie(t *testing.T) { - -} - // TestBadCookie tests for failure when a cookie header is modified (malformed). func TestBadCookie(t *testing.T) { - s := web.New() - CSRF := Protect(testKey) - s.Use(CSRF) + m := goji.NewMux() + m.UseC(Protect(testKey)) var token string - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - token = Token(c, r) - })) + m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + token = Token(ctx, r) + }) // Obtain a CSRF cookie via a GET request. r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil) @@ -114,7 +108,7 @@ func TestBadCookie(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) // POST the token back in the header. r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil) @@ -129,7 +123,7 @@ func TestBadCookie(t *testing.T) { r.Header.Set("Referer", "http://www.gorillatoolkit.org/") rr = httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to reject a bad cookie: got %v want %v", @@ -140,10 +134,9 @@ func TestBadCookie(t *testing.T) { // Responses should set a "Vary: Cookie" header to protect client/proxy caching. func TestVaryHeader(t *testing.T) { - - s := web.New() - s.Use(Protect(testKey)) - s.Get("/", testHandler) + m := goji.NewMux() + m.UseC(Protect(testKey)) + m.HandleFuncC(pat.Get("/"), testHandler) r, err := http.NewRequest("HEAD", "https://www.golang.org/", nil) if err != nil { @@ -151,7 +144,7 @@ func TestVaryHeader(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -165,10 +158,9 @@ func TestVaryHeader(t *testing.T) { // Requests with no Referer header should fail. func TestNoReferer(t *testing.T) { - - s := web.New() - s.Use(Protect(testKey)) - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {})) + m := goji.NewMux() + m.UseC(Protect(testKey)) + m.HandleFuncC(pat.Get("/"), testHandler) r, err := http.NewRequest("POST", "https://golang.org/", nil) if err != nil { @@ -176,7 +168,7 @@ func TestNoReferer(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -187,15 +179,13 @@ func TestNoReferer(t *testing.T) { // TestBadReferer checks that HTTPS requests with a Referer that does not // match the request URL correctly fail CSRF validation. func TestBadReferer(t *testing.T) { - - s := web.New() - CSRF := Protect(testKey) - s.Use(CSRF) + m := goji.NewMux() + m.UseC(Protect(testKey)) var token string - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - token = Token(c, r) - })) + m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + token = Token(ctx, r) + }) // Obtain a CSRF cookie via a GET request. r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil) @@ -204,7 +194,7 @@ func TestBadReferer(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) // POST the token back in the header. r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil) @@ -219,7 +209,7 @@ func TestBadReferer(t *testing.T) { r.Header.Set("Referer", "http://goji.io") rr = httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -229,14 +219,13 @@ func TestBadReferer(t *testing.T) { // Requests with a valid Referer should pass. func TestWithReferer(t *testing.T) { - s := web.New() - CSRF := Protect(testKey) - s.Use(CSRF) + m := goji.NewMux() + m.UseC(Protect(testKey)) var token string - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - token = Token(c, r) - })) + m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + token = Token(ctx, r) + }) // Obtain a CSRF cookie via a GET request. r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil) @@ -245,7 +234,7 @@ func TestWithReferer(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) // POST the token back in the header. r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil) @@ -258,7 +247,7 @@ func TestWithReferer(t *testing.T) { r.Header.Set("Referer", "http://www.gorillatoolkit.org/") rr = httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -268,6 +257,7 @@ func TestWithReferer(t *testing.T) { // TestFormField tests that a token in the form field takes precedence over a // token in the HTTP header. +// TODO(matt): Finish this test. func TestFormField(t *testing.T) { } diff --git a/helpers.go b/helpers.go index 8aa0854..b94262c 100644 --- a/helpers.go +++ b/helpers.go @@ -9,26 +9,26 @@ import ( "net/http" "net/url" - "github.com/zenazn/goji/web" + "golang.org/x/net/context" ) // Token returns a masked CSRF token ready for passing into HTML template or // a JSON response body. An empty token will be returned if the middleware // has not been applied (which will fail subsequent validation). -func Token(c web.C, r *http.Request) string { - if maskedToken, ok := c.Env[tokenKey].(string); ok { +func Token(ctx context.Context, r *http.Request) string { + if maskedToken, ok := ctx.Value(tokenKey).(string); ok { return maskedToken } return "" } -// FailureReason makes CSRF validation errors available in Goji's request +// FailureReason makes CSRF validation errors available in the request // context. // This is useful when you want to log the cause of the error or report it to // client. -func FailureReason(c web.C, r *http.Request) error { - if err, ok := c.Env[errorKey].(error); ok { +func FailureReason(ctx context.Context, r *http.Request) error { + if err, ok := ctx.Value(errorKey).(error); ok { return err } @@ -46,9 +46,9 @@ func FailureReason(c web.C, r *http.Request) error { // // ... becomes: // // -func TemplateField(c web.C, r *http.Request) template.HTML { +func TemplateField(ctx context.Context, r *http.Request) template.HTML { fragment := fmt.Sprintf(``, - c.Env[formKey], Token(c, r)) + ctx.Value(formKey), Token(ctx, r)) return template.HTML(fragment) } @@ -60,7 +60,7 @@ func TemplateField(c web.C, r *http.Request) template.HTML { // token and returning them together as a 64-byte slice. This effectively // randomises the token on a per-request basis without breaking multiple browser // tabs/windows. -func mask(realToken []byte, c *web.C, r *http.Request) string { +func mask(realToken []byte, r *http.Request) string { otp, err := generateRandomBytes(tokenLength) if err != nil { return "" @@ -180,7 +180,7 @@ func contains(vals []string, s string) bool { return false } -// envError stores a CSRF error in the request context. -func envError(c *web.C, err error) { - c.Env[errorKey] = err +// setEnvError stores a CSRF error in the request context. +func setEnvError(ctx context.Context, err error) { + ctx = context.WithValue(ctx, errorKey, err) } diff --git a/helpers_test.go b/helpers_test.go index f04505a..d353a84 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -14,7 +14,10 @@ import ( "testing" "text/template" - "github.com/zenazn/goji/web" + "goji.io/pat" + "golang.org/x/net/context" + + "goji.io" ) var testTemplate = ` @@ -31,18 +34,18 @@ var testTemplateField = `` // Test that our form helpers correctly inject a token into the response body. func TestFormToken(t *testing.T) { - s := web.New() - s.Use(Protect(testKey)) + m := goji.NewMux() + m.UseC(Protect(testKey)) // Make the token available outside of the handler for comparison. var token string - s.Get("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - token = Token(c, r) + m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + token = Token(ctx, r) t := template.Must((template.New("base").Parse(testTemplate))) t.Execute(w, map[string]interface{}{ - TemplateTag: TemplateField(c, r), + TemplateTag: TemplateField(ctx, r), }) - })) + }) r, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -50,7 +53,7 @@ func TestFormToken(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -68,18 +71,18 @@ func TestFormToken(t *testing.T) { // Test that we can extract a CSRF token from a multipart form. func TestMultipartFormToken(t *testing.T) { - s := web.New() - s.Use(Protect(testKey)) + m := goji.NewMux() + m.UseC(Protect(testKey)) // Make the token available outside of the handler for comparison. var token string - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - token = Token(c, r) + m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + token = Token(ctx, r) t := template.Must((template.New("base").Parse(testTemplate))) t.Execute(w, map[string]interface{}{ - TemplateTag: TemplateField(c, r), + TemplateTag: TemplateField(ctx, r), }) - })) + }) r, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -87,7 +90,7 @@ func TestMultipartFormToken(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) // Set up our multipart form var b bytes.Buffer @@ -112,7 +115,7 @@ func TestMultipartFormToken(t *testing.T) { setCookie(rr, r) rr = httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("middleware failed to pass to the next handler: got %v want %v", @@ -134,7 +137,7 @@ func TestMaskUnmaskTokens(t *testing.T) { t.Fatal(err) } - issued := mask(realToken, nil, nil) + issued := mask(realToken, nil) decoded, err := base64.StdEncoding.DecodeString(issued) if err != nil { t.Fatal(err) @@ -224,19 +227,16 @@ func TestGenerateRandomBytes(t *testing.T) { } func TestTemplateField(t *testing.T) { - s := web.New() - CSRF := Protect( - testKey, - FieldName(testFieldName), - ) - s.Use(CSRF) + m := goji.NewMux() + CSRF := Protect(testKey, FieldName(testFieldName)) + m.UseC(CSRF) var token string var customTemplateField string - s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { - token = Token(c, r) - customTemplateField = string(TemplateField(c, r)) - })) + m.HandleFuncC(pat.New("/"), func(ctx context.Context, w http.ResponseWriter, r *http.Request) { + token = Token(ctx, r) + customTemplateField = string(TemplateField(ctx, r)) + }) r, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -244,7 +244,7 @@ func TestTemplateField(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) expectedTemplateField := fmt.Sprintf(testTemplateField, testFieldName, token) diff --git a/options.go b/options.go index 319a0e2..2d65f38 100644 --- a/options.go +++ b/options.go @@ -1,10 +1,6 @@ package csrf -import ( - "net/http" - - "github.com/zenazn/goji/web" -) +import "goji.io" // Option describes a functional option for configuring the CSRF handler. type Option func(*csrf) error @@ -70,7 +66,7 @@ func HttpOnly(h bool) Option { // // Note that a custom error handler can also access the csrf.Failure(c, r) // function to retrieve the CSRF validation reason from Goji's request context. -func ErrorHandler(h web.Handler) Option { +func ErrorHandler(h goji.Handler) Option { return func(cs *csrf) error { cs.opts.ErrorHandler = h return nil @@ -117,7 +113,7 @@ func setStore(s store) Option { // parseOptions parses the supplied options functions and returns a configured // csrf handler. -func parseOptions(h http.Handler, opts ...Option) *csrf { +func parseOptions(h goji.Handler, opts ...Option) *csrf { // Set the handler to call after processing. cs := &csrf{ h: h, diff --git a/options_test.go b/options_test.go index 4d27332..c9bcbd7 100644 --- a/options_test.go +++ b/options_test.go @@ -1,16 +1,15 @@ package csrf import ( - "net/http" "reflect" "testing" - "github.com/zenazn/goji/web" + "goji.io" ) // Tests that options functions are applied to the middleware. func TestOptions(t *testing.T) { - var h http.Handler + var h goji.Handler age := 86400 domain := "goji.io" @@ -28,7 +27,7 @@ func TestOptions(t *testing.T) { Secure(false), RequestHeader(header), FieldName(field), - ErrorHandler(web.HandlerFunc(errorHandler)), + ErrorHandler(goji.HandlerFunc(errorHandler)), CookieName(name), } diff --git a/store.go b/store.go index 621aef9..96145df 100644 --- a/store.go +++ b/store.go @@ -5,13 +5,12 @@ import ( "time" "github.com/gorilla/securecookie" - "github.com/zenazn/goji/web" ) // store represents the session storage used for CSRF tokens. type store interface { // Get returns the real CSRF token from the store. - Get(c *web.C, r *http.Request) ([]byte, error) + Get(r *http.Request) ([]byte, error) // Save stores the real CSRF token in the store and writes a // cookie to the http.ResponseWriter. // For non-cookie stores, the cookie should contain a unique (256 bit) ID @@ -33,7 +32,7 @@ type cookieStore struct { // Get retrieves a CSRF token from the session cookie. It returns an empty token // if decoding fails (e.g. HMAC validation fails or the named cookie doesn't exist). -func (cs *cookieStore) Get(c *web.C, r *http.Request) ([]byte, error) { +func (cs *cookieStore) Get(r *http.Request) ([]byte, error) { // Retrieve the cookie from the request cookie, err := r.Cookie(cs.name) if err != nil { diff --git a/store_test.go b/store_test.go index a708562..c4f6981 100644 --- a/store_test.go +++ b/store_test.go @@ -7,8 +7,11 @@ import ( "net/http/httptest" "testing" + "goji.io" + + "goji.io/pat" + "github.com/gorilla/securecookie" - "github.com/zenazn/goji/web" ) // Check Store implementations @@ -19,7 +22,7 @@ type brokenSaveStore struct { store } -func (bs *brokenSaveStore) Get(*web.C, *http.Request) ([]byte, error) { +func (bs *brokenSaveStore) Get(*http.Request) ([]byte, error) { // Generate an invalid token so we can progress to our Save method return generateRandomBytes(24) } @@ -30,10 +33,10 @@ func (bs *brokenSaveStore) Save(realToken []byte, w http.ResponseWriter) error { // Tests for failure if the middleware can't save to the Store. func TestStoreCannotSave(t *testing.T) { - s := web.New() + m := goji.NewMux() bs := &brokenSaveStore{} - s.Use(Protect(testKey, setStore(bs))) - s.Get("/", testHandler) + m.UseC(Protect(testKey, setStore(bs))) + m.HandleFuncC(pat.Get("/"), testHandler) r, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -41,7 +44,7 @@ func TestStoreCannotSave(t *testing.T) { } rr := httptest.NewRecorder() - s.ServeHTTP(rr, r) + m.ServeHTTP(rr, r) if rr.Code != http.StatusForbidden { t.Fatalf("broken store did not set an error status: got %v want %v", @@ -72,7 +75,7 @@ func TestCookieDecode(t *testing.T) { // Set a fake cookie value so r.Cookie passes. r.Header.Set("Cookie", fmt.Sprintf("%s=%s", cookieName, "notacookie")) - _, err = st.Get(&web.C{}, r) + _, err = st.Get(r) if err == nil { t.Fatal("cookiestore did not report an invalid hashkey on decode") }