diff --git a/examples/responses.rs b/examples/responses.rs index 8bc24fe..a308f3d 100644 --- a/examples/responses.rs +++ b/examples/responses.rs @@ -1,6 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::common::GPT4_1_MINI; -use openai_api_rs::v1::responses::CreateResponseRequest; +use openai_api_rs::v1::responses::responses::CreateResponseRequest; use serde_json::json; use std::env; diff --git a/examples/responses_stream.rs b/examples/responses_stream.rs new file mode 100644 index 0000000..8247539 --- /dev/null +++ b/examples/responses_stream.rs @@ -0,0 +1,51 @@ +use futures_util::StreamExt; +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::common::GPT4_1_MINI; +use openai_api_rs::v1::responses::responses_stream::{ + CreateResponseStreamRequest, ResponseStreamResponse, +}; +use serde_json::{json, Value}; +use std::env; +use std::io::{self, Write}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENAI_API_KEY").unwrap(); + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; + + let mut req = CreateResponseStreamRequest::new(); + req.model = Some(GPT4_1_MINI.to_string()); + req.input = Some(json!("What is bitcoin? Please answer in detail.")); + + let mut stream = client.create_response_stream(req).await?; + let mut full_text = String::new(); + + while let Some(event) = stream.next().await { + match event { + ResponseStreamResponse::Event(evt) => { + if let Some("response.output_text.delta") = evt.event.as_deref() { + if let Some(delta) = evt.data.get("delta").and_then(Value::as_str) { + print!("{delta}"); + io::stdout().flush()?; + full_text.push_str(delta); + continue; + } + } + + if let Some(name) = evt.event.as_deref() { + println!("\nEvent: {name} => {}", evt.data); + } else { + println!("Event data: {}", evt.data); + } + } + ResponseStreamResponse::Done => { + println!("\n\nDone streaming response."); + } + } + } + + println!("\nCollected text: {full_text}"); + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example responses_stream diff --git a/src/v1/api.rs b/src/v1/api.rs index 5d8a13d..a780d1e 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -35,9 +35,12 @@ use crate::v1::message::{ }; use crate::v1::model::{ModelResponse, ModelsResponse}; use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse}; -use crate::v1::responses::{ +use crate::v1::responses::responses::{ CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject, }; +use crate::v1::responses::responses_stream::{ + CreateResponseStreamRequest, ResponseStream, ResponseStreamResponse, +}; use crate::v1::run::{ CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject, RunStepObject, @@ -830,6 +833,40 @@ impl OpenAIClient { self.post("responses", &req).await } + pub async fn create_response_stream( + &mut self, + req: CreateResponseStreamRequest, + ) -> Result, APIError> { + let mut payload = to_value(&req).map_err(|err| APIError::CustomError { + message: format!("Failed to serialize request: {}", err), + })?; + + if let Some(obj) = payload.as_object_mut() { + obj.insert("stream".into(), Value::Bool(true)); + } + + let request = self.build_request(Method::POST, "responses").await; + let request = request.json(&payload); + let response = request.send().await?; + + if response.status().is_success() { + Ok(ResponseStream { + response: Box::pin(response.bytes_stream()), + buffer: String::new(), + first_chunk: true, + }) + } else { + let error_text = response + .text() + .await + .unwrap_or_else(|_| String::from("Unknown error")); + + Err(APIError::CustomError { + message: error_text, + }) + } + } + pub async fn retrieve_response( &mut self, response_id: String, diff --git a/src/v1/responses/mod.rs b/src/v1/responses/mod.rs new file mode 100644 index 0000000..fde1cd3 --- /dev/null +++ b/src/v1/responses/mod.rs @@ -0,0 +1,3 @@ +#[allow(clippy::module_inception)] +pub mod responses; +pub mod responses_stream; diff --git a/src/v1/responses.rs b/src/v1/responses/responses.rs similarity index 99% rename from src/v1/responses.rs rename to src/v1/responses/responses.rs index 348b1fe..d4f5dda 100644 --- a/src/v1/responses.rs +++ b/src/v1/responses/responses.rs @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::BTreeMap; +// pub mod responses_stream; + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct CreateResponseRequest { // background diff --git a/src/v1/responses/responses_stream.rs b/src/v1/responses/responses_stream.rs new file mode 100644 index 0000000..daf101a --- /dev/null +++ b/src/v1/responses/responses_stream.rs @@ -0,0 +1,132 @@ +use super::responses::CreateResponseRequest; +use futures_util::Stream; +use serde_json::Value; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub type CreateResponseStreamRequest = CreateResponseRequest; + +#[derive(Debug, Clone)] +pub struct ResponseStreamEvent { + pub event: Option, + pub data: Value, +} + +#[derive(Debug, Clone)] +pub enum ResponseStreamResponse { + Event(ResponseStreamEvent), + Done, +} + +pub struct ResponseStream> + Unpin> { + pub response: S, + pub buffer: String, + pub first_chunk: bool, +} + +impl ResponseStream +where + S: Stream> + Unpin, +{ + fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> { + let carriage_idx = buffer.find("\r\n\r\n"); + let newline_idx = buffer.find("\n\n"); + + match (carriage_idx, newline_idx) { + (Some(r_idx), Some(n_idx)) => { + if r_idx <= n_idx { + Some((r_idx, 4)) + } else { + Some((n_idx, 2)) + } + } + (Some(r_idx), None) => Some((r_idx, 4)), + (None, Some(n_idx)) => Some((n_idx, 2)), + (None, None) => None, + } + } + + fn next_response_from_buffer(&mut self) -> Option { + while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) { + let event_block = self.buffer[..idx].to_owned(); + self.buffer = self.buffer[idx + delimiter_len..].to_owned(); + + let mut event_name = None; + let mut data_payload = String::new(); + + for line in event_block.lines() { + let trimmed_line = line.trim_end_matches('\r'); + + if let Some(event) = trimmed_line + .strip_prefix("event: ") + .or_else(|| trimmed_line.strip_prefix("event:")) + { + let name = event.trim(); + if !name.is_empty() { + event_name = Some(name.to_string()); + } + } else if let Some(content) = trimmed_line + .strip_prefix("data: ") + .or_else(|| trimmed_line.strip_prefix("data:")) + { + if !content.is_empty() { + if !data_payload.is_empty() { + data_payload.push('\n'); + } + data_payload.push_str(content); + } + } + } + + if data_payload.is_empty() { + continue; + } + + if data_payload.trim() == "[DONE]" { + return Some(ResponseStreamResponse::Done); + } + + let parsed = serde_json::from_str::(&data_payload) + .unwrap_or_else(|_| Value::String(data_payload.clone())); + + return Some(ResponseStreamResponse::Event(ResponseStreamEvent { + event: event_name, + data: parsed, + })); + } + + None + } +} + +impl> + Unpin> Stream for ResponseStream { + type Item = ResponseStreamResponse; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(response) = self.next_response_from_buffer() { + return Poll::Ready(Some(response)); + } + + match Pin::new(&mut self.as_mut().response).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let chunk_str = String::from_utf8_lossy(&chunk).to_string(); + if self.first_chunk { + self.first_chunk = false; + } + self.buffer.push_str(&chunk_str); + } + Poll::Ready(Some(Err(error))) => { + eprintln!("Error in stream: {:?}", error); + return Poll::Ready(None); + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +}