Skip to content

Commit

Permalink
[feature] multipart form support.
Browse files Browse the repository at this point in the history
  • Loading branch information
elithrar committed Nov 30, 2015
1 parent a3f3add commit e7faaa5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
2 changes: 0 additions & 2 deletions csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ func TestWithReferer(t *testing.T) {
rr = httptest.NewRecorder()
s.ServeHTTP(rr, r)

t.Log(r.Header)

if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
rr.Code, http.StatusOK)
Expand Down
12 changes: 12 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,23 @@ func unmask(issued []byte) []byte {
// requestToken returns the issued token (pad + masked token) from the HTTP POST
// body or HTTP header. It will return nil if the token fails to decode.
func (cs *csrf) requestToken(r *http.Request) []byte {
// 1. Check the HTTP header first.
issued := r.Header.Get(cs.opts.RequestHeader)

// 2. Fall back to the POST (form) value.
if issued == "" {
issued = r.PostFormValue(cs.opts.FieldName)
}

// 3. Finally, fall back to the multipart form (if set).
if issued == "" && r.MultipartForm != nil {
vals := r.MultipartForm.Value[cs.opts.FieldName]

if len(vals) > 0 {
issued = vals[0]
}
}

// Decode the "issued" (pad + masked) token sent in the request. Return a
// nil byte slice on a decoding error (this will fail upstream).
decoded, err := base64.StdEncoding.DecodeString(issued)
Expand Down
59 changes: 59 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/base64"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -62,6 +63,64 @@ func TestFormToken(t *testing.T) {
}
}

// Test that we can extract a CSRF token from a multipart form.
func TestMultipartFormToken(t *testing.T) {
s := web.New()
s.Use(Protect(testKey))

// Make the token available outside of the handler for comparison.
var token string
s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {
token = Token(c, r)
t := template.Must((template.New("base").Parse(testTemplate)))
t.Execute(w, map[string]interface{}{
TemplateTag: TemplateField(c, r),
})
}))

r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
s.ServeHTTP(rr, r)

// Set up our multipart form
var b bytes.Buffer
mp := multipart.NewWriter(&b)
wr, err := mp.CreateFormField(fieldName)
if err != nil {
t.Fatal(err)
}

wr.Write([]byte(token))
mp.Close()

r, err = http.NewRequest("POST", "/", &b)
if err != nil {
t.Fatal(err)
}

// Add the multipart header.
r.Header.Set("Content-Type", mp.FormDataContentType())

// Send back the issued cookie.
setCookie(rr, r)

rr = httptest.NewRecorder()
s.ServeHTTP(rr, r)

if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
rr.Code, http.StatusOK)
}

if body := rr.Body.String(); !strings.Contains(body, token) {
t.Fatalf("token not in response body: got %v want %v", body, token)
}
}

// TestMaskUnmaskTokens tests that a token traversing the mask -> unmask process
// is correctly unmasked to the original 'real' token.
func TestMaskUnmaskTokens(t *testing.T) {
Expand Down

0 comments on commit e7faaa5

Please sign in to comment.