@@ -2,13 +2,13 @@ use std::pin::Pin;
22
33use bytes:: Bytes ;
44use futures:: { stream:: StreamExt , Stream } ;
5- use reqwest:: multipart:: Form ;
6- use reqwest_eventsource:: { Event , EventSource , RequestBuilderExt } ;
5+ use reqwest:: { multipart:: Form , Response } ;
6+ use reqwest_eventsource:: { Error as EventSourceError , Event , EventSource , RequestBuilderExt } ;
77use serde:: { de:: DeserializeOwned , Serialize } ;
88
99use crate :: {
1010 config:: { Config , OpenAIConfig } ,
11- error:: { map_deserialization_error, ApiError , OpenAIError , WrappedError } ,
11+ error:: { map_deserialization_error, ApiError , OpenAIError , StreamError , WrappedError } ,
1212 file:: Files ,
1313 image:: Images ,
1414 moderation:: Moderations ,
@@ -335,52 +335,34 @@ impl<C: Config> Client<C> {
335335 . map_err ( backoff:: Error :: Permanent ) ?;
336336
337337 let status = response. status ( ) ;
338- let bytes = response
339- . bytes ( )
340- . await
341- . map_err ( OpenAIError :: Reqwest )
342- . map_err ( backoff:: Error :: Permanent ) ?;
343338
344- if status. is_server_error ( ) {
345- // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
346- let message: String = String :: from_utf8_lossy ( & bytes) . into_owned ( ) ;
347- tracing:: warn!( "Server error: {status} - {message}" ) ;
348- return Err ( backoff:: Error :: Transient {
349- err : OpenAIError :: ApiError ( ApiError {
350- message,
351- r#type : None ,
352- param : None ,
353- code : None ,
354- } ) ,
355- retry_after : None ,
356- } ) ;
357- }
358-
359- // Deserialize response body from either error object or actual response object
360- if !status. is_success ( ) {
361- let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
362- . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) )
363- . map_err ( backoff:: Error :: Permanent ) ?;
364-
365- if status. as_u16 ( ) == 429
366- // API returns 429 also when:
367- // "You exceeded your current quota, please check your plan and billing details."
368- && wrapped_error. error . r#type != Some ( "insufficient_quota" . to_string ( ) )
369- {
370- // Rate limited retry...
371- tracing:: warn!( "Rate limited: {}" , wrapped_error. error. message) ;
372- return Err ( backoff:: Error :: Transient {
373- err : OpenAIError :: ApiError ( wrapped_error. error ) ,
374- retry_after : None ,
375- } ) ;
376- } else {
377- return Err ( backoff:: Error :: Permanent ( OpenAIError :: ApiError (
378- wrapped_error. error ,
379- ) ) ) ;
339+ match read_response ( response) . await {
340+ Ok ( bytes) => Ok ( bytes) ,
341+ Err ( e) => {
342+ match e {
343+ OpenAIError :: ApiError ( api_error) => {
344+ if status. is_server_error ( ) {
345+ Err ( backoff:: Error :: Transient {
346+ err : OpenAIError :: ApiError ( api_error) ,
347+ retry_after : None ,
348+ } )
349+ } else if status. as_u16 ( ) == 429
350+ && api_error. r#type != Some ( "insufficient_quota" . to_string ( ) )
351+ {
352+ // Rate limited retry...
353+ tracing:: warn!( "Rate limited: {}" , api_error. message) ;
354+ Err ( backoff:: Error :: Transient {
355+ err : OpenAIError :: ApiError ( api_error) ,
356+ retry_after : None ,
357+ } )
358+ } else {
359+ Err ( backoff:: Error :: Permanent ( OpenAIError :: ApiError ( api_error) ) )
360+ }
361+ }
362+ _ => Err ( backoff:: Error :: Permanent ( e) ) ,
363+ }
380364 }
381365 }
382-
383- Ok ( bytes)
384366 } )
385367 . await
386368 }
@@ -471,6 +453,53 @@ impl<C: Config> Client<C> {
471453 }
472454}
473455
456+ async fn read_response ( response : Response ) -> Result < Bytes , OpenAIError > {
457+ let status = response. status ( ) ;
458+ let bytes = response. bytes ( ) . await . map_err ( OpenAIError :: Reqwest ) ?;
459+
460+ if status. is_server_error ( ) {
461+ // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
462+ let message: String = String :: from_utf8_lossy ( & bytes) . into_owned ( ) ;
463+ tracing:: warn!( "Server error: {status} - {message}" ) ;
464+ return Err ( OpenAIError :: ApiError ( ApiError {
465+ message,
466+ r#type : None ,
467+ param : None ,
468+ code : None ,
469+ } ) ) ;
470+ }
471+
472+ // Deserialize response body from either error object or actual response object
473+ if !status. is_success ( ) {
474+ let wrapped_error: WrappedError = serde_json:: from_slice ( bytes. as_ref ( ) )
475+ . map_err ( |e| map_deserialization_error ( e, bytes. as_ref ( ) ) ) ?;
476+
477+ return Err ( OpenAIError :: ApiError ( wrapped_error. error ) ) ;
478+ }
479+
480+ Ok ( bytes)
481+ }
482+
483+ async fn map_stream_error ( value : EventSourceError ) -> OpenAIError {
484+ match value {
485+ EventSourceError :: Parser ( e) => OpenAIError :: StreamError ( StreamError :: Parser ( e. to_string ( ) ) ) ,
486+ EventSourceError :: InvalidContentType ( e, response) => {
487+ OpenAIError :: StreamError ( StreamError :: InvalidContentType ( e, response) )
488+ }
489+ EventSourceError :: InvalidLastEventId ( e) => {
490+ OpenAIError :: StreamError ( StreamError :: InvalidLastEventId ( e) )
491+ }
492+ EventSourceError :: StreamEnded => OpenAIError :: StreamError ( StreamError :: StreamEnded ) ,
493+ EventSourceError :: Utf8 ( e) => OpenAIError :: StreamError ( StreamError :: Utf8 ( e) ) ,
494+ EventSourceError :: Transport ( error) => OpenAIError :: Reqwest ( error) ,
495+ EventSourceError :: InvalidStatusCode ( _status_code, response) => {
496+ read_response ( response) . await . expect_err (
497+ "Unreachable because read_response returns err when status_code is invalid" ,
498+ )
499+ }
500+ }
501+ }
502+
474503/// Request which responds with SSE.
475504/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
476505pub ( crate ) async fn stream < O > (
@@ -485,7 +514,7 @@ where
485514 while let Some ( ev) = event_source. next ( ) . await {
486515 match ev {
487516 Err ( e) => {
488- if let Err ( _e) = tx. send ( Err ( OpenAIError :: StreamError ( e . to_string ( ) ) ) ) {
517+ if let Err ( _e) = tx. send ( Err ( map_stream_error ( e ) . await ) ) {
489518 // rx dropped
490519 break ;
491520 }
@@ -530,7 +559,7 @@ where
530559 while let Some ( ev) = event_source. next ( ) . await {
531560 match ev {
532561 Err ( e) => {
533- if let Err ( _e) = tx. send ( Err ( OpenAIError :: StreamError ( e . to_string ( ) ) ) ) {
562+ if let Err ( _e) = tx. send ( Err ( map_stream_error ( e ) . await ) ) {
534563 // rx dropped
535564 break ;
536565 }
0 commit comments