diff --git a/CHANGELOG.md b/CHANGELOG.md index af4d00eb6c..1f0c1676db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- `opentelemetry-instrumentation-fastapi`: Fix middleware ordering to cover all exception handling use cases. + ([#3664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3664)) +- `opentelemetry-instrumentation-asgi`: Make all user hooks failsafe and record exceptions in hooks. + ([#3664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3664)) - `opentelemetry-instrumentation`: Avoid calls to `context.detach` with `None` token. ([#3673](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3673)) diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 7e72dbf11f..a6be59b86f 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -268,7 +268,7 @@ def client_response_hook(span: Span, scope: Scope, message: dict[str, Any]): HTTP_SERVER_REQUEST_DURATION, ) from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import set_span_in_context +from opentelemetry.trace import Span, set_span_in_context from opentelemetry.util.http import ( OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST, @@ -646,9 +646,23 @@ def __init__( self.default_span_details = ( default_span_details or get_default_span_details ) - self.server_request_hook = server_request_hook - self.client_request_hook = client_request_hook - self.client_response_hook = client_response_hook + + def failsafe(func): + if func is None: + return None + + @wraps(func) + def wrapper(span: Span, *args, **kwargs): + try: + func(span, *args, **kwargs) + except Exception as exc: # pylint: disable=broad-exception-caught + span.record_exception(exc) + + return wrapper + + self.server_request_hook = failsafe(server_request_hook) + self.client_request_hook = failsafe(client_request_hook) + self.client_response_hook = failsafe(client_response_hook) self.content_length_header = None self._sem_conv_opt_in_mode = sem_conv_opt_in_mode diff --git a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py index b8791cf730..23a830bcb6 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py @@ -277,6 +277,17 @@ async def error_asgi(scope, receive, send): await send({"type": "http.response.body", "body": b"*"}) +class UnhandledException(Exception): + pass + + +def failing_hook(msg): + def hook(*_): + raise UnhandledException(msg) + + return hook + + # pylint: disable=too-many-public-methods class TestAsgiApplication(AsyncAsgiTestBase): def setUp(self): @@ -481,6 +492,12 @@ def validate_outputs( span.instrumentation_scope.name, "opentelemetry.instrumentation.asgi", ) + if "events" in expected: + self.assertEqual(len(span.events), len(expected["events"])) + for event, expected in zip(span.events, expected["events"]): + self.assertEqual(event.name, expected["name"]) + for name, value in expected["attributes"].items(): + self.assertEqual(event.attributes[name], value) async def test_basic_asgi_call(self): """Test that spans are emitted as expected.""" @@ -1206,6 +1223,40 @@ def update_expected_hook_results(expected): outputs, modifiers=[update_expected_hook_results] ) + async def test_hook_exceptions(self): + def exception_event(msg): + return { + "name": "exception", + "attributes": { + "exception.type": f"{__name__}.UnhandledException", + "exception.message": msg, + }, + } + + def update_expected_hook_results(expected): + for entry in expected: + if entry["kind"] == trace_api.SpanKind.SERVER: + entry["events"] = [exception_event("server request")] + elif entry["name"] == "GET / http receive": + entry["events"] = [exception_event("client request")] + elif entry["name"] == "GET / http send": + entry["events"] = [exception_event("client response")] + + return expected + + app = otel_asgi.OpenTelemetryMiddleware( + simple_asgi, + server_request_hook=failing_hook("server request"), + client_request_hook=failing_hook("client request"), + client_response_hook=failing_hook("client response"), + ) + self.seed_app(app) + await self.send_default_request() + outputs = await self.get_all_output() + self.validate_outputs( + outputs, modifiers=[update_expected_hook_results] + ) + async def test_asgi_metrics(self): app = otel_asgi.OpenTelemetryMiddleware(simple_asgi) self.seed_app(app) diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py index 8ba83985c6..d697ac9b7e 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py @@ -191,7 +191,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from starlette.applications import Starlette from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import Match -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -210,7 +210,8 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.metrics import MeterProvider, get_meter from opentelemetry.semconv.attributes.http_attributes import HTTP_ROUTE -from opentelemetry.trace import TracerProvider, get_tracer +from opentelemetry.trace import TracerProvider, get_current_span, get_tracer +from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util.http import ( get_excluded_urls, parse_excluded_urls, @@ -242,7 +243,7 @@ def instrument_app( http_capture_headers_server_response: list[str] | None = None, http_capture_headers_sanitize_fields: list[str] | None = None, exclude_spans: list[Literal["receive", "send"]] | None = None, - ): + ): # pylint: disable=too-many-locals """Instrument an uninstrumented FastAPI application. Args: @@ -289,17 +290,77 @@ def instrument_app( schema_url=_get_schema_url(sem_conv_opt_in_mode), ) - # Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware - # as the outermost middleware. - # Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able - # to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do. - def build_middleware_stack(self: Starlette) -> ASGIApp: - inner_server_error_middleware: ASGIApp = ( # type: ignore + # Define an additional middleware for exception handling + # Normally, `opentelemetry.trace.use_span` covers the recording of + # exceptions into the active span, but `OpenTelemetryMiddleware` + # ends the span too early before the exception can be recorded. + class ExceptionHandlerMiddleware: + def __init__(self, app): + self.app = app + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + try: + await self.app(scope, receive, send) + except Exception as exc: # pylint: disable=broad-exception-caught + span = get_current_span() + span.record_exception(exc) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + raise + + # For every possible use case of error handling, exception + # handling, trace availability in exception handlers and + # automatic exception recording to work, we need to make a + # series of wrapping and re-wrapping middlewares. + + # First, grab the original middleware stack from Starlette. It + # comprises a stack of + # `ServerErrorMiddleware` -> [user defined middlewares] -> `ExceptionMiddleware` + inner_server_error_middleware: ServerErrorMiddleware = ( # type: ignore self._original_build_middleware_stack() # type: ignore ) + + if not isinstance( + inner_server_error_middleware, ServerErrorMiddleware + ): + # Oops, something changed about how Starlette creates middleware stacks + _logger.error( + "Cannot instrument FastAPI as the expected middleware stack has changed" + ) + return inner_server_error_middleware + + # We take [user defined middlewares] -> `ExceptionHandlerMiddleware` + # out of the outermost `ServerErrorMiddleware` and instead pass + # it to our own `ExceptionHandlerMiddleware` + exception_middleware = ExceptionHandlerMiddleware( + inner_server_error_middleware.app + ) + + # Now, we create a new `ServerErrorMiddleware` that wraps + # `ExceptionHandlerMiddleware` but otherwise uses the same + # original `handler` and debug setting. The end result is a + # middleware stack that's identical to the original stack except + # all user middlewares are covered by our + # `ExceptionHandlerMiddleware`. + error_middleware = ServerErrorMiddleware( + app=exception_middleware, + handler=inner_server_error_middleware.handler, + debug=inner_server_error_middleware.debug, + ) + + # Finally, we wrap the stack above in our actual OTEL + # middleware. As a result, an active tracing context exists for + # every use case of user-defined error and exception handlers as + # well as automatic recording of exceptions in active spans. otel_middleware = OpenTelemetryMiddleware( - inner_server_error_middleware, + error_middleware, excluded_urls=excluded_urls, default_span_details=_get_default_span_details, server_request_hook=server_request_hook, @@ -313,23 +374,18 @@ def build_middleware_stack(self: Starlette) -> ASGIApp: http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, exclude_spans=exclude_spans, ) - # Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware - # are handled. - # This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that - # to impact the user's application just because we wrapped the middlewares in this order. - if isinstance( - inner_server_error_middleware, ServerErrorMiddleware - ): # usually true - outer_server_error_middleware = ServerErrorMiddleware( - app=otel_middleware, - ) - else: - # Something else seems to have patched things, or maybe Starlette changed. - # Just create a default ServerErrorMiddleware. - outer_server_error_middleware = ServerErrorMiddleware( - app=otel_middleware - ) - return outer_server_error_middleware + + # Ultimately, wrap everything in another default + # `ServerErrorMiddleware` (w/o user handlers) so that any + # exceptions raised in `OpenTelemetryMiddleware` are handled. + # + # This should not happen unless there is a bug in + # OpenTelemetryMiddleware, but if there is we don't want that to + # impact the user's application just because we wrapped the + # middlewares in this order. + return ServerErrorMiddleware( + app=otel_middleware, + ) app._original_build_middleware_stack = app.build_middleware_stack app.build_middleware_stack = types.MethodType( diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py index 523c165f85..33a6bec0ac 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py @@ -21,7 +21,7 @@ import fastapi from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, PlainTextResponse from fastapi.testclient import TestClient import opentelemetry.instrumentation.fastapi as otel_fastapi @@ -38,9 +38,7 @@ from opentelemetry.instrumentation.auto_instrumentation._load import ( _load_instrumentors, ) -from opentelemetry.instrumentation.dependencies import ( - DependencyConflict, -) +from opentelemetry.instrumentation.dependencies import DependencyConflict from opentelemetry.sdk.metrics.export import ( HistogramDataPoint, NumberDataPoint, @@ -59,6 +57,9 @@ from opentelemetry.semconv._incubating.attributes.net_attributes import ( NET_HOST_PORT, ) +from opentelemetry.semconv.attributes.exception_attributes import ( + EXCEPTION_TYPE, +) from opentelemetry.semconv.attributes.http_attributes import ( HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, @@ -70,6 +71,7 @@ from opentelemetry.semconv.attributes.url_attributes import URL_SCHEME from opentelemetry.test.globals_test import reset_trace_globals from opentelemetry.test.test_base import TestBase +from opentelemetry.trace.status import StatusCode from opentelemetry.util._importlib_metadata import entry_points from opentelemetry.util.http import ( OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, @@ -1877,3 +1879,138 @@ def test_custom_header_not_present_in_non_recording_span(self): self.assertEqual(200, resp.status_code) span_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(span_list), 0) + + +class TestTraceableExceptionHandling(TestBase): + """Tests to ensure FastAPI exception handlers are only executed once and with a valid context""" + + def setUp(self): + super().setUp() + + self.app = fastapi.FastAPI() + + otel_fastapi.FastAPIInstrumentor().instrument_app( + self.app, exclude_spans=["receive", "send"] + ) + self.client = TestClient(self.app) + self.tracer = self.tracer_provider.get_tracer(__name__) + self.executed = 0 + self.request_trace_id = None + self.error_trace_id = None + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + def test_error_handler_context(self): + """OTEL tracing contexts must be available during error handler + execution, and handlers must only be executed once""" + + status_code = 501 + + @self.app.exception_handler(Exception) + async def _(*_): + self.error_trace_id = ( + trace.get_current_span().get_span_context().trace_id + ) + self.executed += 1 + return PlainTextResponse("", status_code) + + @self.app.get("/foobar") + async def _(): + self.request_trace_id = ( + trace.get_current_span().get_span_context().trace_id + ) + raise UnhandledException("Test Exception") + + try: + self.client.get( + "/foobar", + ) + except UnhandledException: + pass + + self.assertIsNotNone(self.request_trace_id) + self.assertEqual(self.request_trace_id, self.error_trace_id) + + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.name, "GET /foobar") + self.assertEqual(span.attributes.get(HTTP_STATUS_CODE), status_code) + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual(len(span.events), 1) + event = span.events[0] + self.assertEqual(event.name, "exception") + assert event.attributes is not None + self.assertEqual( + event.attributes.get(EXCEPTION_TYPE), + f"{__name__}.UnhandledException", + ) + self.assertEqual(self.executed, 1) + + def test_exception_span_recording(self): + """Exceptions are always recorded in the active span""" + + @self.app.get("/foobar") + async def _(): + raise UnhandledException("Test Exception") + + try: + self.client.get( + "/foobar", + ) + except UnhandledException: + pass + + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.name, "GET /foobar") + self.assertEqual(span.attributes.get(HTTP_STATUS_CODE), 500) + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual(len(span.events), 1) + event = span.events[0] + self.assertEqual(event.name, "exception") + assert event.attributes is not None + self.assertEqual( + event.attributes.get(EXCEPTION_TYPE), + f"{__name__}.UnhandledException", + ) + + def test_middleware_exceptions(self): + """Exceptions from user middlewares are recorded in the active span""" + + @self.app.get("/foobar") + async def _(): + return PlainTextResponse("Hello World") + + @self.app.middleware("http") + async def _(*_): + raise UnhandledException("Test Exception") + + try: + self.client.get( + "/foobar", + ) + except UnhandledException: + pass + + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.name, "GET /foobar") + self.assertEqual(span.attributes.get(HTTP_STATUS_CODE), 500) + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual(len(span.events), 1) + event = span.events[0] + self.assertEqual(event.name, "exception") + assert event.attributes is not None + self.assertEqual( + event.attributes.get(EXCEPTION_TYPE), + f"{__name__}.UnhandledException", + )