diff --git a/context.go b/context.go new file mode 100644 index 0000000..94bd40f --- /dev/null +++ b/context.go @@ -0,0 +1,19 @@ +package ginmiddleware + +import ( + "context" + + "github.com/gin-gonic/gin" +) + +func getRequestContext( + c *gin.Context, + options *Options, +) context.Context { + requestContext := context.WithValue(context.Background(), GinContextKey, c) + if options != nil { + requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) + } + + return requestContext +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..4d02a1d --- /dev/null +++ b/error.go @@ -0,0 +1,35 @@ +package ginmiddleware + +import ( + "errors" + "net/http" + + "github.com/getkin/kin-openapi/routers" + "github.com/gin-gonic/gin" +) + +func handleValidationError( + c *gin.Context, + err error, + options *Options, + generalStatusCode int, +) { + var errorHandler ErrorHandler + // if an error handler is provided, use that + if options != nil && options.ErrorHandler != nil { + errorHandler = options.ErrorHandler + } else { + errorHandler = func(c *gin.Context, message string, statusCode int) { + c.AbortWithStatusJSON(statusCode, gin.H{"error": message}) + } + } + + if errors.Is(err, routers.ErrPathNotFound) { + errorHandler(c, err.Error(), http.StatusNotFound) + } else { + errorHandler(c, err.Error(), generalStatusCode) + } + + // in case the handler didn't internally call Abort, stop the chain + c.Abort() +} diff --git a/oapi_validate.go b/oapi_validate_request.go similarity index 50% rename from oapi_validate.go rename to oapi_validate_request.go index dce1061..99b2a40 100644 --- a/oapi_validate.go +++ b/oapi_validate_request.go @@ -15,7 +15,6 @@ package ginmiddleware import ( - "context" "errors" "fmt" "log" @@ -30,11 +29,6 @@ import ( "github.com/gin-gonic/gin" ) -const ( - GinContextKey = "oapi-codegen/gin-context" - UserDataKey = "oapi-codegen/user-data" -) - // OapiValidatorFromYamlFile creates a validator middleware from a YAML file path func OapiValidatorFromYamlFile(path string) (gin.HandlerFunc, error) { data, err := os.ReadFile(path) @@ -57,24 +51,6 @@ func OapiRequestValidator(swagger *openapi3.T) gin.HandlerFunc { return OapiRequestValidatorWithOptions(swagger, nil) } -// ErrorHandler is called when there is an error in validation -type ErrorHandler func(c *gin.Context, message string, statusCode int) - -// MultiErrorHandler is called when oapi returns a MultiError type -type MultiErrorHandler func(openapi3.MultiError) error - -// Options to customize request validation. These are passed through to -// openapi3filter. -type Options struct { - ErrorHandler ErrorHandler - Options openapi3filter.Options - ParamDecoder openapi3filter.ContentParameterDecoder - UserData interface{} - MultiErrorHandler MultiErrorHandler - // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` - SilenceServersWarning bool -} - // OapiRequestValidatorWithOptions creates a validator from a swagger object, with validation options func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc { if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) { @@ -88,22 +64,7 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin. return func(c *gin.Context) { err := ValidateRequestFromContext(c, router, options) if err != nil { - // using errors.Is did not work - if options != nil && options.ErrorHandler != nil && err.Error() == routers.ErrPathNotFound.Error() { - options.ErrorHandler(c, err.Error(), http.StatusNotFound) - // in case the handler didn't internally call Abort, stop the chain - c.Abort() - } else if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(c, err.Error(), http.StatusBadRequest) - // in case the handler didn't internally call Abort, stop the chain - c.Abort() - } else if err.Error() == routers.ErrPathNotFound.Error() { - // note: i am not sure if this is the best way to handle this - c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()}) - } else { - // note: i am not sure if this is the best way to handle this - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } + handleValidationError(c, err, options, http.StatusBadRequest) } c.Next() } @@ -112,38 +73,14 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin. // ValidateRequestFromContext is called from the middleware above and actually does the work // of validating a request. func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *Options) error { - req := c.Request - route, pathParams, err := router.FindRoute(req) - - // We failed to find a matching route for the request. + validationInput, err := getRequestValidationInput(c.Request, router, options) if err != nil { - switch e := err.(type) { - case *routers.RouteError: - // We've got a bad request, the path requested doesn't match - // either server, or path, or something. - return errors.New(e.Reason) - default: - // This should never happen today, but if our upstream code changes, - // we don't want to crash the server, so handle the unexpected error. - return fmt.Errorf("error validating route: %s", err.Error()) - } + return fmt.Errorf("error getting request validation input from route: %w", err) } - validationInput := &openapi3filter.RequestValidationInput{ - Request: req, - PathParams: pathParams, - Route: route, - } - - // Pass the gin context into the request validator, so that any callbacks + // Pass the gin context into the response validator, so that any callbacks // which it invokes make it available. - requestContext := context.WithValue(context.Background(), GinContextKey, c) //nolint:staticcheck - - if options != nil { - validationInput.Options = &options.Options - validationInput.ParamDecoder = options.ParamDecoder - requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) //nolint:staticcheck - } + requestContext := getRequestContext(c, options) err = openapi3filter.ValidateRequest(requestContext, validationInput) if err != nil { @@ -170,42 +107,3 @@ func ValidateRequestFromContext(c *gin.Context, router routers.Router, options * } return nil } - -// GetGinContext gets the echo context from within requests. It returns -// nil if not found or wrong type. -func GetGinContext(c context.Context) *gin.Context { - iface := c.Value(GinContextKey) - if iface == nil { - return nil - } - ginCtx, ok := iface.(*gin.Context) - if !ok { - return nil - } - return ginCtx -} - -func GetUserData(c context.Context) interface{} { - return c.Value(UserDataKey) -} - -// attempt to get the MultiErrorHandler from the options. If it is not set, -// return a default handler -func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler { - if options == nil { - return defaultMultiErrorHandler - } - - if options.MultiErrorHandler == nil { - return defaultMultiErrorHandler - } - - return options.MultiErrorHandler -} - -// defaultMultiErrorHandler returns a StatusBadRequest (400) and a list -// of all of the errors. This method is called if there are no other -// methods defined on the options. -func defaultMultiErrorHandler(me openapi3.MultiError) error { - return fmt.Errorf("multiple errors encountered: %s", me) -} diff --git a/oapi_validate_test.go b/oapi_validate_request_test.go similarity index 89% rename from oapi_validate_test.go rename to oapi_validate_request_test.go index 1b5fe81..feb88ef 100644 --- a/oapi_validate_test.go +++ b/oapi_validate_request_test.go @@ -15,16 +15,12 @@ package ginmiddleware import ( - "bytes" "context" _ "embed" - "encoding/json" "errors" "fmt" "io" "net/http" - "net/http/httptest" - "net/url" "testing" "github.com/getkin/kin-openapi/openapi3" @@ -34,57 +30,11 @@ import ( "github.com/stretchr/testify/require" ) -//go:embed test_spec.yaml -var testSchema []byte - -func doGet(t *testing.T, handler http.Handler, rawURL string) *httptest.ResponseRecorder { - u, err := url.Parse(rawURL) - if err != nil { - t.Fatalf("Invalid url: %s", rawURL) - } - - r, err := http.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - t.Fatalf("Could not construct a request: %s", rawURL) - } - r.Header.Set("accept", "application/json") - r.Header.Set("host", u.Host) - - tt := httptest.NewRecorder() - - handler.ServeHTTP(tt, r) - - return tt -} - -func doPost(t *testing.T, handler http.Handler, rawURL string, jsonBody interface{}) *httptest.ResponseRecorder { - u, err := url.Parse(rawURL) - if err != nil { - t.Fatalf("Invalid url: %s", rawURL) - } - - body, err := json.Marshal(jsonBody) - if err != nil { - t.Fatalf("Could not marshal request body: %v", err) - } - - r, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) - if err != nil { - t.Fatalf("Could not construct a request for URL %s: %v", rawURL, err) - } - r.Header.Set("accept", "application/json") - r.Header.Set("content-type", "application/json") - r.Header.Set("host", u.Host) - - tt := httptest.NewRecorder() - - handler.ServeHTTP(tt, r) - - return tt -} +//go:embed test_request_spec.yaml +var testRequestSchema []byte func TestOapiRequestValidator(t *testing.T) { - swagger, err := openapi3.NewLoader().LoadFromData(testSchema) + swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema) require.NoError(t, err, "Error initializing swagger") // Create a new echo router @@ -232,7 +182,7 @@ func TestOapiRequestValidator(t *testing.T) { } func TestOapiRequestValidatorWithOptionsMultiError(t *testing.T) { - swagger, err := openapi3.NewLoader().LoadFromData(testSchema) + swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema) require.NoError(t, err, "Error initializing swagger") g := gin.New() @@ -335,7 +285,7 @@ func TestOapiRequestValidatorWithOptionsMultiError(t *testing.T) { } func TestOapiRequestValidatorWithOptionsMultiErrorAndCustomHandler(t *testing.T) { - swagger, err := openapi3.NewLoader().LoadFromData(testSchema) + swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema) require.NoError(t, err, "Error initializing swagger") g := gin.New() diff --git a/oapi_validate_response.go b/oapi_validate_response.go new file mode 100644 index 0000000..b5649da --- /dev/null +++ b/oapi_validate_response.go @@ -0,0 +1,143 @@ +// Copyright 2021 DeepMap, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ginmiddleware + +import ( + "bytes" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" + "github.com/getkin/kin-openapi/routers/gorillamux" + "github.com/gin-gonic/gin" +) + +// OapiResponseValidatorFromYamlFile creates a validator middleware from a YAML file path +func OapiResponseValidatorFromYamlFile(path string) (gin.HandlerFunc, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("error reading %s: %s", path, err) + } + + swagger, err := openapi3.NewLoader().LoadFromData(data) + if err != nil { + return nil, fmt.Errorf("error parsing %s as Swagger YAML: %s", + path, err) + } + return OapiRequestValidator(swagger), nil +} + +// OapiResponseValidator is an gin middleware function which validates outgoing HTTP responses +// to make sure that they conform to the given OAPI 3.0 specification. When +// OAPI validation fails on the request, we return an HTTP/500 with error message +func OapiResponseValidator(swagger *openapi3.T) gin.HandlerFunc { + return OapiResponseValidatorWithOptions(swagger, nil) +} + +// OapiResponseValidatorWithOptions creates a validator from a swagger object, with validation options +func OapiResponseValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc { + if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) { + log.Println("WARN: OapiResponseValidatorWithOptions called with an OpenAPI spec that has `Servers` set. This may lead to an HTTP 400 with `no matching operation was found` when sending a valid request, as the validator performs `Host` header validation. If you're expecting `Host` header validation, you can silence this warning by setting `Options.SilenceServersWarning = true`. See https://github.com/deepmap/oapi-codegen/issues/882 for more information.") + } + + router, err := gorillamux.NewRouter(swagger) + if err != nil { + panic(err) + } + return func(c *gin.Context) { + err := ValidateResponseFromContext(c, router, options) + if err != nil { + handleValidationError(c, err, options, http.StatusInternalServerError) + } + + // in case an error was encountered before Next() was called, call it here + c.Next() + } +} + +// ValidateResponseFromContext is called from the middleware above and actually does the work +// of validating a response. +func ValidateResponseFromContext(c *gin.Context, router routers.Router, options *Options) error { + reqValidationInput, err := getRequestValidationInput(c.Request, router, options) + if err != nil { + return fmt.Errorf("error getting request validation input from route: %w", err) + } + + // Pass the gin context into the response validator, so that any callbacks + // which it invokes make it available. + requestContext := getRequestContext(c, options) + + // wrap the response writer in a bodyWriter so we can capture the response body + bw := newResponseInterceptor(c.Writer) + c.Writer = bw + + // Call the next handler in the chain, which will actually process the request + c.Next() + + // capture the response status and body + status := c.Writer.Status() + body := io.NopCloser(bytes.NewReader(bw.body.Bytes())) + + rspValidationInput := &openapi3filter.ResponseValidationInput{ + RequestValidationInput: reqValidationInput, + Status: status, + Header: c.Writer.Header(), + Body: body, + } + + if options != nil { + rspValidationInput.Options = &options.Options + } + + err = openapi3filter.ValidateResponse(requestContext, rspValidationInput) + if err != nil { + // restore the original response writer + c.Writer = bw.ResponseWriter + + me := openapi3.MultiError{} + if errors.As(err, &me) { + errFunc := getMultiErrorHandlerFromOptions(options) + return errFunc(me) + } + + switch e := err.(type) { + case *openapi3filter.ResponseError: + // We've got a bad request + // Split up the verbose error by lines and return the first one + // openapi errors seem to be multi-line with a decent message on the first + errorLines := strings.Split(e.Error(), "\n") + return fmt.Errorf("error in openapi3filter.ResponseError: %s", errorLines[0]) + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + return fmt.Errorf("error validating response: %w", err) + } + } + + // the response is valid, so write the captured response body to the original response writer + _, err = bw.ResponseWriter.Write(bw.body.Bytes()) + if err != nil { + return fmt.Errorf("error writing response body: %w", err) + } + + return nil +} diff --git a/oapi_validate_response_test.go b/oapi_validate_response_test.go new file mode 100644 index 0000000..e70888a --- /dev/null +++ b/oapi_validate_response_test.go @@ -0,0 +1,309 @@ +// Copyright 2021 DeepMap, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ginmiddleware + +import ( + _ "embed" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//go:embed test_response_spec.yaml +var testResponseSchema []byte + +func TestOapiResponseValidator(t *testing.T) { + gin.SetMode(gin.TestMode) + + swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema) + require.NoError(t, err, "Error initializing swagger") + + // Create a new gin router + g := gin.New() + + options := Options{ + ErrorHandler: func(c *gin.Context, message string, statusCode int) { + c.String(statusCode, "test: "+message) + }, + Options: openapi3filter.Options{ + AuthenticationFunc: openapi3filter.NoopAuthenticationFunc, + IncludeResponseStatus: true, + }, + UserData: "hi!", + } + + // Install our OpenApi based response validator + g.Use(OapiResponseValidatorWithOptions(swagger, &options)) + + // Test an incorrect route + { + rec := doGet(t, g, "http://deepmap.ai/incorrect") + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Contains(t, rec.Body.String(), "no matching operation was found") + } + + // Test wrong server + { + rec := doGet(t, g, "http://wrongserver.ai/resource") + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Contains(t, rec.Body.String(), "no matching operation was found") + } + + // getResource + testGetResource := func(t *testing.T, g *gin.Engine) { + var body string + var statusCode int + var contentType string + + // Install a request handler for /resource. + g.GET("/resource", func(c *gin.Context) { + c.Data(statusCode, contentType, []byte(body)) + }) + + tests := []struct { + name string + body string + status int + contentType string + wantRsp string + wantStatus int + }{ + // Let's send a good response, it should pass + { + name: "good response: good status: 200", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusOK, + contentType: "application/json", + wantRsp: `{"name":"Wilhelm Scream", "id":7}`, + wantStatus: http.StatusOK, + }, + // And for 404, it should pass + { + name: "good response: good status: 404", + body: `{"message": "couldn't find the resource"}`, + status: http.StatusNotFound, + contentType: "application/json", + wantRsp: `{"message": "couldn't find the resource"}`, + wantStatus: http.StatusNotFound, + }, + // And for 500, it should pass + { + name: "good response: good status: 500", + body: `{"message": "internal server error"}`, + status: http.StatusInternalServerError, + contentType: "application/json", + wantRsp: `{"message": "internal server error"}`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a bad response, it should fail + { + name: "bad response: good status", + body: `{"name": "Wilhelm Scream", "id": "not a number"}`, + status: http.StatusOK, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: response body doesn't match schema: Error at "/id": value must be an integer`, + wantStatus: http.StatusInternalServerError, + }, + // And for 404, it should fail + { + name: "bad response: missing required property: good status: 404", + body: `{}`, + status: http.StatusNotFound, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: response body doesn't match schema: Error at "/message": property "message" is missing`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a good response, but with a bad status, it should fail + { + name: "good response: bad status", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusCreated, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: status is not supported`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a good response, but with a bad content type, it should fail + { + name: "good response: bad content type", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusOK, + contentType: "text/plain", + wantRsp: `test: error in openapi3filter.ResponseError: response header Content-Type has unexpected value: "text/plain"`, + wantStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body = tt.body + statusCode = tt.status + contentType = tt.contentType + + rec := doGet(t, g, "http://deepmap.ai/resource") + assert.Equal(t, tt.wantStatus, rec.Code) + if tt.wantStatus == http.StatusOK { + switch tt.contentType { + case "application/json": + assert.JSONEq(t, tt.wantRsp, rec.Body.String()) + default: + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + } else { + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + }) + } + } + + // createResource + testCreateResource := func(t *testing.T, g *gin.Engine) { + var body string + var statusCode int + var contentType string + + // Install a request handler for /resource. + g.POST("/resource", func(c *gin.Context) { + c.Data(statusCode, contentType, []byte(body)) + }) + + tests := []struct { + name string + body string + status int + contentType string + wantRsp string + wantStatus int + }{ + // Let's send a good response, it should pass + { + name: "good response: good status: 201", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusCreated, + contentType: "application/json", + wantRsp: `{"name":"Wilhelm Scream", "id":7}`, + wantStatus: http.StatusCreated, + }, + // Let's send a good response, but with a bad status, it should fail + { + name: "good response: bad status: 200", + body: `{"name": "Wilhelm Scream", "id": 7}`, + status: http.StatusOK, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: status is not supported`, + wantStatus: http.StatusInternalServerError, + }, + // Let's send a good response, with different content type, it should pass + { + name: "good response: good status: 504", + body: "Gateway Timeout", + status: http.StatusGatewayTimeout, + contentType: "text/plain", + wantRsp: "Gateway Timeout", + wantStatus: http.StatusGatewayTimeout, + }, + // Let's send a good response, but with a bad content type, it should fail + { + name: "good response: bad content type", + body: `{"message":"timed out waiting for upstream server to respond"}`, + status: http.StatusGatewayTimeout, + contentType: "application/json", + wantRsp: `test: error in openapi3filter.ResponseError: response header Content-Type has unexpected value: "application/json"`, + wantStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body = tt.body + statusCode = tt.status + contentType = tt.contentType + + rec := doPost(t, g, "http://deepmap.ai/resource", gin.H{"name": "Wilhelm Scream"}) + assert.Equal(t, tt.wantStatus, rec.Code) + if tt.wantStatus == http.StatusCreated { + switch tt.contentType { + case "application/json": + assert.JSONEq(t, tt.wantRsp, rec.Body.String()) + default: + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + } else { + assert.Equal(t, tt.wantRsp, rec.Body.String()) + } + }) + } + } + + tests := []struct { + name string + operationID string + }{ + { + name: "GET /resource", + operationID: "getResource", + }, + { + name: "POST /resource", + operationID: "createResource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + switch tt.operationID { + case "getResource": + testGetResource(t, g) + case "createResource": + testCreateResource(t, g) + } + }) + } +} + +func TestOapiResponseValidatorNoOptions(t *testing.T) { + swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema) + require.NoError(t, err, "Error initializing swagger") + + mw := OapiResponseValidator(swagger) + assert.NotNil(t, mw, "Response validator is nil") +} + +func TestOapiResponseValidatorFromYamlFile(t *testing.T) { + // Test that we can load a response validator from a yaml file + { + mw, err := OapiResponseValidatorFromYamlFile("test_response_spec.yaml") + assert.NoError(t, err, "Error initializing response validator") + assert.NotNil(t, mw, "Response validator is nil") + } + + // Test that we get an error when the file does not exist + { + mw, err := OapiResponseValidatorFromYamlFile("nonexistent.yaml") + assert.Error(t, err, "Expected error initializing response validator") + assert.Nil(t, mw, "Response validator is not nil") + } + + // Test that we get an error when the file is not a valid yaml file + { + mw, err := OapiResponseValidatorFromYamlFile("README.md") + assert.Error(t, err, "Expected error initializing response validator") + assert.Nil(t, mw, "Response validator is not nil") + } +} diff --git a/request_doers_test.go b/request_doers_test.go new file mode 100644 index 0000000..b8b08e4 --- /dev/null +++ b/request_doers_test.go @@ -0,0 +1,56 @@ +package ginmiddleware + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func doGet(t *testing.T, handler http.Handler, rawURL string) *httptest.ResponseRecorder { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + r, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + t.Fatalf("Could not construct a request: %s", rawURL) + } + r.Header.Set("accept", "application/json") + r.Header.Set("host", u.Host) + + tt := httptest.NewRecorder() + + handler.ServeHTTP(tt, r) + + return tt +} + +func doPost(t *testing.T, handler http.Handler, rawURL string, jsonBody interface{}) *httptest.ResponseRecorder { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + body, err := json.Marshal(jsonBody) + if err != nil { + t.Fatalf("Could not marshal request body: %v", err) + } + + r, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + t.Fatalf("Could not construct a request for URL %s: %v", rawURL, err) + } + r.Header.Set("accept", "application/json") + r.Header.Set("content-type", "application/json") + r.Header.Set("host", u.Host) + + tt := httptest.NewRecorder() + + handler.ServeHTTP(tt, r) + + return tt +} diff --git a/response_interceptor.go b/response_interceptor.go new file mode 100644 index 0000000..558bea6 --- /dev/null +++ b/response_interceptor.go @@ -0,0 +1,25 @@ +package ginmiddleware + +import ( + "bytes" + + "github.com/gin-gonic/gin" +) + +type responseInterceptor struct { + gin.ResponseWriter + body *bytes.Buffer +} + +var _ gin.ResponseWriter = &responseInterceptor{} + +func newResponseInterceptor(w gin.ResponseWriter) *responseInterceptor { + return &responseInterceptor{ + ResponseWriter: w, + body: bytes.NewBufferString(""), + } +} + +func (w *responseInterceptor) Write(b []byte) (int, error) { + return w.body.Write(b) +} diff --git a/route.go b/route.go new file mode 100644 index 0000000..eae4f35 --- /dev/null +++ b/route.go @@ -0,0 +1,44 @@ +package ginmiddleware + +import ( + "fmt" + "net/http" + + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/getkin/kin-openapi/routers" +) + +func getRequestValidationInput( + req *http.Request, + router routers.Router, + options *Options, +) (*openapi3filter.RequestValidationInput, error) { + route, pathParams, err := router.FindRoute(req) + + // We failed to find a matching route for the request. + if err != nil { + switch e := err.(type) { + case *routers.RouteError: + // We've got a bad request, the path requested doesn't match + // either server, or path, or something. + return nil, fmt.Errorf("error validating route: %w", e) + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + return nil, fmt.Errorf("error validating route: %s", err.Error()) + } + } + + reqValidationInput := openapi3filter.RequestValidationInput{ + Request: req, + PathParams: pathParams, + Route: route, + } + + if options != nil { + reqValidationInput.Options = &options.Options + reqValidationInput.ParamDecoder = options.ParamDecoder + } + + return &reqValidationInput, nil +} diff --git a/test_spec.yaml b/test_request_spec.yaml similarity index 100% rename from test_spec.yaml rename to test_request_spec.yaml diff --git a/test_response_spec.yaml b/test_response_spec.yaml new file mode 100644 index 0000000..fb650f4 --- /dev/null +++ b/test_response_spec.yaml @@ -0,0 +1,77 @@ +openapi: '3.0.0' +info: + version: 1.0.0 + title: TestServer +servers: + - url: http://deepmap.ai/ +paths: + /resource: + get: + operationId: getResource + parameters: + - name: id + in: query + schema: + type: integer + responses: + '200': + description: success + content: + application/json: + schema: + required: + - name + - id + properties: + name: + type: string + id: + type: integer + '404': + description: not found + content: + application/json: + schema: + required: + - message + properties: + message: + type: string + '500': + description: internal server error + content: + application/json: + schema: + properties: + message: + type: string + post: + operationId: createResource + responses: + '201': + description: created + content: + application/json: + schema: + required: + - name + - id + properties: + name: + type: string + id: + type: integer + '504': + description: gateway timeout + content: + text/plain: + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + properties: + name: + type: string diff --git a/types.go b/types.go new file mode 100644 index 0000000..c13b706 --- /dev/null +++ b/types.go @@ -0,0 +1,33 @@ +package ginmiddleware + +import ( + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/gin-gonic/gin" +) + +type GinContextKeyType string +type UserDataKeyType string + +const ( + GinContextKey GinContextKeyType = "oapi-codegen/gin-context" + UserDataKey UserDataKeyType = "oapi-codegen/user-data" +) + +// ErrorHandler is called when there is an error in validation +type ErrorHandler func(c *gin.Context, message string, statusCode int) + +// MultiErrorHandler is called when oapi returns a MultiError type +type MultiErrorHandler func(openapi3.MultiError) error + +// Options to customize request validation. These are passed through to +// openapi3filter. +type Options struct { + ErrorHandler ErrorHandler + Options openapi3filter.Options + ParamDecoder openapi3filter.ContentParameterDecoder + UserData interface{} + MultiErrorHandler MultiErrorHandler + // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` + SilenceServersWarning bool +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..702efee --- /dev/null +++ b/utils.go @@ -0,0 +1,47 @@ +package ginmiddleware + +import ( + "context" + "fmt" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/gin-gonic/gin" +) + +// GetGinContext gets the echo context from within requests. It returns +// nil if not found or wrong type. +func GetGinContext(c context.Context) *gin.Context { + iface := c.Value(GinContextKey) + if iface == nil { + return nil + } + ginCtx, ok := iface.(*gin.Context) + if !ok { + return nil + } + return ginCtx +} + +func GetUserData(c context.Context) interface{} { + return c.Value(UserDataKey) +} + +// attempt to get the MultiErrorHandler from the options. If it is not set, +// return a default handler +func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler { + if options == nil { + return defaultMultiErrorHandler + } + + if options.MultiErrorHandler == nil { + return defaultMultiErrorHandler + } + + return options.MultiErrorHandler +} + +// defaultMultiErrorHandler returns a list of all of the errors. +// This method is called if there are no other methods defined on the options. +func defaultMultiErrorHandler(me openapi3.MultiError) error { + return fmt.Errorf("multiple errors encountered: %s", me) +}