Skip to content

Commit c32b738

Browse files
outergodalexmojaki
andauthored
Rewrite FastAPI instrumentor middleware stack to be failsafe (#3664)
* rewrite FastAPIInstrumentor:build_middleware_stack to become failsafe * add test cases for FastAPI failsafe handling * add CHANGELOG entry * remove unused import * [lint] don't return from failsafe wrapper * [lint] allow broad exceptions * [lint] more allowing * record FastAPI hook exceptions in active span * remove comment * properly deal with hooks not being set * add custom FastAPI exception recording * move failsafe hook handling down to OpenTelemetryMiddleware * shut up pylint * optimize failsafe to check for `None` only once * remove confusing comment and simplify wrapper logic * add clarifying comment * test proper exception / status code recording * add HTTP status code check * test HTTP status on the exception recording span * improve test by removing TypeError * rectify comment/explanation on inner middleware for exception handling * minor typo * move ExceptionHandlingMiddleware as the outermost inner middleware Also improve code documentation and add another test. * use distinct status code in test * improve comemnt Co-authored-by: Alex Hall <[email protected]> * narrow down exception handling Co-authored-by: Alex Hall <[email protected]> * narrow down FastAPI exception tests to relevant spans * collapse tests, more narrow exceptions * move failsafe hook tests to ASGI test suite * update CHANGELOG * satisfy linter * don't record exception if span is not recording * add test for unhappy instrumentation codepath * make inject fixtures private * give up and shut up pylint * improve instrumentation failure error message and add test --------- Co-authored-by: Alex Hall <[email protected]>
1 parent 60b9035 commit c32b738

File tree

5 files changed

+342
-35
lines changed

5 files changed

+342
-35
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Fixed
1515

16+
- `opentelemetry-instrumentation-fastapi`: Fix middleware ordering to cover all exception handling use cases.
17+
([#3664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3664))
18+
- `opentelemetry-instrumentation-asgi`: Make all user hooks failsafe and record exceptions in hooks.
19+
([#3664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3664))
1620
- `opentelemetry-instrumentation-fastapi`: Fix memory leak in `uninstrument_app()` by properly removing apps from the tracking set
1721
([#3688](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3688)
1822
- `opentelemetry-instrumentation-tornado` Fix server (request) duration metric calculation

instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def client_response_hook(span: Span, scope: Scope, message: dict[str, Any]):
268268
HTTP_SERVER_REQUEST_DURATION,
269269
)
270270
from opentelemetry.semconv.trace import SpanAttributes
271-
from opentelemetry.trace import set_span_in_context
271+
from opentelemetry.trace import Span, set_span_in_context
272272
from opentelemetry.util.http import (
273273
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS,
274274
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST,
@@ -646,9 +646,23 @@ def __init__(
646646
self.default_span_details = (
647647
default_span_details or get_default_span_details
648648
)
649-
self.server_request_hook = server_request_hook
650-
self.client_request_hook = client_request_hook
651-
self.client_response_hook = client_response_hook
649+
650+
def failsafe(func):
651+
if func is None:
652+
return None
653+
654+
@wraps(func)
655+
def wrapper(span: Span, *args, **kwargs):
656+
try:
657+
func(span, *args, **kwargs)
658+
except Exception as exc: # pylint: disable=broad-exception-caught
659+
span.record_exception(exc)
660+
661+
return wrapper
662+
663+
self.server_request_hook = failsafe(server_request_hook)
664+
self.client_request_hook = failsafe(client_request_hook)
665+
self.client_response_hook = failsafe(client_response_hook)
652666
self.content_length_header = None
653667
self._sem_conv_opt_in_mode = sem_conv_opt_in_mode
654668

instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,17 @@ async def error_asgi(scope, receive, send):
277277
await send({"type": "http.response.body", "body": b"*"})
278278

279279

280+
class UnhandledException(Exception):
281+
pass
282+
283+
284+
def failing_hook(msg):
285+
def hook(*_):
286+
raise UnhandledException(msg)
287+
288+
return hook
289+
290+
280291
# pylint: disable=too-many-public-methods
281292
class TestAsgiApplication(AsyncAsgiTestBase):
282293
def setUp(self):
@@ -481,6 +492,12 @@ def validate_outputs(
481492
span.instrumentation_scope.name,
482493
"opentelemetry.instrumentation.asgi",
483494
)
495+
if "events" in expected:
496+
self.assertEqual(len(span.events), len(expected["events"]))
497+
for event, expected in zip(span.events, expected["events"]):
498+
self.assertEqual(event.name, expected["name"])
499+
for name, value in expected["attributes"].items():
500+
self.assertEqual(event.attributes[name], value)
484501

485502
async def test_basic_asgi_call(self):
486503
"""Test that spans are emitted as expected."""
@@ -1206,6 +1223,40 @@ def update_expected_hook_results(expected):
12061223
outputs, modifiers=[update_expected_hook_results]
12071224
)
12081225

1226+
async def test_hook_exceptions(self):
1227+
def exception_event(msg):
1228+
return {
1229+
"name": "exception",
1230+
"attributes": {
1231+
"exception.type": f"{__name__}.UnhandledException",
1232+
"exception.message": msg,
1233+
},
1234+
}
1235+
1236+
def update_expected_hook_results(expected):
1237+
for entry in expected:
1238+
if entry["kind"] == trace_api.SpanKind.SERVER:
1239+
entry["events"] = [exception_event("server request")]
1240+
elif entry["name"] == "GET / http receive":
1241+
entry["events"] = [exception_event("client request")]
1242+
elif entry["name"] == "GET / http send":
1243+
entry["events"] = [exception_event("client response")]
1244+
1245+
return expected
1246+
1247+
app = otel_asgi.OpenTelemetryMiddleware(
1248+
simple_asgi,
1249+
server_request_hook=failing_hook("server request"),
1250+
client_request_hook=failing_hook("client request"),
1251+
client_response_hook=failing_hook("client response"),
1252+
)
1253+
self.seed_app(app)
1254+
await self.send_default_request()
1255+
outputs = await self.get_all_output()
1256+
self.validate_outputs(
1257+
outputs, modifiers=[update_expected_hook_results]
1258+
)
1259+
12091260
async def test_asgi_metrics(self):
12101261
app = otel_asgi.OpenTelemetryMiddleware(simple_asgi)
12111262
self.seed_app(app)

instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
192192
from starlette.applications import Starlette
193193
from starlette.middleware.errors import ServerErrorMiddleware
194194
from starlette.routing import Match
195-
from starlette.types import ASGIApp
195+
from starlette.types import ASGIApp, Receive, Scope, Send
196196

197197
from opentelemetry.instrumentation._semconv import (
198198
_get_schema_url,
@@ -211,7 +211,8 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
211211
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
212212
from opentelemetry.metrics import MeterProvider, get_meter
213213
from opentelemetry.semconv.attributes.http_attributes import HTTP_ROUTE
214-
from opentelemetry.trace import TracerProvider, get_tracer
214+
from opentelemetry.trace import TracerProvider, get_current_span, get_tracer
215+
from opentelemetry.trace.status import Status, StatusCode
215216
from opentelemetry.util.http import (
216217
get_excluded_urls,
217218
parse_excluded_urls,
@@ -243,7 +244,7 @@ def instrument_app(
243244
http_capture_headers_server_response: list[str] | None = None,
244245
http_capture_headers_sanitize_fields: list[str] | None = None,
245246
exclude_spans: list[Literal["receive", "send"]] | None = None,
246-
):
247+
): # pylint: disable=too-many-locals
247248
"""Instrument an uninstrumented FastAPI application.
248249
249250
Args:
@@ -290,17 +291,80 @@ def instrument_app(
290291
schema_url=_get_schema_url(sem_conv_opt_in_mode),
291292
)
292293

293-
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware
294-
# as the outermost middleware.
295-
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able
296-
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do.
297-
298294
def build_middleware_stack(self: Starlette) -> ASGIApp:
299-
inner_server_error_middleware: ASGIApp = ( # type: ignore
295+
# Define an additional middleware for exception handling
296+
# Normally, `opentelemetry.trace.use_span` covers the recording of
297+
# exceptions into the active span, but `OpenTelemetryMiddleware`
298+
# ends the span too early before the exception can be recorded.
299+
class ExceptionHandlerMiddleware:
300+
def __init__(self, app):
301+
self.app = app
302+
303+
async def __call__(
304+
self, scope: Scope, receive: Receive, send: Send
305+
) -> None:
306+
try:
307+
await self.app(scope, receive, send)
308+
except Exception as exc: # pylint: disable=broad-exception-caught
309+
span = get_current_span()
310+
if span.is_recording():
311+
span.record_exception(exc)
312+
span.set_status(
313+
Status(
314+
status_code=StatusCode.ERROR,
315+
description=f"{type(exc).__name__}: {exc}",
316+
)
317+
)
318+
raise
319+
320+
# For every possible use case of error handling, exception
321+
# handling, trace availability in exception handlers and
322+
# automatic exception recording to work, we need to make a
323+
# series of wrapping and re-wrapping middlewares.
324+
325+
# First, grab the original middleware stack from Starlette. It
326+
# comprises a stack of
327+
# `ServerErrorMiddleware` -> [user defined middlewares] -> `ExceptionMiddleware`
328+
inner_server_error_middleware: ServerErrorMiddleware = ( # type: ignore
300329
self._original_build_middleware_stack() # type: ignore
301330
)
331+
332+
if not isinstance(
333+
inner_server_error_middleware, ServerErrorMiddleware
334+
):
335+
# Oops, something changed about how Starlette creates middleware stacks
336+
_logger.error(
337+
"Skipping FastAPI instrumentation due to unexpected middleware stack: expected %s, got %s",
338+
ServerErrorMiddleware.__name__,
339+
type(inner_server_error_middleware),
340+
)
341+
return inner_server_error_middleware
342+
343+
# We take [user defined middlewares] -> `ExceptionHandlerMiddleware`
344+
# out of the outermost `ServerErrorMiddleware` and instead pass
345+
# it to our own `ExceptionHandlerMiddleware`
346+
exception_middleware = ExceptionHandlerMiddleware(
347+
inner_server_error_middleware.app
348+
)
349+
350+
# Now, we create a new `ServerErrorMiddleware` that wraps
351+
# `ExceptionHandlerMiddleware` but otherwise uses the same
352+
# original `handler` and debug setting. The end result is a
353+
# middleware stack that's identical to the original stack except
354+
# all user middlewares are covered by our
355+
# `ExceptionHandlerMiddleware`.
356+
error_middleware = ServerErrorMiddleware(
357+
app=exception_middleware,
358+
handler=inner_server_error_middleware.handler,
359+
debug=inner_server_error_middleware.debug,
360+
)
361+
362+
# Finally, we wrap the stack above in our actual OTEL
363+
# middleware. As a result, an active tracing context exists for
364+
# every use case of user-defined error and exception handlers as
365+
# well as automatic recording of exceptions in active spans.
302366
otel_middleware = OpenTelemetryMiddleware(
303-
inner_server_error_middleware,
367+
error_middleware,
304368
excluded_urls=excluded_urls,
305369
default_span_details=_get_default_span_details,
306370
server_request_hook=server_request_hook,
@@ -314,23 +378,18 @@ def build_middleware_stack(self: Starlette) -> ASGIApp:
314378
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
315379
exclude_spans=exclude_spans,
316380
)
317-
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware
318-
# are handled.
319-
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that
320-
# to impact the user's application just because we wrapped the middlewares in this order.
321-
if isinstance(
322-
inner_server_error_middleware, ServerErrorMiddleware
323-
): # usually true
324-
outer_server_error_middleware = ServerErrorMiddleware(
325-
app=otel_middleware,
326-
)
327-
else:
328-
# Something else seems to have patched things, or maybe Starlette changed.
329-
# Just create a default ServerErrorMiddleware.
330-
outer_server_error_middleware = ServerErrorMiddleware(
331-
app=otel_middleware
332-
)
333-
return outer_server_error_middleware
381+
382+
# Ultimately, wrap everything in another default
383+
# `ServerErrorMiddleware` (w/o user handlers) so that any
384+
# exceptions raised in `OpenTelemetryMiddleware` are handled.
385+
#
386+
# This should not happen unless there is a bug in
387+
# OpenTelemetryMiddleware, but if there is we don't want that to
388+
# impact the user's application just because we wrapped the
389+
# middlewares in this order.
390+
return ServerErrorMiddleware(
391+
app=otel_middleware,
392+
)
334393

335394
app._original_build_middleware_stack = app.build_middleware_stack
336395
app.build_middleware_stack = types.MethodType(

0 commit comments

Comments
 (0)