Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 85 additions & 44 deletions getstream/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import time
import uuid
from typing import Any, Dict, Optional, Type, get_origin

from getstream.models import APIError
Expand All @@ -16,7 +17,11 @@
span_request,
current_operation,
metric_attributes,
with_span,
get_current_call_cid,
get_current_channel_cid,
)
import ijson


def build_path(path: str, path_params: dict) -> str:
Expand All @@ -28,6 +33,7 @@ def build_path(path: str, path_params: dict) -> str:


class ResponseParserMixin:
@with_span("parse_response")
def _parse_response(
self, response: httpx.Response, data_type: Type[T]
) -> StreamResponse[T]:
Expand Down Expand Up @@ -89,23 +95,28 @@ def _normalize_endpoint_from_path(self, path: str) -> str:
return ".".join(norm_parts) if norm_parts else "root"

def _prepare_request(self, method: str, path: str, query_params, kwargs):
headers = kwargs.get("headers", {})
path_params = kwargs.get("path_params") if kwargs else None
url_path = (
build_path(path, path_params) if path_params else build_path(path, None)
)
url_full = f"{self.base_url}{url_path}"
endpoint = self._endpoint_name(path)
client_request_id = str(uuid.uuid4())
headers["x-client-request-id"] = client_request_id
kwargs["headers"] = headers
span_attrs = common_attributes(
api_key=self.api_key,
endpoint=endpoint,
method=method,
url=url_full,
client_request_id=client_request_id,
)
# Enrich with contextual IDs when available (set by decorators)
call_cid = getattr(self, "_call_cid", None)
call_cid = get_current_call_cid()
if call_cid:
span_attrs["stream.call_cid"] = call_cid
channel_cid = getattr(self, "_channel_cid", None)
channel_cid = get_current_channel_cid()
if channel_cid:
span_attrs["stream.channel_cid"] = channel_cid
return url_path, url_full, endpoint, span_attrs
Expand Down Expand Up @@ -145,7 +156,14 @@ def _endpoint_name(self, path: str) -> str:
return op or current_operation(self._normalize_endpoint_from_path(path))

def _request_sync(
self, method: str, path: str, *, query_params=None, args=(), kwargs=None
self,
method: str,
path: str,
*,
query_params=None,
args=(),
kwargs=None,
data_type: Optional[Type[T]] = None,
):
kwargs = kwargs or {}
url_path, url_full, endpoint, attrs = self._prepare_request(
Expand All @@ -161,22 +179,26 @@ def _request_sync(
response = getattr(self.client, method.lower())(
url_path, params=query_params, *args, **call_kwargs
)
duration = parse_duration_from_body(response.content)
if duration:
span.set_attribute("http.server.duration", duration)
try:
span and span.set_attribute(
"http.response.status_code", response.status_code
)
except Exception:
pass
duration_ms = (time.perf_counter() - start) * 1000.0
# Metrics should be low-cardinality: exclude url/call_cid/channel_cid
metric_attrs = metric_attributes(
api_key=self.api_key,
endpoint=endpoint,
method=method,
status_code=getattr(response, "status_code", None),
)
record_metrics(duration_ms, attributes=metric_attrs)
return response

duration_ms = (time.perf_counter() - start) * 1000.0
# Metrics should be low-cardinality: exclude url/call_cid/channel_cid
metric_attrs = metric_attributes(
api_key=self.api_key,
endpoint=endpoint,
method=method,
status_code=getattr(response, "status_code", None),
)
record_metrics(duration_ms, attributes=metric_attrs)
return self._parse_response(response, data_type or Dict[str, Any])

def patch(
self,
Expand All @@ -187,14 +209,14 @@ def patch(
*args,
**kwargs,
) -> StreamResponse[T]:
response = self._request_sync(
return self._request_sync(
"PATCH",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

def get(
self,
Expand All @@ -205,14 +227,14 @@ def get(
*args,
**kwargs,
) -> StreamResponse[T]:
response = self._request_sync(
return self._request_sync(
"GET",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

def post(
self,
Expand All @@ -223,14 +245,14 @@ def post(
*args,
**kwargs,
) -> StreamResponse[T]:
response = self._request_sync(
return self._request_sync(
"POST",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

def put(
self,
Expand All @@ -241,14 +263,14 @@ def put(
*args,
**kwargs,
) -> StreamResponse[T]:
response = self._request_sync(
return self._request_sync(
"PUT",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

def delete(
self,
Expand All @@ -259,14 +281,14 @@ def delete(
*args,
**kwargs,
) -> StreamResponse[T]:
response = self._request_sync(
return self._request_sync(
"DELETE",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

def close(self):
"""
Expand Down Expand Up @@ -313,9 +335,17 @@ def _endpoint_name(self, path: str) -> str:
return op or current_operation(self._normalize_endpoint_from_path(path))

async def _request_async(
self, method: str, path: str, *, query_params=None, args=(), kwargs=None
self,
method: str,
path: str,
*,
query_params=None,
args=(),
kwargs=None,
data_type: Optional[Type[T]] = None,
):
kwargs = kwargs or {}
query_params = query_params or {}
url_path, url_full, endpoint, attrs = self._prepare_request(
method, path, query_params, kwargs
)
Expand All @@ -328,22 +358,26 @@ async def _request_async(
response = await getattr(self.client, method.lower())(
url_path, params=query_params, *args, **call_kwargs
)
duration = parse_duration_from_body(response.content)
if duration:
span.set_attribute("http.server.duration", duration)
try:
span and span.set_attribute(
"http.response.status_code", response.status_code
)
except Exception:
pass
duration_ms = (time.perf_counter() - start) * 1000.0
# Metrics should be low-cardinality: exclude url/call_cid/channel_cid
metric_attrs = metric_attributes(
api_key=self.api_key,
endpoint=endpoint,
method=method,
status_code=getattr(response, "status_code", None),
)
record_metrics(duration_ms, attributes=metric_attrs)
return response

duration_ms = (time.perf_counter() - start) * 1000.0
# Metrics should be low-cardinality: exclude url/call_cid/channel_cid
metric_attrs = metric_attributes(
api_key=self.api_key,
endpoint=endpoint,
method=method,
status_code=getattr(response, "status_code", None),
)
record_metrics(duration_ms, attributes=metric_attrs)
return self._parse_response(response, data_type or Dict[str, Any])

async def patch(
self,
Expand All @@ -354,14 +388,14 @@ async def patch(
*args,
**kwargs,
) -> StreamResponse[T]:
response = await self._request_async(
return await self._request_async(
"PATCH",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

async def get(
self,
Expand All @@ -372,14 +406,14 @@ async def get(
*args,
**kwargs,
) -> StreamResponse[T]:
response = await self._request_async(
return await self._request_async(
"GET",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

async def post(
self,
Expand All @@ -390,14 +424,14 @@ async def post(
*args,
**kwargs,
) -> StreamResponse[T]:
response = await self._request_async(
return await self._request_async(
"POST",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

async def put(
self,
Expand All @@ -408,14 +442,14 @@ async def put(
*args,
**kwargs,
) -> StreamResponse[T]:
response = await self._request_async(
return await self._request_async(
"PUT",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])

async def delete(
self,
Expand All @@ -426,14 +460,14 @@ async def delete(
*args,
**kwargs,
) -> StreamResponse[T]:
response = await self._request_async(
return await self._request_async(
"DELETE",
path,
query_params=query_params,
args=args,
kwargs=kwargs | {"path_params": path_params},
data_type=data_type,
)
return self._parse_response(response, data_type or Dict[str, Any])


class StreamAPIException(Exception):
Expand Down Expand Up @@ -478,3 +512,10 @@ def __str__(self) -> str:
return f'Stream error code {self.api_error.code}: {self.api_error.message}"'
else:
return f"Stream error HTTP code: {self.status_code}"


def parse_duration_from_body(body: bytes) -> Optional[str]:
for prefix, event, value in ijson.parse(body):
if prefix == "duration" and event == "string":
return value
return None
Loading