From bc37f209bfef66d8ff84dc7d13dc066107eb3d9c Mon Sep 17 00:00:00 2001 From: RW Date: Wed, 8 Jan 2025 08:19:20 +0100 Subject: [PATCH] refactor(timeout): unify and enhance timeout middleware (#3275) * feat(timeout): unify and enhance timeout middleware - Combine classic context-based timeout with a Goroutine + channel approach - Support custom error list without additional parameters - Return fiber.ErrRequestTimeout for timeouts or listed errors * feat(timeout): unify and enhance timeout middleware - Combine classic context-based timeout with a Goroutine + channel approach - Support custom error list without additional parameters - Return fiber.ErrRequestTimeout for timeouts or listed errors * refactor(timeout): remove goroutine-based logic and improve documentation - Switch to a synchronous approach to avoid data races with fasthttp context - Enhance error handling for deadline and custom errors - Update comments for clarity and maintainability * refactor(timeout): add more test cases and handle zero duration case * refactor(timeout): add more test cases and handle zero duration case * refactor(timeout): add more test cases and handle zero duration case --------- Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- middleware/timeout/timeout.go | 57 +++++++--- middleware/timeout/timeout_test.go | 166 ++++++++++++++++++----------- 2 files changed, 147 insertions(+), 76 deletions(-) diff --git a/middleware/timeout/timeout.go b/middleware/timeout/timeout.go index a88f2e90b1..127fff8723 100644 --- a/middleware/timeout/timeout.go +++ b/middleware/timeout/timeout.go @@ -8,23 +8,52 @@ import ( "github.com/gofiber/fiber/v3" ) -// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response. -func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler { +// New enforces a timeout for each incoming request. If the timeout expires or +// any of the specified errors occur, fiber.ErrRequestTimeout is returned. +func New(h fiber.Handler, timeout time.Duration, tErrs ...error) fiber.Handler { return func(ctx fiber.Ctx) error { - timeoutContext, cancel := context.WithTimeout(ctx.Context(), t) + // If timeout <= 0, skip context.WithTimeout and run the handler as-is. + if timeout <= 0 { + return runHandler(ctx, h, tErrs) + } + + // Create a context with the specified timeout; any operation exceeding + // this deadline will be canceled automatically. + timeoutContext, cancel := context.WithTimeout(ctx.Context(), timeout) defer cancel() + + // Replace the default Fiber context with our timeout-bound context. ctx.SetContext(timeoutContext) - if err := h(ctx); err != nil { - if errors.Is(err, context.DeadlineExceeded) { - return fiber.ErrRequestTimeout - } - for i := range tErrs { - if errors.Is(err, tErrs[i]) { - return fiber.ErrRequestTimeout - } - } - return err + + // Run the handler and check for relevant errors. + err := runHandler(ctx, h, tErrs) + + // If the context actually timed out, return a timeout error. + if errors.Is(timeoutContext.Err(), context.DeadlineExceeded) { + return fiber.ErrRequestTimeout + } + return err + } +} + +// runHandler executes the handler and returns fiber.ErrRequestTimeout if it +// sees a deadline exceeded error or one of the custom "timeout-like" errors. +func runHandler(c fiber.Ctx, h fiber.Handler, tErrs []error) error { + // Execute the wrapped handler synchronously. + err := h(c) + // If the context has timed out, return a request timeout error. + if err != nil && (errors.Is(err, context.DeadlineExceeded) || isCustomError(err, tErrs)) { + return fiber.ErrRequestTimeout + } + return err +} + +// isCustomError checks whether err matches any error in errList using errors.Is. +func isCustomError(err error, errList []error) bool { + for _, e := range errList { + if errors.Is(err, e) { + return true } - return nil } + return false } diff --git a/middleware/timeout/timeout_test.go b/middleware/timeout/timeout_test.go index 2e1756184c..161296a71a 100644 --- a/middleware/timeout/timeout_test.go +++ b/middleware/timeout/timeout_test.go @@ -12,77 +12,119 @@ import ( "github.com/stretchr/testify/require" ) -// go test -run Test_WithContextTimeout -func Test_WithContextTimeout(t *testing.T) { - t.Parallel() - // fiber instance - app := fiber.New() - h := New(func(c fiber.Ctx) error { - sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms") - require.NoError(t, err) - if err := sleepWithContext(c.Context(), sleepTime, context.DeadlineExceeded); err != nil { - return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err)) - } +var ( + // Custom error that we treat like a timeout when returned by the handler. + errCustomTimeout = errors.New("custom timeout error") + + // Some unrelated error that should NOT trigger a request timeout. + errUnrelated = errors.New("unmatched error") +) + +// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled. +func sleepWithContext(ctx context.Context, d time.Duration, te error) error { + timer := time.NewTimer(d) + defer timer.Stop() // Clean up the timer + + select { + case <-ctx.Done(): + return te + case <-timer.C: return nil - }, 100*time.Millisecond) - app.Get("/test/:sleepTime", h) - testTimeout := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") - } - testSucces := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") } - testTimeout("300") - testTimeout("500") - testSucces("50") - testSucces("30") } -var ErrFooTimeOut = errors.New("foo context canceled") +// TestTimeout_Success tests a handler that completes within the allotted timeout. +func TestTimeout_Success(t *testing.T) { + t.Parallel() + app := fiber.New() + + // Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit. + app.Get("/fast", New(func(c fiber.Ctx) error { + // Simulate some work + if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil { + return err + } + return c.SendString("OK") + }, 50*time.Millisecond)) + + req := httptest.NewRequest(fiber.MethodGet, "/fast", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests") +} -// go test -run Test_WithContextTimeoutWithCustomError -func Test_WithContextTimeoutWithCustomError(t *testing.T) { +// TestTimeout_Exceeded tests a handler that exceeds the provided timeout. +func TestTimeout_Exceeded(t *testing.T) { t.Parallel() - // fiber instance app := fiber.New() - h := New(func(c fiber.Ctx) error { - sleepTime, err := time.ParseDuration(c.Params("sleepTime") + "ms") - require.NoError(t, err) - if err := sleepWithContext(c.Context(), sleepTime, ErrFooTimeOut); err != nil { - return fmt.Errorf("%w: execution error", err) + + // This handler sleeps 200ms, exceeding the 100ms limit. + app.Get("/slow", New(func(c fiber.Ctx) error { + if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil { + return err } - return nil - }, 100*time.Millisecond, ErrFooTimeOut) - app.Get("/test/:sleepTime", h) - testTimeout := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") - } - testSucces := func(timeoutStr string) { - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test/"+timeoutStr, nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") - } - testTimeout("300") - testTimeout("500") - testSucces("50") - testSucces("30") + return c.SendString("Should never get here") + }, 100*time.Millisecond)) + + req := httptest.NewRequest(fiber.MethodGet, "/slow", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout") } -func sleepWithContext(ctx context.Context, d time.Duration, te error) error { - timer := time.NewTimer(d) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C +// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout. +func TestTimeout_CustomError(t *testing.T) { + t.Parallel() + app := fiber.New() + + // This handler sleeps 50ms and returns errCustomTimeout if canceled. + app.Get("/custom", New(func(c fiber.Ctx) error { + // Sleep might time out, or might return early. If the context is canceled, + // we treat errCustomTimeout as a 'timeout-like' condition. + if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil { + return fmt.Errorf("wrapped: %w", err) } - return te - case <-timer.C: - } - return nil + return c.SendString("Should never get here") + }, 100*time.Millisecond, errCustomTimeout)) + + req := httptest.NewRequest(fiber.MethodGet, "/custom", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error") +} + +// TestTimeout_UnmatchedError checks that if the handler returns an error +// that is neither a deadline exceeded nor a custom 'timeout' error, it is +// propagated as a regular 500 (internal server error). +func TestTimeout_UnmatchedError(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Get("/unmatched", New(func(_ fiber.Ctx) error { + return errUnrelated // Not in the custom error list + }, 100*time.Millisecond, errCustomTimeout)) + + req := httptest.NewRequest(fiber.MethodGet, "/unmatched", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode, + "Expected 500 because the error is not recognized as a timeout error") +} + +// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero. +// Usually this means the request can never exceed a 'deadline' – effectively no timeout. +func TestTimeout_ZeroDuration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Get("/zero", New(func(c fiber.Ctx) error { + // Sleep 50ms, but there's no real 'deadline' since zero-timeout. + time.Sleep(50 * time.Millisecond) + return c.SendString("No timeout used") + }, 0)) + + req := httptest.NewRequest(fiber.MethodGet, "/zero", nil) + resp, err := app.Test(req) + require.NoError(t, err, "app.Test(req) should not fail") + require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout") }