Skip to content

feat: add error handler with more configuration #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 177 additions & 12 deletions oapi_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package nethttpmiddleware

import (
"context"
"errors"
"fmt"
"log"
Expand All @@ -21,8 +22,58 @@ import (
)

// ErrorHandler is called when there is an error in validation
//
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
//
// Deprecated: it's recommended you migrate to the ErrorHandlerWithOpts, as it provides more control over how to handle an error that occurs, including giving direct access to the `error` itself. There are no plans to remove this method.
type ErrorHandler func(w http.ResponseWriter, message string, statusCode int)

// ErrorHandlerWithOpts is called when there is an error in validation, with more information about the `error` that occurred and which request is currently being processed.
//
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
//
// NOTE that this should ideally be used instead of ErrorHandler
type ErrorHandlerWithOpts func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts ErrorHandlerOpts)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a context.Context here - I thought it'd make more sense as a function parameter (and the first) than in the opts - thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamietanna — As it has been a while since I was working on the project whose use-case caused me to write a PR the specifics faded from memory, and I commented because I was vaguely thinking there would two different contexts. However, in reviewing your reply and your code I seems I was just not thinking it through clearly enough and now see that your PR is probably the better and more idiomatic approach.

So, all good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no worries! No, I very much appreciate your review - I think if used with i.e. Gin, you may want the gin.Context and the context.Context, but in this case for pure net/http, we won't have that exposed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamietanna — Hey, though. Super excited to see you address this as I am sure I will need it in the future.


// ErrorHandlerOpts contains additional options that are passed to the `ErrorHandlerWithOpts` function in the case of an error being returned by the middleware
type ErrorHandlerOpts struct {
// Error is the underlying error that triggered this error handler to be executed.
//
// Known error types:
//
// - `*openapi3filter.SecurityRequirementsError` - if the `AuthenticationFunc` has failed to authenticate the request
// - `*openapi3filter.RequestError` - if a bad request has been made
//
// Additionally, if you have set `openapi3filter.Options#MultiError`:
//
// - `openapi3.MultiError` (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
Error error

// StatusCode indicates the HTTP Status Code that the OpenAPI validation middleware _suggests_ is returned to the user.
//
// NOTE that this is very much a suggestion, and can be overridden if you believe you have a better approach.
StatusCode int

// MatchedRoute is the underlying path that this request is being matched against.
//
// This is the route according to the OpenAPI validation middleware, and can be used in addition to/instead of the `http.Request`
//
// NOTE that this will be nil if there is no matched route (i.e. a request has been sent to an endpoint not in the OpenAPI spec)
MatchedRoute *ErrorHandlerOptsMatchedRoute
}

type ErrorHandlerOptsMatchedRoute struct {
// Route indicates the Route that this error is received by.
//
// This can be used in addition to/instead of the `http.Request`.
Route *routers.Route

// PathParams are any path parameters that are determined from the request.
//
// This can be used in addition to/instead of the `http.Request`.
PathParams map[string]string
}

// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
type MultiErrorHandler func(openapi3.MultiError) (int, error)

Expand All @@ -32,11 +83,21 @@ type Options struct {
Options openapi3filter.Options
// ErrorHandler is called when a validation error occurs.
//
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
//
// If not provided, `http.Error` will be called
ErrorHandler ErrorHandler

// ErrorHandlerWithOpts is called when there is an error in validation.
//
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
ErrorHandlerWithOpts ErrorHandlerWithOpts

// MultiErrorHandler is called when there is an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) returned by the `openapi3filter`.
//
// If not provided `defaultMultiErrorHandler` will be used.
//
// Does not get called when using `ErrorHandlerWithOpts`
MultiErrorHandler MultiErrorHandler
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
SilenceServersWarning bool
Expand All @@ -62,24 +123,96 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// validate request
if statusCode, err := validateRequest(r, router, options); err != nil {
if options != nil && options.ErrorHandler != nil {
options.ErrorHandler(w, err.Error(), statusCode)
} else {
http.Error(w, err.Error(), statusCode)
}
return
if options == nil {
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
} else if options.ErrorHandlerWithOpts != nil {
performRequestValidationForErrorHandlerWithOpts(next, w, r, router, options)
} else if options.ErrorHandler != nil {
performRequestValidationForErrorHandler(next, w, r, router, options, options.ErrorHandler)
} else {
// NOTE that this shouldn't happen, but let's be sure that we always end up calling the default error handler if no other handler is defined
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
}

// serve
next.ServeHTTP(w, r)
})
}

}

func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) {
// validate request
statusCode, err := validateRequest(r, router, options)
if err == nil {
// serve
next.ServeHTTP(w, r)
return
}

errorHandler(w, err.Error(), statusCode)
}

// Note that this is an inline-and-modified version of `validateRequest`, with a simplified control flow and providing full access to the `error` for the `ErrorHandlerWithOpts` function.
func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) {
// Find route
route, pathParams, err := router.FindRoute(r)
if err != nil {
errOpts := ErrorHandlerOpts{
// MatchedRoute will be nil, as we've not matched a route we know about
Error: err,
StatusCode: http.StatusNotFound,
}

options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
return
}

errOpts := ErrorHandlerOpts{
MatchedRoute: &ErrorHandlerOptsMatchedRoute{
Route: route,
PathParams: pathParams,
},
// other options will be added before executing
}

// Validate request
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: r,
PathParams: pathParams,
Route: route,
}

if options != nil {
requestValidationInput.Options = &options.Options
}

err = openapi3filter.ValidateRequest(r.Context(), requestValidationInput)
if err == nil {
// it's a valid request, so serve it
next.ServeHTTP(w, r)
return
}

switch e := err.(type) {
case openapi3.MultiError:
errOpts.Error = e
errOpts.StatusCode = determineStatusCodeForMultiError(e)
case *openapi3filter.RequestError:
// We've got a bad request
errOpts.Error = e
errOpts.StatusCode = http.StatusBadRequest
case *openapi3filter.SecurityRequirementsError:
errOpts.Error = e
errOpts.StatusCode = http.StatusUnauthorized
default:
// This should never happen today, but if our upstream code changes,
// we don't want to crash the server, so handle the unexpected error.
// return http.StatusInternalServerError,
errOpts.Error = fmt.Errorf("error validating route: %w", e)
errOpts.StatusCode = http.StatusUnauthorized
}

options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
}

// validateRequest is called from the middleware above and actually does the work
// of validating a request.
func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) {
Expand Down Expand Up @@ -147,3 +280,35 @@ func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
func defaultMultiErrorHandler(me openapi3.MultiError) (int, error) {
return http.StatusBadRequest, me
}

func determineStatusCodeForMultiError(errs openapi3.MultiError) int {
numRequestErrors := 0
numSecurityRequirementsErrors := 0

for _, err := range errs {
switch err.(type) {
case *openapi3filter.RequestError:
numRequestErrors++
case *openapi3filter.SecurityRequirementsError:
numSecurityRequirementsErrors++
default:
// if we have /any/ unknown error types, we should suggest returning an HTTP 500 Internal Server Error
return http.StatusInternalServerError
}
}

if numRequestErrors > 0 && numSecurityRequirementsErrors > 0 {
return http.StatusInternalServerError
}

if numRequestErrors > 0 {
return http.StatusBadRequest
}

if numSecurityRequirementsErrors > 0 {
return http.StatusUnauthorized
}

// we shouldn't hit this, but to be safe, return an HTTP 500 Internal Server Error if we don't have any cases above
return http.StatusInternalServerError
}
Loading