Skip to content

Commit 5c8349d

Browse files
committed
feat(client.go): introduce DefaultChunkTimeout and chunkTimeout in Client
- Add DefaultChunkTimeout constant for default chunk processing timeout - Include chunkTimeout field in Client struct to allow custom chunk timeout - Implement ChunkTimeout option function to configure chunkTimeout - Adjust streamReader function to accept timeout parameter - Utilize chunkTimeout for setting up timers in chunk reading operations This enhancement allows finer control over timeout settings specific to chunk processing, improving the client's flexibility and responsiveness in varying network conditions or when dealing with large data streams.
1 parent 1885404 commit 5c8349d

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

openai/client.go

+26-14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ const (
2828
// requests to the OpenAI API. This value can be changed by using the Timeout
2929
// option when creating a new client.
3030
DefaultTimeout = 3 * time.Minute
31+
32+
DefaultChunkTimeout = 5 * time.Second
3133
)
3234

3335
var modelTokens = map[string]int{
@@ -49,6 +51,7 @@ type Client struct {
4951
temperature float32
5052
topP float32
5153
timeout time.Duration
54+
chunkTimeout time.Duration
5255
verbose bool
5356
stream io.Writer
5457
client *openai.Client
@@ -106,6 +109,12 @@ func TopP(topP float32) Option {
106109
}
107110
}
108111

112+
func ChunkTimeout(timeout time.Duration) Option {
113+
return func(m *Client) {
114+
m.chunkTimeout = timeout
115+
}
116+
}
117+
109118
// Timeout is a function that sets the timeout duration for the Client. It
110119
// returns an Option that, when provided to the New function, modifies the
111120
// timeout duration of the created Client instance. The timeout duration
@@ -142,10 +151,11 @@ func Stream(stream io.Writer) Option {
142151
// API requests.
143152
func New(apiToken string, opts ...Option) *Client {
144153
c := Client{
145-
temperature: DefaultTemperature,
146-
topP: DefaultTopP,
147-
timeout: DefaultTimeout,
148-
client: openai.NewClient(apiToken),
154+
temperature: DefaultTemperature,
155+
topP: DefaultTopP,
156+
timeout: DefaultTimeout,
157+
chunkTimeout: DefaultChunkTimeout,
158+
client: openai.NewClient(apiToken),
149159
}
150160
for _, opt := range opts {
151161
opt(&c)
@@ -229,7 +239,7 @@ func (c *Client) createCompletion(ctx context.Context, prompt string) (string, e
229239
if err != nil {
230240
return "", err
231241
}
232-
return streamReader(c, stream).read(ctx, func(stream *openai.ChatCompletionStream) (chunk, error) {
242+
return streamReader(c, stream, c.chunkTimeout).read(ctx, func(stream *openai.ChatCompletionStream) (chunk, error) {
233243
resp, err := stream.Recv()
234244
if err != nil {
235245
return chunk{}, err
@@ -261,7 +271,7 @@ func (c *Client) createCompletion(ctx context.Context, prompt string) (string, e
261271
if err != nil {
262272
return "", err
263273
}
264-
return streamReader(c, stream).read(ctx, func(stream *openai.CompletionStream) (chunk, error) {
274+
return streamReader(c, stream, c.chunkTimeout).read(ctx, func(stream *openai.CompletionStream) (chunk, error) {
265275
resp, err := stream.Recv()
266276
if err != nil {
267277
return chunk{}, err
@@ -289,26 +299,28 @@ func isChatModel(model string) bool {
289299
}
290300

291301
type chunkReader[Stream any] struct {
292-
client *Client
293-
stream Stream
302+
client *Client
303+
stream Stream
304+
timeout time.Duration
294305
}
295306

296-
func streamReader[Stream any](client *Client, stream Stream) chunkReader[Stream] {
297-
return chunkReader[Stream]{
298-
client: client,
299-
stream: stream,
307+
func streamReader[Stream any](client *Client, stream Stream, timeout time.Duration) *chunkReader[Stream] {
308+
return &chunkReader[Stream]{
309+
client: client,
310+
stream: stream,
311+
timeout: timeout,
300312
}
301313
}
302314

303-
func (r chunkReader[Stream]) read(ctx context.Context, getChunk func(Stream) (chunk, error)) (string, error) {
315+
func (r *chunkReader[Stream]) read(ctx context.Context, getChunk func(Stream) (chunk, error)) (string, error) {
304316
var text strings.Builder
305317

306318
if r.client.stream != nil {
307319
fmt.Fprint(r.client.stream, "\n")
308320
}
309321

310322
for {
311-
timeout := time.NewTimer(5 * time.Second)
323+
timeout := time.NewTimer(r.timeout)
312324

313325
chunkC := make(chan chunk)
314326
errC := make(chan error)

0 commit comments

Comments
 (0)