Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/responses.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::common::GPT4_1_MINI;
use openai_api_rs::v1::responses::CreateResponseRequest;
use openai_api_rs::v1::responses::responses::CreateResponseRequest;
use serde_json::json;
use std::env;

Expand Down
51 changes: 51 additions & 0 deletions examples/responses_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use futures_util::StreamExt;
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::common::GPT4_1_MINI;
use openai_api_rs::v1::responses::responses_stream::{
CreateResponseStreamRequest, ResponseStreamResponse,
};
use serde_json::{json, Value};
use std::env;
use std::io::{self, Write};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_key = env::var("OPENAI_API_KEY").unwrap();
let mut client = OpenAIClient::builder().with_api_key(api_key).build()?;

let mut req = CreateResponseStreamRequest::new();
req.model = Some(GPT4_1_MINI.to_string());
req.input = Some(json!("What is bitcoin? Please answer in detail."));

let mut stream = client.create_response_stream(req).await?;
let mut full_text = String::new();

while let Some(event) = stream.next().await {
match event {
ResponseStreamResponse::Event(evt) => {
if let Some("response.output_text.delta") = evt.event.as_deref() {
if let Some(delta) = evt.data.get("delta").and_then(Value::as_str) {
print!("{delta}");
io::stdout().flush()?;
full_text.push_str(delta);
continue;
}
}

if let Some(name) = evt.event.as_deref() {
println!("\nEvent: {name} => {}", evt.data);
} else {
println!("Event data: {}", evt.data);
}
}
ResponseStreamResponse::Done => {
println!("\n\nDone streaming response.");
}
}
}

println!("\nCollected text: {full_text}");
Ok(())
}

// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example responses_stream
39 changes: 38 additions & 1 deletion src/v1/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ use crate::v1::message::{
};
use crate::v1::model::{ModelResponse, ModelsResponse};
use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse};
use crate::v1::responses::{
use crate::v1::responses::responses::{
CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject,
};
use crate::v1::responses::responses_stream::{
CreateResponseStreamRequest, ResponseStream, ResponseStreamResponse,
};
use crate::v1::run::{
CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject,
RunStepObject,
Expand Down Expand Up @@ -830,6 +833,40 @@ impl OpenAIClient {
self.post("responses", &req).await
}

pub async fn create_response_stream(
&mut self,
req: CreateResponseStreamRequest,
) -> Result<impl Stream<Item = ResponseStreamResponse>, APIError> {
let mut payload = to_value(&req).map_err(|err| APIError::CustomError {
message: format!("Failed to serialize request: {}", err),
})?;

if let Some(obj) = payload.as_object_mut() {
obj.insert("stream".into(), Value::Bool(true));
}

let request = self.build_request(Method::POST, "responses").await;
let request = request.json(&payload);
let response = request.send().await?;

if response.status().is_success() {
Ok(ResponseStream {
response: Box::pin(response.bytes_stream()),
buffer: String::new(),
first_chunk: true,
})
} else {
let error_text = response
.text()
.await
.unwrap_or_else(|_| String::from("Unknown error"));

Err(APIError::CustomError {
message: error_text,
})
}
}

pub async fn retrieve_response(
&mut self,
response_id: String,
Expand Down
3 changes: 3 additions & 0 deletions src/v1/responses/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#[allow(clippy::module_inception)]
pub mod responses;
pub mod responses_stream;
2 changes: 2 additions & 0 deletions src/v1/responses.rs → src/v1/responses/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::BTreeMap;

// pub mod responses_stream;

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CreateResponseRequest {
// background
Expand Down
132 changes: 132 additions & 0 deletions src/v1/responses/responses_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use super::responses::CreateResponseRequest;
use futures_util::Stream;
use serde_json::Value;
use std::pin::Pin;
use std::task::{Context, Poll};

pub type CreateResponseStreamRequest = CreateResponseRequest;

#[derive(Debug, Clone)]
pub struct ResponseStreamEvent {
pub event: Option<String>,
pub data: Value,
}

#[derive(Debug, Clone)]
pub enum ResponseStreamResponse {
Event(ResponseStreamEvent),
Done,
}

pub struct ResponseStream<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> {
pub response: S,
pub buffer: String,
pub first_chunk: bool,
}

impl<S> ResponseStream<S>
where
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
{
fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> {
let carriage_idx = buffer.find("\r\n\r\n");
let newline_idx = buffer.find("\n\n");

match (carriage_idx, newline_idx) {
(Some(r_idx), Some(n_idx)) => {
if r_idx <= n_idx {
Some((r_idx, 4))
} else {
Some((n_idx, 2))
}
}
(Some(r_idx), None) => Some((r_idx, 4)),
(None, Some(n_idx)) => Some((n_idx, 2)),
(None, None) => None,
}
}

fn next_response_from_buffer(&mut self) -> Option<ResponseStreamResponse> {
while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) {
let event_block = self.buffer[..idx].to_owned();
self.buffer = self.buffer[idx + delimiter_len..].to_owned();

let mut event_name = None;
let mut data_payload = String::new();

for line in event_block.lines() {
let trimmed_line = line.trim_end_matches('\r');

if let Some(event) = trimmed_line
.strip_prefix("event: ")
.or_else(|| trimmed_line.strip_prefix("event:"))
{
let name = event.trim();
if !name.is_empty() {
event_name = Some(name.to_string());
}
} else if let Some(content) = trimmed_line
.strip_prefix("data: ")
.or_else(|| trimmed_line.strip_prefix("data:"))
{
if !content.is_empty() {
if !data_payload.is_empty() {
data_payload.push('\n');
}
data_payload.push_str(content);
}
}
}

if data_payload.is_empty() {
continue;
}

if data_payload.trim() == "[DONE]" {
return Some(ResponseStreamResponse::Done);
}

let parsed = serde_json::from_str::<Value>(&data_payload)
.unwrap_or_else(|_| Value::String(data_payload.clone()));

return Some(ResponseStreamResponse::Event(ResponseStreamEvent {
event: event_name,
data: parsed,
}));
}

None
}
}

impl<S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin> Stream for ResponseStream<S> {
type Item = ResponseStreamResponse;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(response) = self.next_response_from_buffer() {
return Poll::Ready(Some(response));
}

match Pin::new(&mut self.as_mut().response).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let chunk_str = String::from_utf8_lossy(&chunk).to_string();
if self.first_chunk {
self.first_chunk = false;
}
self.buffer.push_str(&chunk_str);
}
Poll::Ready(Some(Err(error))) => {
eprintln!("Error in stream: {:?}", error);
return Poll::Ready(None);
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
Loading