diff --git a/oapi_validate.go b/oapi_validate.go index 5bbce40..db5f540 100644 --- a/oapi_validate.go +++ b/oapi_validate.go @@ -4,6 +4,7 @@ package nethttpmiddleware import ( + "context" "errors" "fmt" "log" @@ -19,14 +20,19 @@ import ( // ErrorHandler is called when there is an error in validation type ErrorHandler func(w http.ResponseWriter, message string, statusCode int) +// ErrorHandlerWithContext is called when there is an error in validation +// If both ErrorHandlerWithContext and ErrorHandler are set, ErrorHandlerWithContext will be used +type ErrorHandlerWithContext func(ctx context.Context, w http.ResponseWriter, message string, statusCode int) + // MultiErrorHandler is called when oapi returns a MultiError type type MultiErrorHandler func(openapi3.MultiError) (int, error) // Options to customize request validation, openapi3filter specified options will be passed through. type Options struct { - Options openapi3filter.Options - ErrorHandler ErrorHandler - MultiErrorHandler MultiErrorHandler + Options openapi3filter.Options + ErrorHandler ErrorHandler + ErrorHandlerWithContext ErrorHandlerWithContext + MultiErrorHandler MultiErrorHandler // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` SilenceServersWarning bool } @@ -52,8 +58,12 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) func // validate request if statusCode, err := validateRequest(r, router, options); err != nil { - if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(w, err.Error(), statusCode) + if options != nil && (options.ErrorHandlerWithContext != nil || options.ErrorHandler != nil) { + if options.ErrorHandlerWithContext != nil { + options.ErrorHandlerWithContext(r.Context(), w, err.Error(), statusCode) + } else { + options.ErrorHandler(w, err.Error(), statusCode) + } } else { http.Error(w, err.Error(), statusCode) }