Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RecoverInterceptor as alternative to WithRecover #824

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ func WithHandlerOptions(options ...HandlerOption) HandlerOption {

// WithRecover adds an interceptor that recovers from panics. The supplied
// function receives the context, [Spec], request headers, and the recovered
// value (which may be nil). It must return an error to send back to the
// client. It may also log the panic, emit metrics, or execute other
// error-handling logic. Handler functions must be safe to call concurrently.
// value. It must return an error to send back to the client. It may also log
// the panic, emit metrics, or execute other error-handling logic. Handler
// functions must be safe to call concurrently.
//
// To preserve compatibility with [net/http]'s semantics, this interceptor
// doesn't handle panics with [http.ErrAbortHandler].
Expand All @@ -150,8 +150,17 @@ func WithHandlerOptions(options ...HandlerOption) HandlerOption {
// usually necessary to prevent crashes. Instead, it helps servers collect
// RPC-specific data during panics and send a more detailed error to
// clients.
//
// Deprecated: Use RecoverInterceptor to create an interceptor and
// WithInterceptors to register it via HandlerOption.
func WithRecover(handle func(context.Context, Spec, http.Header, any) error) HandlerOption {
return WithInterceptors(&recoverHandlerInterceptor{handle: handle})
return WithInterceptors(RecoverInterceptor(func(ctx context.Context, req AnyRequest, panicValue any) error {
//nolint:errorlint,goerr113 // net/http checks for ErrAbortHandler with ==, so we should too
if panicValue == http.ErrAbortHandler {
panic(panicValue) //nolint:forbidigo
}
return handle(ctx, req.Spec(), req.Header(), panicValue)
}))
}

// WithRequireConnectProtocolHeader configures the Handler to require requests
Expand Down
288 changes: 266 additions & 22 deletions recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,292 @@ package connect
import (
"context"
"net/http"
"sync/atomic"
)

// recoverHandlerInterceptor lets handlers trap panics, perform side effects
// (like emitting logs or metrics), and present a friendlier error message to
// clients.
type recoverHandlerInterceptor struct {
Interceptor
// RecoverInterceptor is an interceptor that recovers from panics. The
// supplied function receives the context and request details.
//
// For streaming RPCs, req.Any() may return nil. It will always be nil
// for client-streaming or bidi-streaming RPCs, since there could be
// zero or even multiple request messages for such RPCs. For
// server-streaming RPCs, it will be nil if the panic occurred before
// the request message was received, which can happen if a panic occurs
// in an interceptor before the RPC handler method is invoked.
//
// Similarly, for streaming RPCs, req.Header() may return nil. This
// could happen in clients when the panic that is recovered occurs
// before the stream is actually created and before request headers are
// even allocated.
//
// Applications will generally want to add this interceptor first, which
// means it will actually be the last to handle any results from the
// RPC handler. This allows for recovering from the panics not only in
// the handler but also in any other interceptors.
//
// The recovered value will never be nil. If panic was called with a nil
// value, the recovered value will be a *[runtime.PanicNilError]. It must
// return an error to send back to the client. If it returns nil, an
// *Error with a code of CodeInternal will ne synthesized. The function
// may also log the panic, emit metrics, or execute other error-handling
// logic. The function must be safe to call concurrently.
//
// By default, handlers don't recover from panics. Because the standard
// library's [http.Server] recovers from panics by default, this option
// isn't usually necessary to prevent crashes. Instead, it helps servers
// collect RPC-specific data during panics and send a more detailed error
// to clients.
//
// Unlike [WithRecover], this interceptor does not do anything special with
// [http.ErrAbortHandler], so the handle function may be called with that as
// the panic value.
//
// Also unlike [WithRecover], which can only be used with handlers, this
// interceptor can be used with clients, to recover from any panics caused
// by bugs in the interceptor chain. For streaming RPCs, this will recover
// from panics that happen in calls to send or receive messages on the
// stream or to close the stream.
func RecoverInterceptor(handle func(ctx context.Context, req AnyRequest, panicValue any) error) Interceptor {
return &recoverHandlerInterceptor{handle: handle}
}

handle func(context.Context, Spec, http.Header, any) error
type recoverHandlerInterceptor struct {
handle func(context.Context, AnyRequest, any) error
}

func (i *recoverHandlerInterceptor) WrapUnary(next UnaryFunc) UnaryFunc {
return func(ctx context.Context, req AnyRequest) (_ AnyResponse, retErr error) {
if req.Spec().IsClient {
return next(ctx, req)
}
defer func() {
if r := recover(); r != nil {
// net/http checks for ErrAbortHandler with ==, so we should too.
if r == http.ErrAbortHandler { //nolint:errorlint,goerr113
panic(r) //nolint:forbidigo
retErr = i.handle(ctx, req, r)
if retErr == nil {
retErr = errorf(CodeInternal, "handler panicked; but recover handler returned non-nil error")
}
retErr = i.handle(ctx, req.Spec(), req.Header(), r)
}
}()
res, err := next(ctx, req)
return res, err
return next(ctx, req)
}
}

func (i *recoverHandlerInterceptor) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandlerFunc {
return func(ctx context.Context, conn StreamingHandlerConn) (retErr error) {
var streamConn *recoverStreamingHandlerConn
if conn.Spec().StreamType == StreamTypeServer {
// There will be exactly one request. So we try to capture it
// so we can provide it to the recover handle func.
streamConn = &recoverStreamingHandlerConn{StreamingHandlerConn: conn}
conn = streamConn
}

defer func() {
if r := recover(); r != nil {
// net/http checks for ErrAbortHandler with ==, so we should too.
if r == http.ErrAbortHandler { //nolint:errorlint,goerr113
panic(r) //nolint:forbidigo
if panicVal := recover(); panicVal != nil {
var msg any
if streamConn != nil {
if msgPtr := streamConn.req.Load(); msgPtr != nil {
msg = *msgPtr
}
}
retErr = i.handle(ctx, &recoverStreamRequest{conn, msg}, panicVal)
if retErr == nil {
retErr = errorf(CodeInternal, "handler panicked; but recover handler returned non-nil error")
}
}
}()
return next(ctx, conn)
}
}

func (i *recoverHandlerInterceptor) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc {
return func(ctx context.Context, spec Spec) (conn StreamingClientConn) {
defer func() {
if panicVal := recover(); panicVal != nil {
err := i.handle(ctx, emptyRequest(spec), panicVal)
if err == nil {
err = errorf(CodeInternal, "call panicked; but recover handler returned non-nil error")
}
retErr = i.handle(ctx, conn.Spec(), conn.RequestHeader(), r)
conn = &errStreamingClientConn{spec, err}
}
}()
err := next(ctx, conn)
return err
conn = next(ctx, spec)
return &recoverStreamingClientConn{
StreamingClientConn: conn,
ctx: ctx,
handle: i.handle,
}
}
}

type recoverStreamRequest struct {
StreamingHandlerConn
msg any
}

func (r *recoverStreamRequest) Any() any {
return r.msg
}

func (r *recoverStreamRequest) Header() http.Header {
return r.RequestHeader()
}

func (r *recoverStreamRequest) HTTPMethod() string {
return http.MethodPost // streams always use POST
}

func (r *recoverStreamRequest) internalOnly() {
}

func (r *recoverStreamRequest) setRequestMethod(_ string) {
// only invoked internally for unary RPCs; safe to ignore
}

type recoverStreamingHandlerConn struct {
StreamingHandlerConn
req atomic.Pointer[any]
}

func (r *recoverStreamingHandlerConn) Receive(msg any) error {
err := r.StreamingHandlerConn.Receive(msg)
if err == nil {
// Note: The framework instantiates msg, passes it to
// this method, and then returns it to the application.
// It is possible that the application could mutate the
// value, so what we provide to the recover handler would
// then differ from the message actually received. But
// this is no different than if the RPC handler mutated
// the request message for a unary RPC and interceptors
// later examined it via Request.Any. So we tolerate the
// possibility for server-stream requests, too.
r.req.Store(&msg)
}
return err
}

type emptyRequest Spec

func (e emptyRequest) Any() any {
return nil
}

func (e emptyRequest) Spec() Spec {
return Spec(e)
}

func (e emptyRequest) Peer() Peer {
return Peer{}
}

func (e emptyRequest) Header() http.Header {
return nil
}

func (e emptyRequest) HTTPMethod() string {
return http.MethodPost
}

func (e emptyRequest) internalOnly() {
}

func (e emptyRequest) setRequestMethod(_ string) {
// only invoked internally for unary RPCs; safe to ignore
}

type errStreamingClientConn struct {
spec Spec
err error
}

func (e *errStreamingClientConn) Spec() Spec {
return e.spec
}

func (e *errStreamingClientConn) Peer() Peer {
return Peer{}
}

func (e *errStreamingClientConn) Send(_ any) error {
return e.err
}

func (e *errStreamingClientConn) RequestHeader() http.Header {
// Clients can add headers before calling Send, so this must be mutable/non-nil.
return http.Header{} // TODO: memoize so we never allocate more than one?
}

func (e *errStreamingClientConn) CloseRequest() error {
return e.err
}

func (e *errStreamingClientConn) Receive(_ any) error {
return e.err
}

func (e *errStreamingClientConn) ResponseHeader() http.Header {
return nil
}

func (e *errStreamingClientConn) ResponseTrailer() http.Header {
return nil
}

func (e *errStreamingClientConn) CloseResponse() error {
return e.err
}

type recoverStreamingClientConn struct {
StreamingClientConn

//nolint:containedctx // must memoize the stream context to pass to recover handler
ctx context.Context
handle func(context.Context, AnyRequest, any) error
req atomic.Pointer[any]
}

func (r *recoverStreamingClientConn) Send(msg any) error {
if r.Spec().StreamType == StreamTypeServer {
// Capture the request message for server-streaming RPCs.
r.req.Store(&msg)
}
return r.invoke(func() error {
return r.StreamingClientConn.Send(msg)
})
}

func (r *recoverStreamingClientConn) RequestHeader() http.Header {
if header := r.StreamingClientConn.RequestHeader(); header != nil {
return header
}
// Clients can add headers before calling Send, so this must be mutable/non-nil.
// We do this not to recover from a panic but in the hopes of preventing panics in the caller.
return http.Header{} // TODO: memoize so we never allocate more than one?
}

func (r *recoverStreamingClientConn) CloseRequest() error {
return r.invoke(r.StreamingClientConn.CloseRequest)
}

func (r *recoverStreamingClientConn) Receive(msg any) error {
return r.invoke(func() error {
return r.StreamingClientConn.Receive(msg)
})
}

func (r *recoverStreamingClientConn) CloseResponse() error {
return r.invoke(r.StreamingClientConn.CloseResponse)
}

func (r *recoverStreamingClientConn) invoke(action func() error) (retErr error) {
defer func() {
if panicVal := recover(); panicVal != nil {
var msg any
if msgPtr := r.req.Load(); msgPtr != nil {
msg = *msgPtr
}
retErr = r.handle(r.ctx, &recoverStreamRequest{r, msg}, panicVal)
if retErr == nil {
retErr = errorf(CodeInternal, "call panicked; but recover handler returned non-nil error")
}
}
}()
return action()
}
Loading