Skip to content

Commit 4baf6ef

Browse files
Graden Reastainless-app[bot]
authored andcommitted
Add streaming support
1 parent 36d47a3 commit 4baf6ef

File tree

2 files changed

+199
-3
lines changed

2 files changed

+199
-3
lines changed

src/groq/_streaming.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __stream__(self) -> Iterator[_T]:
5353
iterator = self._iter_events()
5454

5555
for sse in iterator:
56+
if sse.data.startswith("[DONE]"):
57+
break
5658
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
5759

5860
# Ensure the entire stream is consumed
@@ -106,6 +108,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:
106108

107109
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
108110
async for sse in self._decoder.aiter(self.response.aiter_lines()):
111+
if sse.data.startswith("[DONE]"):
112+
break
109113
yield sse
110114

111115
async def __stream__(self) -> AsyncIterator[_T]:

src/groq/resources/chat/completions.py

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from __future__ import annotations
44

5-
from typing import Dict, List, Union, Iterable, Optional
5+
from typing import Dict, List, Literal, Union, Iterable, Optional, overload
66

77
import httpx
88

9+
from ...lib.chat_completion_chunk import ChatCompletionChunk
910
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
1011
from ..._utils import maybe_transform
1112
from ..._compat import cached_property
@@ -16,6 +17,7 @@
1617
async_to_raw_response_wrapper,
1718
async_to_streamed_response_wrapper,
1819
)
20+
from ..._streaming import AsyncStream, Stream
1921
from ...types.chat import ChatCompletion, completion_create_params
2022
from ..._base_client import (
2123
make_request_options,
@@ -33,6 +35,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse:
3335
def with_streaming_response(self) -> CompletionsWithStreamingResponse:
3436
return CompletionsWithStreamingResponse(self)
3537

38+
@overload
3639
def create(
3740
self,
3841
*,
@@ -47,7 +50,7 @@ def create(
4750
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
4851
seed: int | NotGiven = NOT_GIVEN,
4952
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
50-
stream: bool | NotGiven = NOT_GIVEN,
53+
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
5154
temperature: float | NotGiven = NOT_GIVEN,
5255
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
5356
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -61,6 +64,98 @@ def create(
6164
extra_body: Body | None = None,
6265
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
6366
) -> ChatCompletion:
67+
...
68+
69+
@overload
70+
def create(
71+
self,
72+
*,
73+
frequency_penalty: float | NotGiven = NOT_GIVEN,
74+
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
75+
logprobs: bool | NotGiven = NOT_GIVEN,
76+
max_tokens: int | NotGiven = NOT_GIVEN,
77+
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
78+
model: str | NotGiven = NOT_GIVEN,
79+
n: int | NotGiven = NOT_GIVEN,
80+
presence_penalty: float | NotGiven = NOT_GIVEN,
81+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
82+
seed: int | NotGiven = NOT_GIVEN,
83+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
84+
stream: Literal[True],
85+
temperature: float | NotGiven = NOT_GIVEN,
86+
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
87+
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
88+
top_logprobs: int | NotGiven = NOT_GIVEN,
89+
top_p: float | NotGiven = NOT_GIVEN,
90+
user: str | NotGiven = NOT_GIVEN,
91+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
92+
# The extra values given here take precedence over values defined on the client or passed to this method.
93+
extra_headers: Headers | None = None,
94+
extra_query: Query | None = None,
95+
extra_body: Body | None = None,
96+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
97+
) -> Stream[ChatCompletionChunk]:
98+
...
99+
100+
@overload
101+
def create(
102+
self,
103+
*,
104+
frequency_penalty: float | NotGiven = NOT_GIVEN,
105+
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
106+
logprobs: bool | NotGiven = NOT_GIVEN,
107+
max_tokens: int | NotGiven = NOT_GIVEN,
108+
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
109+
model: str | NotGiven = NOT_GIVEN,
110+
n: int | NotGiven = NOT_GIVEN,
111+
presence_penalty: float | NotGiven = NOT_GIVEN,
112+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
113+
seed: int | NotGiven = NOT_GIVEN,
114+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
115+
stream: bool,
116+
temperature: float | NotGiven = NOT_GIVEN,
117+
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
118+
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
119+
top_logprobs: int | NotGiven = NOT_GIVEN,
120+
top_p: float | NotGiven = NOT_GIVEN,
121+
user: str | NotGiven = NOT_GIVEN,
122+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
123+
# The extra values given here take precedence over values defined on the client or passed to this method.
124+
extra_headers: Headers | None = None,
125+
extra_query: Query | None = None,
126+
extra_body: Body | None = None,
127+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
128+
) -> ChatCompletion | Stream[ChatCompletionChunk]:
129+
...
130+
131+
def create(
132+
self,
133+
*,
134+
frequency_penalty: float | NotGiven = NOT_GIVEN,
135+
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
136+
logprobs: bool | NotGiven = NOT_GIVEN,
137+
max_tokens: int | NotGiven = NOT_GIVEN,
138+
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
139+
model: str | NotGiven = NOT_GIVEN,
140+
n: int | NotGiven = NOT_GIVEN,
141+
presence_penalty: float | NotGiven = NOT_GIVEN,
142+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
143+
seed: int | NotGiven = NOT_GIVEN,
144+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
145+
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
146+
temperature: float | NotGiven = NOT_GIVEN,
147+
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
148+
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
149+
top_logprobs: int | NotGiven = NOT_GIVEN,
150+
top_p: float | NotGiven = NOT_GIVEN,
151+
user: str | NotGiven = NOT_GIVEN,
152+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
153+
# The extra values given here take precedence over values defined on the client or passed to this method.
154+
extra_headers: Headers | None = None,
155+
extra_query: Query | None = None,
156+
extra_body: Body | None = None,
157+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
158+
) -> ChatCompletion | Stream[ChatCompletionChunk]:
64159
"""
65160
Creates a completion for a chat prompt
66161
@@ -105,6 +200,8 @@ def create(
105200
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
106201
),
107202
cast_to=ChatCompletion,
203+
stream=stream or False,
204+
stream_cls=Stream[ChatCompletionChunk],
108205
)
109206

110207

@@ -117,6 +214,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
117214
def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
118215
return AsyncCompletionsWithStreamingResponse(self)
119216

217+
@overload
120218
async def create(
121219
self,
122220
*,
@@ -131,7 +229,7 @@ async def create(
131229
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
132230
seed: int | NotGiven = NOT_GIVEN,
133231
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
134-
stream: bool | NotGiven = NOT_GIVEN,
232+
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
135233
temperature: float | NotGiven = NOT_GIVEN,
136234
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
137235
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
@@ -145,6 +243,98 @@ async def create(
145243
extra_body: Body | None = None,
146244
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
147245
) -> ChatCompletion:
246+
...
247+
248+
@overload
249+
async def create(
250+
self,
251+
*,
252+
frequency_penalty: float | NotGiven = NOT_GIVEN,
253+
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
254+
logprobs: bool | NotGiven = NOT_GIVEN,
255+
max_tokens: int | NotGiven = NOT_GIVEN,
256+
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
257+
model: str | NotGiven = NOT_GIVEN,
258+
n: int | NotGiven = NOT_GIVEN,
259+
presence_penalty: float | NotGiven = NOT_GIVEN,
260+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
261+
seed: int | NotGiven = NOT_GIVEN,
262+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
263+
stream: Literal[True],
264+
temperature: float | NotGiven = NOT_GIVEN,
265+
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
266+
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
267+
top_logprobs: int | NotGiven = NOT_GIVEN,
268+
top_p: float | NotGiven = NOT_GIVEN,
269+
user: str | NotGiven = NOT_GIVEN,
270+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
271+
# The extra values given here take precedence over values defined on the client or passed to this method.
272+
extra_headers: Headers | None = None,
273+
extra_query: Query | None = None,
274+
extra_body: Body | None = None,
275+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
276+
) -> AsyncStream[ChatCompletionChunk]:
277+
...
278+
279+
@overload
280+
async def create(
281+
self,
282+
*,
283+
frequency_penalty: float | NotGiven = NOT_GIVEN,
284+
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
285+
logprobs: bool | NotGiven = NOT_GIVEN,
286+
max_tokens: int | NotGiven = NOT_GIVEN,
287+
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
288+
model: str | NotGiven = NOT_GIVEN,
289+
n: int | NotGiven = NOT_GIVEN,
290+
presence_penalty: float | NotGiven = NOT_GIVEN,
291+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
292+
seed: int | NotGiven = NOT_GIVEN,
293+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
294+
stream: bool,
295+
temperature: float | NotGiven = NOT_GIVEN,
296+
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
297+
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
298+
top_logprobs: int | NotGiven = NOT_GIVEN,
299+
top_p: float | NotGiven = NOT_GIVEN,
300+
user: str | NotGiven = NOT_GIVEN,
301+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
302+
# The extra values given here take precedence over values defined on the client or passed to this method.
303+
extra_headers: Headers | None = None,
304+
extra_query: Query | None = None,
305+
extra_body: Body | None = None,
306+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
307+
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
308+
...
309+
310+
async def create(
311+
self,
312+
*,
313+
frequency_penalty: float | NotGiven = NOT_GIVEN,
314+
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
315+
logprobs: bool | NotGiven = NOT_GIVEN,
316+
max_tokens: int | NotGiven = NOT_GIVEN,
317+
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
318+
model: str | NotGiven = NOT_GIVEN,
319+
n: int | NotGiven = NOT_GIVEN,
320+
presence_penalty: float | NotGiven = NOT_GIVEN,
321+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
322+
seed: int | NotGiven = NOT_GIVEN,
323+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
324+
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
325+
temperature: float | NotGiven = NOT_GIVEN,
326+
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
327+
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
328+
top_logprobs: int | NotGiven = NOT_GIVEN,
329+
top_p: float | NotGiven = NOT_GIVEN,
330+
user: str | NotGiven = NOT_GIVEN,
331+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
332+
# The extra values given here take precedence over values defined on the client or passed to this method.
333+
extra_headers: Headers | None = None,
334+
extra_query: Query | None = None,
335+
extra_body: Body | None = None,
336+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
337+
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
148338
"""
149339
Creates a completion for a chat prompt
150340
@@ -189,6 +379,8 @@ async def create(
189379
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
190380
),
191381
cast_to=ChatCompletion,
382+
stream=stream or False,
383+
stream_cls=AsyncStream[ChatCompletionChunk],
192384
)
193385

194386

0 commit comments

Comments
 (0)