From 70293cb15ff6ef10221cc634bbdb735198574bc7 Mon Sep 17 00:00:00 2001 From: zookee Date: Fri, 23 Feb 2024 14:28:39 +0000 Subject: [PATCH 1/3] Rename to oapi_validate_request to distinguish for file for response Move shared code into types and utils files Add proper types for context keys to get rid of linter warning --- oapi_validate.go => oapi_validate_request.go | 66 +------------------ ...e_test.go => oapi_validate_request_test.go | 0 types.go | 33 ++++++++++ utils.go | 47 +++++++++++++ 4 files changed, 82 insertions(+), 64 deletions(-) rename oapi_validate.go => oapi_validate_request.go (74%) rename oapi_validate_test.go => oapi_validate_request_test.go (100%) create mode 100644 types.go create mode 100644 utils.go diff --git a/oapi_validate.go b/oapi_validate_request.go similarity index 74% rename from oapi_validate.go rename to oapi_validate_request.go index dce1061..0159b1d 100644 --- a/oapi_validate.go +++ b/oapi_validate_request.go @@ -30,11 +30,6 @@ import ( "github.com/gin-gonic/gin" ) -const ( - GinContextKey = "oapi-codegen/gin-context" - UserDataKey = "oapi-codegen/user-data" -) - // OapiValidatorFromYamlFile creates a validator middleware from a YAML file path func OapiValidatorFromYamlFile(path string) (gin.HandlerFunc, error) { data, err := os.ReadFile(path) @@ -57,24 +52,6 @@ func OapiRequestValidator(swagger *openapi3.T) gin.HandlerFunc { return OapiRequestValidatorWithOptions(swagger, nil) } -// ErrorHandler is called when there is an error in validation -type ErrorHandler func(c *gin.Context, message string, statusCode int) - -// MultiErrorHandler is called when oapi returns a MultiError type -type MultiErrorHandler func(openapi3.MultiError) error - -// Options to customize request validation. These are passed through to -// openapi3filter. -type Options struct { - ErrorHandler ErrorHandler - Options openapi3filter.Options - ParamDecoder openapi3filter.ContentParameterDecoder - UserData interface{} - MultiErrorHandler MultiErrorHandler - // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` - SilenceServersWarning bool -} - // OapiRequestValidatorWithOptions creates a validator from a swagger object, with validation options func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc { if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) { @@ -137,12 +114,12 @@ func ValidateRequestFromContext(c *gin.Context, router routers.Router, options * // Pass the gin context into the request validator, so that any callbacks // which it invokes make it available. - requestContext := context.WithValue(context.Background(), GinContextKey, c) //nolint:staticcheck + requestContext := context.WithValue(context.Background(), GinContextKey, c) if options != nil { validationInput.Options = &options.Options validationInput.ParamDecoder = options.ParamDecoder - requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) //nolint:staticcheck + requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) } err = openapi3filter.ValidateRequest(requestContext, validationInput) @@ -170,42 +147,3 @@ func ValidateRequestFromContext(c *gin.Context, router routers.Router, options * } return nil } - -// GetGinContext gets the echo context from within requests. It returns -// nil if not found or wrong type. -func GetGinContext(c context.Context) *gin.Context { - iface := c.Value(GinContextKey) - if iface == nil { - return nil - } - ginCtx, ok := iface.(*gin.Context) - if !ok { - return nil - } - return ginCtx -} - -func GetUserData(c context.Context) interface{} { - return c.Value(UserDataKey) -} - -// attempt to get the MultiErrorHandler from the options. If it is not set, -// return a default handler -func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler { - if options == nil { - return defaultMultiErrorHandler - } - - if options.MultiErrorHandler == nil { - return defaultMultiErrorHandler - } - - return options.MultiErrorHandler -} - -// defaultMultiErrorHandler returns a StatusBadRequest (400) and a list -// of all of the errors. This method is called if there are no other -// methods defined on the options. -func defaultMultiErrorHandler(me openapi3.MultiError) error { - return fmt.Errorf("multiple errors encountered: %s", me) -} diff --git a/oapi_validate_test.go b/oapi_validate_request_test.go similarity index 100% rename from oapi_validate_test.go rename to oapi_validate_request_test.go diff --git a/types.go b/types.go new file mode 100644 index 0000000..c13b706 --- /dev/null +++ b/types.go @@ -0,0 +1,33 @@ +package ginmiddleware + +import ( + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/gin-gonic/gin" +) + +type GinContextKeyType string +type UserDataKeyType string + +const ( + GinContextKey GinContextKeyType = "oapi-codegen/gin-context" + UserDataKey UserDataKeyType = "oapi-codegen/user-data" +) + +// ErrorHandler is called when there is an error in validation +type ErrorHandler func(c *gin.Context, message string, statusCode int) + +// MultiErrorHandler is called when oapi returns a MultiError type +type MultiErrorHandler func(openapi3.MultiError) error + +// Options to customize request validation. These are passed through to +// openapi3filter. +type Options struct { + ErrorHandler ErrorHandler + Options openapi3filter.Options + ParamDecoder openapi3filter.ContentParameterDecoder + UserData interface{} + MultiErrorHandler MultiErrorHandler + // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` + SilenceServersWarning bool +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..702efee --- /dev/null +++ b/utils.go @@ -0,0 +1,47 @@ +package ginmiddleware + +import ( + "context" + "fmt" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/gin-gonic/gin" +) + +// GetGinContext gets the echo context from within requests. It returns +// nil if not found or wrong type. +func GetGinContext(c context.Context) *gin.Context { + iface := c.Value(GinContextKey) + if iface == nil { + return nil + } + ginCtx, ok := iface.(*gin.Context) + if !ok { + return nil + } + return ginCtx +} + +func GetUserData(c context.Context) interface{} { + return c.Value(UserDataKey) +} + +// attempt to get the MultiErrorHandler from the options. If it is not set, +// return a default handler +func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler { + if options == nil { + return defaultMultiErrorHandler + } + + if options.MultiErrorHandler == nil { + return defaultMultiErrorHandler + } + + return options.MultiErrorHandler +} + +// defaultMultiErrorHandler returns a list of all of the errors. +// This method is called if there are no other methods defined on the options. +func defaultMultiErrorHandler(me openapi3.MultiError) error { + return fmt.Errorf("multiple errors encountered: %s", me) +} From f9f18ff668326c95f530d0a5b43bd99c76176bbf Mon Sep 17 00:00:00 2001 From: zookee Date: Fri, 1 Mar 2024 17:44:28 +0000 Subject: [PATCH 2/3] Add OAPI response validation middleware --- oapi_validate_request_test.go | 60 +----- oapi_validate_response.go | 183 ++++++++++++++++ oapi_validate_response_test.go | 263 +++++++++++++++++++++++ request_doers_test.go | 56 +++++ test_spec.yaml => test_request_spec.yaml | 0 test_response_spec.yaml | 77 +++++++ 6 files changed, 584 insertions(+), 55 deletions(-) create mode 100644 oapi_validate_response.go create mode 100644 oapi_validate_response_test.go create mode 100644 request_doers_test.go rename test_spec.yaml => test_request_spec.yaml (100%) create mode 100644 test_response_spec.yaml diff --git a/oapi_validate_request_test.go b/oapi_validate_request_test.go index 1b5fe81..feb88ef 100644 --- a/oapi_validate_request_test.go +++ b/oapi_validate_request_test.go @@ -15,16 +15,12 @@ package ginmiddleware import ( - "bytes" "context" _ "embed" - "encoding/json" "errors" "fmt" "io" "net/http" - "net/http/httptest" - "net/url" "testing" "github.com/getkin/kin-openapi/openapi3" @@ -34,57 +30,11 @@ import ( "github.com/stretchr/testify/require" ) -//go:embed test_spec.yaml -var testSchema []byte - -func doGet(t *testing.T, handler http.Handler, rawURL string) *httptest.ResponseRecorder { - u, err := url.Parse(rawURL) - if err != nil { - t.Fatalf("Invalid url: %s", rawURL) - } - - r, err := http.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - t.Fatalf("Could not construct a request: %s", rawURL) - } - r.Header.Set("accept", "application/json") - r.Header.Set("host", u.Host) - - tt := httptest.NewRecorder() - - handler.ServeHTTP(tt, r) - - return tt -} - -func doPost(t *testing.T, handler http.Handler, rawURL string, jsonBody interface{}) *httptest.ResponseRecorder { - u, err := url.Parse(rawURL) - if err != nil { - t.Fatalf("Invalid url: %s", rawURL) - } - - body, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Could not marshal request body: %v", err) - } - - r, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) - if err != nil { - t.Fatalf("Could not construct a request for URL %s: %v", rawURL, err) - } - r.Header.Set("accept", "application/json") - r.Header.Set("content-type", "application/json") - r.Header.Set("host", u.Host) - - tt := httptest.NewRecorder() - - handler.ServeHTTP(tt, r) - - return tt -} +//go:embed test_request_spec.yaml +var testRequestSchema []byte func TestOapiRequestValidator(t *testing.T) { - swagger, err := openapi3.NewLoader().LoadFromData(testSchema) + swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema) require.NoError(t, err, "Error initializing swagger") // Create a new echo router @@ -232,7 +182,7 @@ func TestOapiRequestValidator(t *testing.T) { } func TestOapiRequestValidatorWithOptionsMultiError(t *testing.T) { - swagger, err := openapi3.NewLoader().LoadFromData(testSchema) + swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema) require.NoError(t, err, "Error initializing swagger") g := gin.New() @@ -335,7 +285,7 @@ func TestOapiRequestValidatorWithOptionsMultiError(t *testing.T) { } func TestOapiRequestValidatorWithOptionsMultiErrorAndCustomHandler(t *testing.T) { - swagger, err := openapi3.NewLoader().LoadFromData(testSchema) + swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema) require.NoError(t, err, "Error initializing swagger") g := gin.New() diff --git a/oapi_validate_response.go b/oapi_validate_response.go new file mode 100644 index 0000000..5da09ea --- /dev/null +++ b/oapi_validate_response.go @@ -0,0 +1,183 @@ +// Copyright 2021 DeepMap, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ginmiddleware + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/gin-gonic/gin" +) + +// OapiResponseValidatorFromYamlFile creates a validator middleware from a YAML file path +func OapiResponseValidatorFromYamlFile(path string) (gin.HandlerFunc, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("error reading %s: %s", path, err) + } + + swagger, err := openapi3.NewLoader().LoadFromData(data) + if err != nil { + return nil, fmt.Errorf("error parsing %s as Swagger YAML: %s", + path, err) + } + return OapiRequestValidator(swagger), nil +} + +// OapiRequestValidator is an gin middleware function which validates incoming HTTP requests +// to make sure that they conform to the given OAPI 3.0 specification. When +// OAPI validation fails on the request, we return an HTTP/400 with error message +func OapiResponseValidator(swagger *openapi3.T) gin.HandlerFunc { + return OapiResponseValidatorWithOptions(swagger, nil) +} + +// OapiResponseValidatorWithOptions creates a validator from a swagger object, with validation options +func OapiResponseValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc { + if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) { + log.Println("WARN: OapiResponseValidatorWithOptions called with an OpenAPI spec that has `Servers` set. This may lead to an HTTP 400 with `no matching operation was found` when sending a valid request, as the validator performs `Host` header validation. If you're expecting `Host` header validation, you can silence this warning by setting `Options.SilenceServersWarning = true`. See https://github.com/deepmap/oapi-codegen/issues/882 for more information.") + } + + router, err := gorillamux.NewRouter(swagger) + if err != nil { + panic(err) + } + return func(c *gin.Context) { + err := ValidateResponseFromContext(c, router, options) + if err != nil { + if options != nil && options.ErrorHandler != nil { + options.ErrorHandler(c, err.Error(), http.StatusInternalServerError) + // in case the handler didn't internally call Abort, stop the chain + c.Abort() + } else { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + } + } +} + +type responseInterceptor struct { + gin.ResponseWriter + body *bytes.Buffer +} + +func (w *responseInterceptor) Write(b []byte) (int, error) { + return w.body.Write(b) +} + +// ValidateResponseFromContext is called from the middleware above and actually does the work +// of validating a response. +func ValidateResponseFromContext(c *gin.Context, router routers.Router, options *Options) error { + req := c.Request + route, pathParams, err := router.FindRoute(req) + + // We failed to find a matching route for the request. + if err != nil { + switch e := err.(type) { + case *routers.RouteError: + // We've got a bad request, the path requested doesn't match + // either server, or path, or something. + return errors.New(e.Reason) + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + return fmt.Errorf("error validating route: %s", err.Error()) + } + } + + reqValidationInput := &openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + } + + // Pass the gin context into the request validator, so that any callbacks + // which it invokes make it available. + requestContext := context.WithValue(context.Background(), GinContextKey, c) + + if options != nil { + reqValidationInput.Options = &options.Options + reqValidationInput.ParamDecoder = options.ParamDecoder + requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) + } + + // wrap the response writer in a bodyWriter so we can capture the response body + bw := &responseInterceptor{ResponseWriter: c.Writer, body: bytes.NewBufferString("")} + c.Writer = bw + + // Call the next handler in the chain, which will actually process the request + c.Next() + + // capture the response status and body + status := c.Writer.Status() + body := io.NopCloser(bytes.NewReader(bw.body.Bytes())) + + rspValidationInput := &openapi3filter.ResponseValidationInput{ + RequestValidationInput: reqValidationInput, + Status: status, + Header: c.Writer.Header(), + Body: body, + } + + if options != nil { + rspValidationInput.Options = &options.Options + } + + err = openapi3filter.ValidateResponse(requestContext, rspValidationInput) + + if err != nil { + // restore the original response writer + c.Writer = bw.ResponseWriter + + me := openapi3.MultiError{} + if errors.As(err, &me) { + errFunc := getMultiErrorHandlerFromOptions(options) + return errFunc(me) + } + + switch e := err.(type) { + case *openapi3filter.ResponseError: + // We've got a bad request + // Split up the verbose error by lines and return the first one + // openapi errors seem to be multi-line with a decent message on the first + errorLines := strings.Split(e.Error(), "\n") + return fmt.Errorf("error in openapi3filter.ResponseError: %s", errorLines[0]) + case *openapi3filter.SecurityRequirementsError: + return fmt.Errorf("error in openapi3filter.SecurityRequirementsError: %s", e.Error()) + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + return fmt.Errorf("error validating response: %w", err) + } + } + + // the response is valid, so write the captured response body to the original response writer + _, err = bw.ResponseWriter.Write(bw.body.Bytes()) + if err != nil { + return fmt.Errorf("error writing response body: %w", err) + } + + return nil +} diff --git a/oapi_validate_response_test.go b/oapi_validate_response_test.go new file mode 100644 index 0000000..66d8188 --- /dev/null +++ b/oapi_validate_response_test.go @@ -0,0 +1,263 @@ +// Copyright 2021 DeepMap, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ginmiddleware + +import ( + _ "embed" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//go:embed test_response_spec.yaml +var testResponseSchema []byte + +func TestOapiResponseValidator(t *testing.T) { + gin.SetMode(gin.TestMode) + + swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema) + require.NoError(t, err, "Error initializing swagger") + + // Create a new echo router + g := gin.New() + + // Set up an authenticator to check authenticated function. It will allow + // access to "someScope", but disallow others. + options := Options{ + ErrorHandler: func(c *gin.Context, message string, statusCode int) { + c.String(statusCode, "test: "+message) + }, + Options: openapi3filter.Options{ + AuthenticationFunc: openapi3filter.NoopAuthenticationFunc, + IncludeResponseStatus: true, + }, + UserData: "hi!", + } + + // Install our OpenApi based request validator + g.Use(OapiResponseValidatorWithOptions(swagger, &options)) + + tests := []struct { + name string + operationID string + }{ + { + name: "GET /resource", + operationID: "getResource", + }, + } + + // getResource + testGetResource := func(t *testing.T, g *gin.Engine) { + var body string + var statusCode int + var contentType string + + // Install a request handler for /resource. + g.GET("/resource", func(c *gin.Context) { + c.Data(statusCode, contentType, []byte(body)) + }) + + tests := []struct { + name string + body string + status int + contentType string + wantRsp string + wantStatus int + }{ + // Let's send a good response, it should pass + { + name: "good response: good status: 200", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusOK, + contentType: "application/json", + wantRsp: `{"name":"Wilhelm Scream", "id":7}`, + wantStatus: http.StatusOK, + }, + // And for 404, it should pass + { + name: "good response: good status: 404", + body: `{"message": "couldn't find the resource"}`, + status: http.StatusNotFound, + contentType: "application/json", + wantRsp: `{"message": "couldn't find the resource"}`, + wantStatus: http.StatusNotFound, + }, + // And for 500, it should pass + { + name: "good response: good status: 500", + body: `{"message": "internal server error"}`, + status: http.StatusInternalServerError, + contentType: "application/json", + wantRsp: `{"message": "internal server error"}`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a bad response, it should fail + { + name: "bad response: good status", + body: `{"name": "Wilhelm Scream", "id": "not a number"}`, + status: http.StatusOK, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: response body doesn't match schema: Error at "/id": value must be an integer`, + wantStatus: http.StatusInternalServerError, + }, + // And for 404, it should fail + { + name: "bad response: missing required property: good status: 404", + body: `{}`, + status: http.StatusNotFound, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: response body doesn't match schema: Error at "/message": property "message" is missing`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a good response, but with a bad status, it should fail + { + name: "good response: bad status", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusCreated, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: status is not supported`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a good response, but with a bad content type, it should fail + { + name: "good response: bad content type", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusOK, + contentType: "text/plain", + wantRsp: `test: error in openapi3filter.ResponseError: response header Content-Type has unexpected value: "text/plain"`, + wantStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body = tt.body + statusCode = tt.status + contentType = tt.contentType + + rec := doGet(t, g, "http://deepmap.ai/resource") + assert.Equal(t, tt.wantStatus, rec.Code) + if tt.wantStatus == http.StatusOK { + switch tt.contentType { + case "application/json": + assert.JSONEq(t, tt.wantRsp, rec.Body.String()) + default: + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + } else { + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + }) + } + } + + // createResource + testCreateResource := func(t *testing.T, g *gin.Engine) { + var body string + var statusCode int + var contentType string + + // Install a request handler for /resource. + g.POST("/resource", func(c *gin.Context) { + c.Data(statusCode, contentType, []byte(body)) + }) + + tests := []struct { + name string + body string + status int + contentType string + wantRsp string + wantStatus int + }{ + // Let's send a good response, it should pass + { + name: "good response: good status: 201", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusCreated, + contentType: "application/json", + wantRsp: `{"name":"Wilhelm Scream", "id":7}`, + wantStatus: http.StatusCreated, + }, + // Let's send a good response, but with a bad status, it should fail + { + name: "good response: bad status: 200", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusOK, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: status is not supported`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a good response, with different content type, it should pass + { + name: "good response: good status: 504", + body: "Gateway Timeout", + status: http.StatusGatewayTimeout, + contentType: "text/plain", + wantRsp: "Gateway Timeout", + wantStatus: http.StatusGatewayTimeout, + }, + // Let's send a good response, but with a bad content type, it should fail + { + name: "good response: bad content type", + body: `{"message":"timed out waiting for upstream server to respond"}`, + status: http.StatusGatewayTimeout, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: response header Content-Type has unexpected value: "application/json"`, + wantStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body = tt.body + statusCode = tt.status + contentType = tt.contentType + + rec := doPost(t, g, "http://deepmap.ai/resource", gin.H{"name": "Wilhelm Scream"}) + assert.Equal(t, tt.wantStatus, rec.Code) + if tt.wantStatus == http.StatusOK { + switch tt.contentType { + case "application/json": + assert.JSONEq(t, tt.wantRsp, rec.Body.String()) + default: + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + } else { + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + }) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + switch tt.operationID { + case "getResource": + testGetResource(t, g) + case "createResource": + testCreateResource(t, g) + } + }) + } + +} diff --git a/request_doers_test.go b/request_doers_test.go new file mode 100644 index 0000000..b8b08e4 --- /dev/null +++ b/request_doers_test.go @@ -0,0 +1,56 @@ +package ginmiddleware + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func doGet(t *testing.T, handler http.Handler, rawURL string) *httptest.ResponseRecorder { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + r, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + t.Fatalf("Could not construct a request: %s", rawURL) + } + r.Header.Set("accept", "application/json") + r.Header.Set("host", u.Host) + + tt := httptest.NewRecorder() + + handler.ServeHTTP(tt, r) + + return tt +} + +func doPost(t *testing.T, handler http.Handler, rawURL string, jsonBody interface{}) *httptest.ResponseRecorder { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + body, err := json.Marshal(jsonBody) + if err != nil { + t.Fatalf("Could not marshal request body: %v", err) + } + + r, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + t.Fatalf("Could not construct a request for URL %s: %v", rawURL, err) + } + r.Header.Set("accept", "application/json") + r.Header.Set("content-type", "application/json") + r.Header.Set("host", u.Host) + + tt := httptest.NewRecorder() + + handler.ServeHTTP(tt, r) + + return tt +} diff --git a/test_spec.yaml b/test_request_spec.yaml similarity index 100% rename from test_spec.yaml rename to test_request_spec.yaml diff --git a/test_response_spec.yaml b/test_response_spec.yaml new file mode 100644 index 0000000..fb650f4 --- /dev/null +++ b/test_response_spec.yaml @@ -0,0 +1,77 @@ +openapi: '3.0.0' +info: + version: 1.0.0 + title: TestServer +servers: + - url: http://deepmap.ai/ +paths: + /resource: + get: + operationId: getResource + parameters: + - name: id + in: query + schema: + type: integer + responses: + '200': + description: success + content: + application/json: + schema: + required: + - name + - id + properties: + name: + type: string + id: + type: integer + '404': + description: not found + content: + application/json: + schema: + required: + - message + properties: + message: + type: string + '500': + description: internal server error + content: + application/json: + schema: + properties: + message: + type: string + post: + operationId: createResource + responses: + '201': + description: created + content: + application/json: + schema: + required: + - name + - id + properties: + name: + type: string + id: + type: integer + '504': + description: gateway timeout + content: + text/plain: + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + properties: + name: + type: string From 2577745d0063f78051bf36298d21c956d461f26c Mon Sep 17 00:00:00 2001 From: zookee Date: Fri, 15 Mar 2024 18:50:23 +0000 Subject: [PATCH 3/3] Consolidate shared code across middlewares Remove case for SecurityRequirementsError Improve test coverage --- context.go | 19 +++++++++ error.go | 35 +++++++++++++++++ oapi_validate_request.go | 50 +++-------------------- oapi_validate_response.go | 62 ++++++----------------------- oapi_validate_response_test.go | 72 ++++++++++++++++++++++++++++------ response_interceptor.go | 25 ++++++++++++ route.go | 44 +++++++++++++++++++++ 7 files changed, 198 insertions(+), 109 deletions(-) create mode 100644 context.go create mode 100644 error.go create mode 100644 response_interceptor.go create mode 100644 route.go diff --git a/context.go b/context.go new file mode 100644 index 0000000..94bd40f --- /dev/null +++ b/context.go @@ -0,0 +1,19 @@ +package ginmiddleware + +import ( + "context" + + "github.com/gin-gonic/gin" +) + +func getRequestContext( + c *gin.Context, + options *Options, +) context.Context { + requestContext := context.WithValue(context.Background(), GinContextKey, c) + if options != nil { + requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) + } + + return requestContext +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..4d02a1d --- /dev/null +++ b/error.go @@ -0,0 +1,35 @@ +package ginmiddleware + +import ( + "errors" + "net/http" + + "github.com/getkin/kin-openapi/routers" + "github.com/gin-gonic/gin" +) + +func handleValidationError( + c *gin.Context, + err error, + options *Options, + generalStatusCode int, +) { + var errorHandler ErrorHandler + // if an error handler is provided, use that + if options != nil && options.ErrorHandler != nil { + errorHandler = options.ErrorHandler + } else { + errorHandler = func(c *gin.Context, message string, statusCode int) { + c.AbortWithStatusJSON(statusCode, gin.H{"error": message}) + } + } + + if errors.Is(err, routers.ErrPathNotFound) { + errorHandler(c, err.Error(), http.StatusNotFound) + } else { + errorHandler(c, err.Error(), generalStatusCode) + } + + // in case the handler didn't internally call Abort, stop the chain + c.Abort() +} diff --git a/oapi_validate_request.go b/oapi_validate_request.go index 0159b1d..99b2a40 100644 --- a/oapi_validate_request.go +++ b/oapi_validate_request.go @@ -15,7 +15,6 @@ package ginmiddleware import ( - "context" "errors" "fmt" "log" @@ -65,22 +64,7 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin. return func(c *gin.Context) { err := ValidateRequestFromContext(c, router, options) if err != nil { - // using errors.Is did not work - if options != nil && options.ErrorHandler != nil && err.Error() == routers.ErrPathNotFound.Error() { - options.ErrorHandler(c, err.Error(), http.StatusNotFound) - // in case the handler didn't internally call Abort, stop the chain - c.Abort() - } else if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(c, err.Error(), http.StatusBadRequest) - // in case the handler didn't internally call Abort, stop the chain - c.Abort() - } else if err.Error() == routers.ErrPathNotFound.Error() { - // note: i am not sure if this is the best way to handle this - c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()}) - } else { - // note: i am not sure if this is the best way to handle this - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } + handleValidationError(c, err, options, http.StatusBadRequest) } c.Next() } @@ -89,38 +73,14 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin. // ValidateRequestFromContext is called from the middleware above and actually does the work // of validating a request. func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *Options) error { - req := c.Request - route, pathParams, err := router.FindRoute(req) - - // We failed to find a matching route for the request. + validationInput, err := getRequestValidationInput(c.Request, router, options) if err != nil { - switch e := err.(type) { - case *routers.RouteError: - // We've got a bad request, the path requested doesn't match - // either server, or path, or something. - return errors.New(e.Reason) - default: - // This should never happen today, but if our upstream code changes, - // we don't want to crash the server, so handle the unexpected error. - return fmt.Errorf("error validating route: %s", err.Error()) - } - } - - validationInput := &openapi3filter.RequestValidationInput{ - Request: req, - PathParams: pathParams, - Route: route, + return fmt.Errorf("error getting request validation input from route: %w", err) } - // Pass the gin context into the request validator, so that any callbacks + // Pass the gin context into the response validator, so that any callbacks // which it invokes make it available. - requestContext := context.WithValue(context.Background(), GinContextKey, c) - - if options != nil { - validationInput.Options = &options.Options - validationInput.ParamDecoder = options.ParamDecoder - requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) - } + requestContext := getRequestContext(c, options) err = openapi3filter.ValidateRequest(requestContext, validationInput) if err != nil { diff --git a/oapi_validate_response.go b/oapi_validate_response.go index 5da09ea..b5649da 100644 --- a/oapi_validate_response.go +++ b/oapi_validate_response.go @@ -16,7 +16,6 @@ package ginmiddleware import ( "bytes" - "context" "errors" "fmt" "io" @@ -47,9 +46,9 @@ func OapiResponseValidatorFromYamlFile(path string) (gin.HandlerFunc, error) { return OapiRequestValidator(swagger), nil } -// OapiRequestValidator is an gin middleware function which validates incoming HTTP requests +// OapiResponseValidator is an gin middleware function which validates outgoing HTTP responses // to make sure that they conform to the given OAPI 3.0 specification. When -// OAPI validation fails on the request, we return an HTTP/400 with error message +// OAPI validation fails on the request, we return an HTTP/500 with error message func OapiResponseValidator(swagger *openapi3.T) gin.HandlerFunc { return OapiResponseValidatorWithOptions(swagger, nil) } @@ -67,64 +66,28 @@ func OapiResponseValidatorWithOptions(swagger *openapi3.T, options *Options) gin return func(c *gin.Context) { err := ValidateResponseFromContext(c, router, options) if err != nil { - if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(c, err.Error(), http.StatusInternalServerError) - // in case the handler didn't internally call Abort, stop the chain - c.Abort() - } else { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } + handleValidationError(c, err, options, http.StatusInternalServerError) } - } -} -type responseInterceptor struct { - gin.ResponseWriter - body *bytes.Buffer -} - -func (w *responseInterceptor) Write(b []byte) (int, error) { - return w.body.Write(b) + // in case an error was encountered before Next() was called, call it here + c.Next() + } } // ValidateResponseFromContext is called from the middleware above and actually does the work // of validating a response. func ValidateResponseFromContext(c *gin.Context, router routers.Router, options *Options) error { - req := c.Request - route, pathParams, err := router.FindRoute(req) - - // We failed to find a matching route for the request. + reqValidationInput, err := getRequestValidationInput(c.Request, router, options) if err != nil { - switch e := err.(type) { - case *routers.RouteError: - // We've got a bad request, the path requested doesn't match - // either server, or path, or something. - return errors.New(e.Reason) - default: - // This should never happen today, but if our upstream code changes, - // we don't want to crash the server, so handle the unexpected error. - return fmt.Errorf("error validating route: %s", err.Error()) - } - } - - reqValidationInput := &openapi3filter.RequestValidationInput{ - Request: req, - PathParams: pathParams, - Route: route, + return fmt.Errorf("error getting request validation input from route: %w", err) } - // Pass the gin context into the request validator, so that any callbacks + // Pass the gin context into the response validator, so that any callbacks // which it invokes make it available. - requestContext := context.WithValue(context.Background(), GinContextKey, c) - - if options != nil { - reqValidationInput.Options = &options.Options - reqValidationInput.ParamDecoder = options.ParamDecoder - requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) - } + requestContext := getRequestContext(c, options) // wrap the response writer in a bodyWriter so we can capture the response body - bw := &responseInterceptor{ResponseWriter: c.Writer, body: bytes.NewBufferString("")} + bw := newResponseInterceptor(c.Writer) c.Writer = bw // Call the next handler in the chain, which will actually process the request @@ -146,7 +109,6 @@ func ValidateResponseFromContext(c *gin.Context, router routers.Router, options } err = openapi3filter.ValidateResponse(requestContext, rspValidationInput) - if err != nil { // restore the original response writer c.Writer = bw.ResponseWriter @@ -164,8 +126,6 @@ func ValidateResponseFromContext(c *gin.Context, router routers.Router, options // openapi errors seem to be multi-line with a decent message on the first errorLines := strings.Split(e.Error(), "\n") return fmt.Errorf("error in openapi3filter.ResponseError: %s", errorLines[0]) - case *openapi3filter.SecurityRequirementsError: - return fmt.Errorf("error in openapi3filter.SecurityRequirementsError: %s", e.Error()) default: // This should never happen today, but if our upstream code changes, // we don't want to crash the server, so handle the unexpected error. diff --git a/oapi_validate_response_test.go b/oapi_validate_response_test.go index 66d8188..e70888a 100644 --- a/oapi_validate_response_test.go +++ b/oapi_validate_response_test.go @@ -35,11 +35,9 @@ func TestOapiResponseValidator(t *testing.T) { swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema) require.NoError(t, err, "Error initializing swagger") - // Create a new echo router + // Create a new gin router g := gin.New() - // Set up an authenticator to check authenticated function. It will allow - // access to "someScope", but disallow others. options := Options{ ErrorHandler: func(c *gin.Context, message string, statusCode int) { c.String(statusCode, "test: "+message) @@ -51,17 +49,21 @@ func TestOapiResponseValidator(t *testing.T) { UserData: "hi!", } - // Install our OpenApi based request validator + // Install our OpenApi based response validator g.Use(OapiResponseValidatorWithOptions(swagger, &options)) - tests := []struct { - name string - operationID string - }{ - { - name: "GET /resource", - operationID: "getResource", - }, + // Test an incorrect route + { + rec := doGet(t, g, "http://deepmap.ai/incorrect") + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Contains(t, rec.Body.String(), "no matching operation was found") + } + + // Test wrong server + { + rec := doGet(t, g, "http://wrongserver.ai/resource") + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Contains(t, rec.Body.String(), "no matching operation was found") } // getResource @@ -235,7 +237,7 @@ func TestOapiResponseValidator(t *testing.T) { rec := doPost(t, g, "http://deepmap.ai/resource", gin.H{"name": "Wilhelm Scream"}) assert.Equal(t, tt.wantStatus, rec.Code) - if tt.wantStatus == http.StatusOK { + if tt.wantStatus == http.StatusCreated { switch tt.contentType { case "application/json": assert.JSONEq(t, tt.wantRsp, rec.Body.String()) @@ -249,6 +251,20 @@ func TestOapiResponseValidator(t *testing.T) { } } + tests := []struct { + name string + operationID string + }{ + { + name: "GET /resource", + operationID: "getResource", + }, + { + name: "POST /resource", + operationID: "createResource", + }, + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { switch tt.operationID { @@ -259,5 +275,35 @@ func TestOapiResponseValidator(t *testing.T) { } }) } +} +func TestOapiResponseValidatorNoOptions(t *testing.T) { + swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema) + require.NoError(t, err, "Error initializing swagger") + + mw := OapiResponseValidator(swagger) + assert.NotNil(t, mw, "Response validator is nil") +} + +func TestOapiResponseValidatorFromYamlFile(t *testing.T) { + // Test that we can load a response validator from a yaml file + { + mw, err := OapiResponseValidatorFromYamlFile("test_response_spec.yaml") + assert.NoError(t, err, "Error initializing response validator") + assert.NotNil(t, mw, "Response validator is nil") + } + + // Test that we get an error when the file does not exist + { + mw, err := OapiResponseValidatorFromYamlFile("nonexistent.yaml") + assert.Error(t, err, "Expected error initializing response validator") + assert.Nil(t, mw, "Response validator is not nil") + } + + // Test that we get an error when the file is not a valid yaml file + { + mw, err := OapiResponseValidatorFromYamlFile("README.md") + assert.Error(t, err, "Expected error initializing response validator") + assert.Nil(t, mw, "Response validator is not nil") + } } diff --git a/response_interceptor.go b/response_interceptor.go new file mode 100644 index 0000000..558bea6 --- /dev/null +++ b/response_interceptor.go @@ -0,0 +1,25 @@ +package ginmiddleware + +import ( + "bytes" + + "github.com/gin-gonic/gin" +) + +type responseInterceptor struct { + gin.ResponseWriter + body *bytes.Buffer +} + +var _ gin.ResponseWriter = &responseInterceptor{} + +func newResponseInterceptor(w gin.ResponseWriter) *responseInterceptor { + return &responseInterceptor{ + ResponseWriter: w, + body: bytes.NewBufferString(""), + } +} + +func (w *responseInterceptor) Write(b []byte) (int, error) { + return w.body.Write(b) +} diff --git a/route.go b/route.go new file mode 100644 index 0000000..eae4f35 --- /dev/null +++ b/route.go @@ -0,0 +1,44 @@ +package ginmiddleware + +import ( + "fmt" + "net/http" + + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" +) + +func getRequestValidationInput( + req *http.Request, + router routers.Router, + options *Options, +) (*openapi3filter.RequestValidationInput, error) { + route, pathParams, err := router.FindRoute(req) + + // We failed to find a matching route for the request. + if err != nil { + switch e := err.(type) { + case *routers.RouteError: + // We've got a bad request, the path requested doesn't match + // either server, or path, or something. + return nil, fmt.Errorf("error validating route: %w", e) + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + return nil, fmt.Errorf("error validating route: %s", err.Error()) + } + } + + reqValidationInput := openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + } + + if options != nil { + reqValidationInput.Options = &options.Options + reqValidationInput.ParamDecoder = options.ParamDecoder + } + + return &reqValidationInput, nil +}