diff --git a/csrf.go b/csrf.go index a4c8b0e..1295a40 100644 --- a/csrf.go +++ b/csrf.go @@ -18,6 +18,7 @@ const tokenLength = 32 // Context/session keys & prefixes const ( tokenKey string = "goji.csrf.Token" + formKey string = "goji.csrf.Form" errorKey string = "goji.csrf.Error" cookieName string = "_goji_csrf" errorPrefix string = "goji/csrf: " @@ -203,6 +204,8 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Save the masked token to the request context cs.c.Env[tokenKey] = mask(realToken, cs.c, r) + // Save the field name to the request context + cs.c.Env[formKey] = cs.opts.FieldName // HTTP methods not defined as idempotent ("safe") under RFC7231 require // inspection. diff --git a/helpers.go b/helpers.go index 30a98cc..8aa0854 100644 --- a/helpers.go +++ b/helpers.go @@ -48,7 +48,7 @@ func FailureReason(c web.C, r *http.Request) error { // func TemplateField(c web.C, r *http.Request) template.HTML { fragment := fmt.Sprintf(``, - fieldName, Token(c, r)) + c.Env[formKey], Token(c, r)) return template.HTML(fragment) } diff --git a/helpers_test.go b/helpers_test.go index 80e6e9a..f04505a 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "encoding/base64" + "fmt" "io" "mime/multipart" "net/http" @@ -25,6 +26,8 @@ var testTemplate = ` ` +var testFieldName = "custom_csrf_field_name" +var testTemplateField = `` // Test that our form helpers correctly inject a token into the response body. func TestFormToken(t *testing.T) { @@ -219,3 +222,34 @@ func TestGenerateRandomBytes(t *testing.T) { t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b)) } } + +func TestTemplateField(t *testing.T) { + s := web.New() + CSRF := Protect( + testKey, + FieldName(testFieldName), + ) + s.Use(CSRF) + + var token string + var customTemplateField string + s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { + token = Token(c, r) + customTemplateField = string(TemplateField(c, r)) + })) + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, r) + + expectedTemplateField := fmt.Sprintf(testTemplateField, testFieldName, token) + + if customTemplateField != expectedTemplateField { + t.Fatalf("templateField not set correctly: got %v want %v", + customTemplateField, expectedTemplateField) + } +}