|
9 | 9 |
|
10 | 10 | import httpx |
11 | 11 |
|
12 | | -from ._utils import extract_type_var_from_base |
| 12 | +from ._utils import is_mapping, extract_type_var_from_base |
| 13 | +from ._exceptions import APIError |
13 | 14 |
|
14 | 15 | if TYPE_CHECKING: |
15 | 16 | from ._client import Groq, AsyncGroq |
@@ -57,7 +58,43 @@ def __stream__(self) -> Iterator[_T]: |
57 | 58 | for sse in iterator: |
58 | 59 | if sse.data.startswith("[DONE]"): |
59 | 60 | break |
60 | | - yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
| 61 | + |
| 62 | + if sse.event is None: |
| 63 | + data = sse.json() |
| 64 | + if is_mapping(data) and data.get("error"): |
| 65 | + message = None |
| 66 | + error = data.get("error") |
| 67 | + if is_mapping(error): |
| 68 | + message = error.get("message") |
| 69 | + if not message or not isinstance(message, str): |
| 70 | + message = "An error occurred during streaming" |
| 71 | + |
| 72 | + raise APIError( |
| 73 | + message=message, |
| 74 | + request=self.response.request, |
| 75 | + body=data["error"], |
| 76 | + ) |
| 77 | + |
| 78 | + yield process_data(data=data, cast_to=cast_to, response=response) |
| 79 | + |
| 80 | + else: |
| 81 | + data = sse.json() |
| 82 | + |
| 83 | + if sse.event == "error" and is_mapping(data) and data.get("error"): |
| 84 | + message = None |
| 85 | + error = data.get("error") |
| 86 | + if is_mapping(error): |
| 87 | + message = error.get("message") |
| 88 | + if not message or not isinstance(message, str): |
| 89 | + message = "An error occurred during streaming" |
| 90 | + |
| 91 | + raise APIError( |
| 92 | + message=message, |
| 93 | + request=self.response.request, |
| 94 | + body=data["error"], |
| 95 | + ) |
| 96 | + |
| 97 | + yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) |
61 | 98 |
|
62 | 99 | # Ensure the entire stream is consumed |
63 | 100 | for _sse in iterator: |
@@ -123,7 +160,43 @@ async def __stream__(self) -> AsyncIterator[_T]: |
123 | 160 | async for sse in iterator: |
124 | 161 | if sse.data.startswith("[DONE]"): |
125 | 162 | break |
126 | | - yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
| 163 | + |
| 164 | + if sse.event is None: |
| 165 | + data = sse.json() |
| 166 | + if is_mapping(data) and data.get("error"): |
| 167 | + message = None |
| 168 | + error = data.get("error") |
| 169 | + if is_mapping(error): |
| 170 | + message = error.get("message") |
| 171 | + if not message or not isinstance(message, str): |
| 172 | + message = "An error occurred during streaming" |
| 173 | + |
| 174 | + raise APIError( |
| 175 | + message=message, |
| 176 | + request=self.response.request, |
| 177 | + body=data["error"], |
| 178 | + ) |
| 179 | + |
| 180 | + yield process_data(data=data, cast_to=cast_to, response=response) |
| 181 | + |
| 182 | + else: |
| 183 | + data = sse.json() |
| 184 | + |
| 185 | + if sse.event == "error" and is_mapping(data) and data.get("error"): |
| 186 | + message = None |
| 187 | + error = data.get("error") |
| 188 | + if is_mapping(error): |
| 189 | + message = error.get("message") |
| 190 | + if not message or not isinstance(message, str): |
| 191 | + message = "An error occurred during streaming" |
| 192 | + |
| 193 | + raise APIError( |
| 194 | + message=message, |
| 195 | + request=self.response.request, |
| 196 | + body=data["error"], |
| 197 | + ) |
| 198 | + |
| 199 | + yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) |
127 | 200 |
|
128 | 201 | # Ensure the entire stream is consumed |
129 | 202 | async for _sse in iterator: |
|
0 commit comments