Skip to content

Commit

Permalink
[feature] Added custom fieldname support
Browse files Browse the repository at this point in the history
  • Loading branch information
molivier authored and elithrar committed Nov 30, 2015
1 parent e7faaa5 commit d639a61
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
3 changes: 3 additions & 0 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const tokenLength = 32
// Context/session keys & prefixes
const (
tokenKey string = "goji.csrf.Token"
formKey string = "goji.csrf.Form"
errorKey string = "goji.csrf.Error"
cookieName string = "_goji_csrf"
errorPrefix string = "goji/csrf: "
Expand Down Expand Up @@ -203,6 +204,8 @@ func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Save the masked token to the request context
cs.c.Env[tokenKey] = mask(realToken, cs.c, r)
// Save the field name to the request context
cs.c.Env[formKey] = cs.opts.FieldName

// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
Expand Down
2 changes: 1 addition & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func FailureReason(c web.C, r *http.Request) error {
//
func TemplateField(c web.C, r *http.Request) template.HTML {
fragment := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
fieldName, Token(c, r))
c.Env[formKey], Token(c, r))

return template.HTML(fragment)
}
Expand Down
34 changes: 34 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"mime/multipart"
"net/http"
Expand All @@ -25,6 +26,8 @@ var testTemplate = `
</body>
</html>
`
var testFieldName = "custom_csrf_field_name"
var testTemplateField = `<input type="hidden" name="%s" value="%s">`

// Test that our form helpers correctly inject a token into the response body.
func TestFormToken(t *testing.T) {
Expand Down Expand Up @@ -219,3 +222,34 @@ func TestGenerateRandomBytes(t *testing.T) {
t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b))
}
}

func TestTemplateField(t *testing.T) {
s := web.New()
CSRF := Protect(
testKey,
FieldName(testFieldName),
)
s.Use(CSRF)

var token string
var customTemplateField string
s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
token = Token(c, r)
customTemplateField = string(TemplateField(c, r))
}))

r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
s.ServeHTTP(rr, r)

expectedTemplateField := fmt.Sprintf(testTemplateField, testFieldName, token)

if customTemplateField != expectedTemplateField {
t.Fatalf("templateField not set correctly: got %v want %v",
customTemplateField, expectedTemplateField)
}
}

0 comments on commit d639a61

Please sign in to comment.