Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions hud/eval/instrument.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Auto-instrumentation for httpx to inject trace headers.
"""Auto-instrumentation for httpx and aiohttp to inject trace headers.

This module patches httpx clients to automatically add:
This module patches HTTP clients to automatically add:
- Trace-Id headers when inside an eval context
- Authorization headers for HUD API calls
"""

from __future__ import annotations

import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

if TYPE_CHECKING:
from types import SimpleNamespace

from hud.settings import settings

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +73,7 @@ async def _async_httpx_request_hook(request: Any) -> None:
_httpx_request_hook(request)


def _instrument_client(client: Any) -> None:
def _instrument_httpx_client(client: Any) -> None:
"""Add trace hook to an httpx client instance."""
is_async = hasattr(client, "aclose")
hook = _async_httpx_request_hook if is_async else _httpx_request_hook
Expand All @@ -93,23 +96,73 @@ def _patch_httpx() -> None:

def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None:
_original_async_init(self, *args, **kwargs)
_instrument_client(self)
_instrument_httpx_client(self)

httpx.AsyncClient.__init__ = _patched_async_init # type: ignore[method-assign]

_original_sync_init = httpx.Client.__init__

def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None:
_original_sync_init(self, *args, **kwargs)
_instrument_client(self)
_instrument_httpx_client(self)

httpx.Client.__init__ = _patched_sync_init # type: ignore[method-assign]

logger.debug("httpx auto-instrumentation enabled")


# Auto-patch httpx on module import
def _patch_aiohttp() -> None:
"""
Monkey-patch aiohttp to auto-instrument all ClientSession instances.
This is important for the Gemini client in particular, which uses aiohttp by default.
"""
try:
import aiohttp
except ImportError:
logger.debug("aiohttp not installed, skipping auto-instrumentation")
return

async def on_request_start(
_session: aiohttp.ClientSession,
_trace_config_ctx: SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
) -> None:
"""aiohttp trace hook that adds trace headers and auth to HUD requests."""
url_str = str(params.url)
if not _is_hud_url(url_str):
return

trace_headers = _get_trace_headers()
if trace_headers is not None:
for key, value in trace_headers.items():
params.headers[key] = value
logger.debug("Added trace headers to aiohttp request: %s", url_str)

has_auth = "authorization" in {k.lower() for k in params.headers}
if not has_auth and settings.api_key:
params.headers["Authorization"] = f"Bearer {settings.api_key}"
logger.debug("Added API key auth to aiohttp request: %s", url_str)

trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_request_start)

_original_init = aiohttp.ClientSession.__init__

def _patched_init(self: aiohttp.ClientSession, *args: Any, **kwargs: Any) -> None:
existing_traces = kwargs.get("trace_configs") or []
if trace_config not in existing_traces:
existing_traces = [*list(existing_traces), trace_config]
kwargs["trace_configs"] = existing_traces
_original_init(self, *args, **kwargs)

aiohttp.ClientSession.__init__ = _patched_init # type: ignore[method-assign]

logger.debug("aiohttp auto-instrumentation enabled")


# Auto-patch on module import
_patch_httpx()
_patch_aiohttp()


__all__ = ["_patch_httpx"]
__all__ = ["_patch_aiohttp", "_patch_httpx"]
Loading