@@ -28,6 +28,8 @@ const (
28
28
// requests to the OpenAI API. This value can be changed by using the Timeout
29
29
// option when creating a new client.
30
30
DefaultTimeout = 3 * time .Minute
31
+
32
+ DefaultChunkTimeout = 5 * time .Second
31
33
)
32
34
33
35
var modelTokens = map [string ]int {
@@ -49,6 +51,7 @@ type Client struct {
49
51
temperature float32
50
52
topP float32
51
53
timeout time.Duration
54
+ chunkTimeout time.Duration
52
55
verbose bool
53
56
stream io.Writer
54
57
client * openai.Client
@@ -106,6 +109,12 @@ func TopP(topP float32) Option {
106
109
}
107
110
}
108
111
112
+ func ChunkTimeout (timeout time.Duration ) Option {
113
+ return func (m * Client ) {
114
+ m .chunkTimeout = timeout
115
+ }
116
+ }
117
+
109
118
// Timeout is a function that sets the timeout duration for the Client. It
110
119
// returns an Option that, when provided to the New function, modifies the
111
120
// timeout duration of the created Client instance. The timeout duration
@@ -142,10 +151,11 @@ func Stream(stream io.Writer) Option {
142
151
// API requests.
143
152
func New (apiToken string , opts ... Option ) * Client {
144
153
c := Client {
145
- temperature : DefaultTemperature ,
146
- topP : DefaultTopP ,
147
- timeout : DefaultTimeout ,
148
- client : openai .NewClient (apiToken ),
154
+ temperature : DefaultTemperature ,
155
+ topP : DefaultTopP ,
156
+ timeout : DefaultTimeout ,
157
+ chunkTimeout : DefaultChunkTimeout ,
158
+ client : openai .NewClient (apiToken ),
149
159
}
150
160
for _ , opt := range opts {
151
161
opt (& c )
@@ -229,7 +239,7 @@ func (c *Client) createCompletion(ctx context.Context, prompt string) (string, e
229
239
if err != nil {
230
240
return "" , err
231
241
}
232
- return streamReader (c , stream ).read (ctx , func (stream * openai.ChatCompletionStream ) (chunk , error ) {
242
+ return streamReader (c , stream , c . chunkTimeout ).read (ctx , func (stream * openai.ChatCompletionStream ) (chunk , error ) {
233
243
resp , err := stream .Recv ()
234
244
if err != nil {
235
245
return chunk {}, err
@@ -261,7 +271,7 @@ func (c *Client) createCompletion(ctx context.Context, prompt string) (string, e
261
271
if err != nil {
262
272
return "" , err
263
273
}
264
- return streamReader (c , stream ).read (ctx , func (stream * openai.CompletionStream ) (chunk , error ) {
274
+ return streamReader (c , stream , c . chunkTimeout ).read (ctx , func (stream * openai.CompletionStream ) (chunk , error ) {
265
275
resp , err := stream .Recv ()
266
276
if err != nil {
267
277
return chunk {}, err
@@ -289,26 +299,28 @@ func isChatModel(model string) bool {
289
299
}
290
300
291
301
type chunkReader [Stream any ] struct {
292
- client * Client
293
- stream Stream
302
+ client * Client
303
+ stream Stream
304
+ timeout time.Duration
294
305
}
295
306
296
- func streamReader [Stream any ](client * Client , stream Stream ) chunkReader [Stream ] {
297
- return chunkReader [Stream ]{
298
- client : client ,
299
- stream : stream ,
307
+ func streamReader [Stream any ](client * Client , stream Stream , timeout time.Duration ) * chunkReader [Stream ] {
308
+ return & chunkReader [Stream ]{
309
+ client : client ,
310
+ stream : stream ,
311
+ timeout : timeout ,
300
312
}
301
313
}
302
314
303
- func (r chunkReader [Stream ]) read (ctx context.Context , getChunk func (Stream ) (chunk , error )) (string , error ) {
315
+ func (r * chunkReader [Stream ]) read (ctx context.Context , getChunk func (Stream ) (chunk , error )) (string , error ) {
304
316
var text strings.Builder
305
317
306
318
if r .client .stream != nil {
307
319
fmt .Fprint (r .client .stream , "\n " )
308
320
}
309
321
310
322
for {
311
- timeout := time .NewTimer (5 * time . Second )
323
+ timeout := time .NewTimer (r . timeout )
312
324
313
325
chunkC := make (chan chunk )
314
326
errC := make (chan error )
0 commit comments