Skip to content

Commit ab7a960

Browse files
authored
Unify StreamError and OpenAIError (#413)
* unify StreamError and OpenAIError * format * clippy
1 parent 3d0a137 commit ab7a960

File tree

3 files changed

+106
-51
lines changed

3 files changed

+106
-51
lines changed

async-openai/src/client.rs

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use std::pin::Pin;
22

33
use bytes::Bytes;
44
use futures::{stream::StreamExt, Stream};
5-
use reqwest::multipart::Form;
6-
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
5+
use reqwest::{multipart::Form, Response};
6+
use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
77
use serde::{de::DeserializeOwned, Serialize};
88

99
use crate::{
1010
config::{Config, OpenAIConfig},
11-
error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
11+
error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
1212
file::Files,
1313
image::Images,
1414
moderation::Moderations,
@@ -335,52 +335,34 @@ impl<C: Config> Client<C> {
335335
.map_err(backoff::Error::Permanent)?;
336336

337337
let status = response.status();
338-
let bytes = response
339-
.bytes()
340-
.await
341-
.map_err(OpenAIError::Reqwest)
342-
.map_err(backoff::Error::Permanent)?;
343338

344-
if status.is_server_error() {
345-
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
346-
let message: String = String::from_utf8_lossy(&bytes).into_owned();
347-
tracing::warn!("Server error: {status} - {message}");
348-
return Err(backoff::Error::Transient {
349-
err: OpenAIError::ApiError(ApiError {
350-
message,
351-
r#type: None,
352-
param: None,
353-
code: None,
354-
}),
355-
retry_after: None,
356-
});
357-
}
358-
359-
// Deserialize response body from either error object or actual response object
360-
if !status.is_success() {
361-
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
362-
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
363-
.map_err(backoff::Error::Permanent)?;
364-
365-
if status.as_u16() == 429
366-
// API returns 429 also when:
367-
// "You exceeded your current quota, please check your plan and billing details."
368-
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
369-
{
370-
// Rate limited retry...
371-
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
372-
return Err(backoff::Error::Transient {
373-
err: OpenAIError::ApiError(wrapped_error.error),
374-
retry_after: None,
375-
});
376-
} else {
377-
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
378-
wrapped_error.error,
379-
)));
339+
match read_response(response).await {
340+
Ok(bytes) => Ok(bytes),
341+
Err(e) => {
342+
match e {
343+
OpenAIError::ApiError(api_error) => {
344+
if status.is_server_error() {
345+
Err(backoff::Error::Transient {
346+
err: OpenAIError::ApiError(api_error),
347+
retry_after: None,
348+
})
349+
} else if status.as_u16() == 429
350+
&& api_error.r#type != Some("insufficient_quota".to_string())
351+
{
352+
// Rate limited retry...
353+
tracing::warn!("Rate limited: {}", api_error.message);
354+
Err(backoff::Error::Transient {
355+
err: OpenAIError::ApiError(api_error),
356+
retry_after: None,
357+
})
358+
} else {
359+
Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
360+
}
361+
}
362+
_ => Err(backoff::Error::Permanent(e)),
363+
}
380364
}
381365
}
382-
383-
Ok(bytes)
384366
})
385367
.await
386368
}
@@ -471,6 +453,53 @@ impl<C: Config> Client<C> {
471453
}
472454
}
473455

456+
async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
457+
let status = response.status();
458+
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;
459+
460+
if status.is_server_error() {
461+
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
462+
let message: String = String::from_utf8_lossy(&bytes).into_owned();
463+
tracing::warn!("Server error: {status} - {message}");
464+
return Err(OpenAIError::ApiError(ApiError {
465+
message,
466+
r#type: None,
467+
param: None,
468+
code: None,
469+
}));
470+
}
471+
472+
// Deserialize response body from either error object or actual response object
473+
if !status.is_success() {
474+
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
475+
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
476+
477+
return Err(OpenAIError::ApiError(wrapped_error.error));
478+
}
479+
480+
Ok(bytes)
481+
}
482+
483+
async fn map_stream_error(value: EventSourceError) -> OpenAIError {
484+
match value {
485+
EventSourceError::Parser(e) => OpenAIError::StreamError(StreamError::Parser(e.to_string())),
486+
EventSourceError::InvalidContentType(e, response) => {
487+
OpenAIError::StreamError(StreamError::InvalidContentType(e, response))
488+
}
489+
EventSourceError::InvalidLastEventId(e) => {
490+
OpenAIError::StreamError(StreamError::InvalidLastEventId(e))
491+
}
492+
EventSourceError::StreamEnded => OpenAIError::StreamError(StreamError::StreamEnded),
493+
EventSourceError::Utf8(e) => OpenAIError::StreamError(StreamError::Utf8(e)),
494+
EventSourceError::Transport(error) => OpenAIError::Reqwest(error),
495+
EventSourceError::InvalidStatusCode(_status_code, response) => {
496+
read_response(response).await.expect_err(
497+
"Unreachable because read_response returns err when status_code is invalid",
498+
)
499+
}
500+
}
501+
}
502+
474503
/// Request which responds with SSE.
475504
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
476505
pub(crate) async fn stream<O>(
@@ -485,7 +514,7 @@ where
485514
while let Some(ev) = event_source.next().await {
486515
match ev {
487516
Err(e) => {
488-
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
517+
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
489518
// rx dropped
490519
break;
491520
}
@@ -530,7 +559,7 @@ where
530559
while let Some(ev) = event_source.next().await {
531560
match ev {
532561
Err(e) => {
533-
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
562+
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
534563
// rx dropped
535564
break;
536565
}

async-openai/src/error.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
2+
use std::string::FromUtf8Error;
3+
4+
use reqwest::{header::HeaderValue, Response};
25
use serde::{Deserialize, Serialize};
36

7+
48
#[derive(Debug, thiserror::Error)]
59
pub enum OpenAIError {
610
/// Underlying error from reqwest library after an API call was made
@@ -20,13 +24,35 @@ pub enum OpenAIError {
2024
FileReadError(String),
2125
/// Error on SSE streaming
2226
#[error("stream failed: {0}")]
23-
StreamError(String),
27+
StreamError(StreamError),
2428
/// Error from client side validation
2529
/// or when builder fails to build request before making API call
2630
#[error("invalid args: {0}")]
2731
InvalidArgument(String),
2832
}
2933

34+
#[derive(Debug, thiserror::Error)]
35+
pub enum StreamError {
36+
/// Source stream is not valid UTF8
37+
#[error(transparent)]
38+
Utf8(FromUtf8Error),
39+
/// Source stream is not a valid EventStream
40+
#[error("Source stream is not a valid event stream: {0}")]
41+
Parser(String),
42+
/// The `Content-Type` returned by the server is invalid
43+
#[error("Invalid content type for event stream: {0:?}")]
44+
InvalidContentType(HeaderValue, Response),
45+
/// The `Last-Event-ID` cannot be formed into a Header to be submitted to the server
46+
#[error("Invalid `Last-Event-ID` for event stream: {0}")]
47+
InvalidLastEventId(String),
48+
/// The server sent an unrecognized event type
49+
#[error("Unrecognized event type: {0}")]
50+
UnrecognizedEventType(String),
51+
/// The stream ended
52+
#[error("Stream ended")]
53+
StreamEnded,
54+
}
55+
3056
/// OpenAI API returns error object on failure
3157
#[derive(Debug, Serialize, Deserialize, Clone)]
3258
pub struct ApiError {

async-openai/src/types/assistant_stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::pin::Pin;
33
use futures::Stream;
44
use serde::Deserialize;
55

6-
use crate::error::{map_deserialization_error, ApiError, OpenAIError};
6+
use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError};
77

88
use super::{
99
MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
@@ -208,7 +208,7 @@ impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
208208
"done" => Ok(AssistantStreamEvent::Done(value.data)),
209209

210210
_ => Err(OpenAIError::StreamError(
211-
"Unrecognized event: {value:?#}".into(),
211+
StreamError::UnrecognizedEventType(value.event),
212212
)),
213213
}
214214
}

0 commit comments

Comments
 (0)