diff --git a/hud/eval/instrument.py b/hud/eval/instrument.py index e950522c..94598f1f 100644 --- a/hud/eval/instrument.py +++ b/hud/eval/instrument.py @@ -1,6 +1,6 @@ -"""Auto-instrumentation for httpx to inject trace headers. +"""Auto-instrumentation for httpx and aiohttp to inject trace headers. -This module patches httpx clients to automatically add: +This module patches HTTP clients to automatically add: - Trace-Id headers when inside an eval context - Authorization headers for HUD API calls """ @@ -8,9 +8,12 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +if TYPE_CHECKING: + from types import SimpleNamespace + from hud.settings import settings logger = logging.getLogger(__name__) @@ -70,7 +73,7 @@ async def _async_httpx_request_hook(request: Any) -> None: _httpx_request_hook(request) -def _instrument_client(client: Any) -> None: +def _instrument_httpx_client(client: Any) -> None: """Add trace hook to an httpx client instance.""" is_async = hasattr(client, "aclose") hook = _async_httpx_request_hook if is_async else _httpx_request_hook @@ -93,7 +96,7 @@ def _patch_httpx() -> None: def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None: _original_async_init(self, *args, **kwargs) - _instrument_client(self) + _instrument_httpx_client(self) httpx.AsyncClient.__init__ = _patched_async_init # type: ignore[method-assign] @@ -101,15 +104,65 @@ def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None: def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None: _original_sync_init(self, *args, **kwargs) - _instrument_client(self) + _instrument_httpx_client(self) httpx.Client.__init__ = _patched_sync_init # type: ignore[method-assign] logger.debug("httpx auto-instrumentation enabled") -# Auto-patch httpx on module import +def _patch_aiohttp() -> None: + """ + Monkey-patch aiohttp to auto-instrument all ClientSession instances. + This is important for the Gemini client in particular, which uses aiohttp by default. + """ + try: + import aiohttp + except ImportError: + logger.debug("aiohttp not installed, skipping auto-instrumentation") + return + + async def on_request_start( + _session: aiohttp.ClientSession, + _trace_config_ctx: SimpleNamespace, + params: aiohttp.TraceRequestStartParams, + ) -> None: + """aiohttp trace hook that adds trace headers and auth to HUD requests.""" + url_str = str(params.url) + if not _is_hud_url(url_str): + return + + trace_headers = _get_trace_headers() + if trace_headers is not None: + for key, value in trace_headers.items(): + params.headers[key] = value + logger.debug("Added trace headers to aiohttp request: %s", url_str) + + has_auth = "authorization" in {k.lower() for k in params.headers} + if not has_auth and settings.api_key: + params.headers["Authorization"] = f"Bearer {settings.api_key}" + logger.debug("Added API key auth to aiohttp request: %s", url_str) + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(on_request_start) + + _original_init = aiohttp.ClientSession.__init__ + + def _patched_init(self: aiohttp.ClientSession, *args: Any, **kwargs: Any) -> None: + existing_traces = kwargs.get("trace_configs") or [] + if trace_config not in existing_traces: + existing_traces = [*list(existing_traces), trace_config] + kwargs["trace_configs"] = existing_traces + _original_init(self, *args, **kwargs) + + aiohttp.ClientSession.__init__ = _patched_init # type: ignore[method-assign] + + logger.debug("aiohttp auto-instrumentation enabled") + + +# Auto-patch on module import _patch_httpx() +_patch_aiohttp() -__all__ = ["_patch_httpx"] +__all__ = ["_patch_aiohttp", "_patch_httpx"]