diff --git a/request/decoder.go b/request/decoder.go index 54c6262..3452444 100644 --- a/request/decoder.go +++ b/request/decoder.go @@ -79,7 +79,10 @@ func decodeValidate(d *form.Decoder, v interface{}, p url.Values, in rest.ParamI func makeDecoder(in rest.ParamIn, formDecoder *form.Decoder, decoderFunc decoderFunc) valueDecoderFunc { return func(r *http.Request, v interface{}, validator rest.Validator) error { ct := r.Header.Get("Content-Type") - if in == rest.ParamInFormData && ct != "" && !strings.HasPrefix(ct, "multipart/form-data") && ct != "application/x-www-form-urlencoded" { + if in == rest.ParamInFormData && ct != "" && !compareContentType( + ct, + "multipart/form-data", + ) && !compareContentType(ct, "application/x-www-form-urlencoded") { return nil } @@ -147,7 +150,7 @@ func formDataToURLValues(r *http.Request) (url.Values, error) { return nil, nil } - if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { + if compareContentType(r.Header.Get("Content-Type"), "multipart/form-data") { err := r.ParseMultipartForm(defaultMaxMemory) if err != nil { return nil, err @@ -168,7 +171,7 @@ func queryToURLValues(r *http.Request) (url.Values, error) { } func formToURLValues(r *http.Request) (url.Values, error) { - if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { + if compareContentType(r.Header.Get("Content-Type"), "multipart/form-data") { err := r.ParseMultipartForm(defaultMaxMemory) if err != nil { return nil, err @@ -201,3 +204,15 @@ func contentTypeBodyToURLValues(r *http.Request) (url.Values, error) { r.Header.Get("Content-Type"): []string{string(b)}, }, nil } + +func compareContentType(contentType, expectedContentType string) bool { + ct := strings.TrimSpace(strings.ToLower(contentType)) + ect := strings.TrimSpace(strings.ToLower(expectedContentType)) + if len(ct) < len(ect) { + return false + } else if len(ct) > len(ect) && ct[len(ect)] == ';' { + ct = ct[:len(ect)] + } + + return ct == ect +} diff --git a/request/jsonbody.go b/request/jsonbody.go index 788cab3..5bc6b1f 100644 --- a/request/jsonbody.go +++ b/request/jsonbody.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "strings" "sync" "github.com/swaggest/rest" @@ -72,8 +71,11 @@ func checkJSONBodyContentType(contentType string, tolerateFormData bool) (ret bo return false, nil } - if len(contentType) < 16 || strings.ToLower(contentType[0:16]) != "application/json" { // allow 'application/json;charset=UTF-8' - if tolerateFormData && (contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data") { + if !compareContentType(contentType, "application/json") { // allow 'application/json;charset=UTF-8' + if tolerateFormData && (compareContentType( + contentType, + "application/x-www-form-urlencoded", + ) || compareContentType(contentType, "multipart/form-data")) { return true, nil } diff --git a/request/jsonbody_test.go b/request/jsonbody_test.go index 3380b62..56f9ccd 100644 --- a/request/jsonbody_test.go +++ b/request/jsonbody_test.go @@ -14,7 +14,8 @@ import ( func Test_decodeJSONBody(t *testing.T) { createBody := bytes.NewReader( - []byte(`{"amount": 123,"customerId": "248df4b7-aa70-47b8-a036-33ac447e668d","type": "withdraw"}`)) + []byte(`{"amount": 123,"customerId": "248df4b7-aa70-47b8-a036-33ac447e668d","type": "withdraw"}`), + ) createReq, err := http.NewRequest(http.MethodPost, "/US/order/348df4b7-aa70-47b8-a036-33ac447e668d", createBody) assert.NoError(t, err) @@ -30,9 +31,11 @@ func Test_decodeJSONBody(t *testing.T) { assert.Equal(t, "248df4b7-aa70-47b8-a036-33ac447e668d", i.CustomerID) assert.Equal(t, "withdraw", i.Type) - vl := rest.ValidatorFunc(func(_ rest.ParamIn, _ map[string]interface{}) error { - return nil - }) + vl := rest.ValidatorFunc( + func(_ rest.ParamIn, _ map[string]interface{}) error { + return nil + }, + ) i = Input{} _, err = createBody.Seek(0, io.SeekStart) @@ -90,9 +93,11 @@ func Test_decodeJSONBody_validateFailed(t *testing.T) { var i []int - vl := rest.ValidatorFunc(func(_ rest.ParamIn, _ map[string]interface{}) error { - return errors.New("failed") - }) + vl := rest.ValidatorFunc( + func(_ rest.ParamIn, _ map[string]interface{}) error { + return errors.New("failed") + }, + ) err = decodeJSONBody(readJSON, false)(req, &i, vl) assert.EqualError(t, err, "failed") @@ -100,22 +105,29 @@ func Test_decodeJSONBody_validateFailed(t *testing.T) { func Test_decodeJSONBody_tolerateFormData(t *testing.T) { createBody := bytes.NewReader( - []byte(`amount=123&customerId=248df4b7-aa70-47b8-a036-33ac447e668d&type=withdraw`)) - createReq, err := http.NewRequest(http.MethodPost, "/US/order/348df4b7-aa70-47b8-a036-33ac447e668d", createBody) - createReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - assert.NoError(t, err) - - type Input struct { - Amount int `json:"amount" formData:"amount"` - CustomerID string `json:"customerId" formData:"customerId"` - Type string `json:"type" formData:"type"` + []byte(`amount=123&customerId=248df4b7-aa70-47b8-a036-33ac447e668d&type=withdraw`), + ) + for _, h := range []string{ + "application/x-www-form-urlencoded", + "multipart/form-data", + "multipart/form-data; boundary=--bound--", + } { + createReq, err := http.NewRequest(http.MethodPost, "/US/order/348df4b7-aa70-47b8-a036-33ac447e668d", createBody) + createReq.Header.Set("Content-Type", h) + assert.NoError(t, err) + + type Input struct { + Amount int `json:"amount" formData:"amount"` + CustomerID string `json:"customerId" formData:"customerId"` + Type string `json:"type" formData:"type"` + } + + i := Input{} + assert.NoError(t, decodeJSONBody(readJSON, true)(createReq, &i, nil)) + assert.Empty(t, i.Amount) + assert.Empty(t, i.CustomerID) + assert.Empty(t, i.Type) } - - i := Input{} - assert.NoError(t, decodeJSONBody(readJSON, true)(createReq, &i, nil)) - assert.Empty(t, i.Amount) - assert.Empty(t, i.CustomerID) - assert.Empty(t, i.Type) } func Test_decodeJSONBody_charset(t *testing.T) {