From 54a84ca0c64ee13a27bc2d9348021fd51a7d345f Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Wed, 15 Oct 2025 14:21:31 +0900 Subject: [PATCH 1/2] Fix response buffer --- .../chat_completion/chat_completion_stream.rs | 147 ++++++++++++------ 1 file changed, 96 insertions(+), 51 deletions(-) diff --git a/src/v1/chat_completion/chat_completion_stream.rs b/src/v1/chat_completion/chat_completion_stream.rs index 32382ac..0a6fb95 100644 --- a/src/v1/chat_completion/chat_completion_stream.rs +++ b/src/v1/chat_completion/chat_completion_stream.rs @@ -113,71 +113,116 @@ pub struct ChatCompletionStream> + Unpin> Stream - for ChatCompletionStream +impl ChatCompletionStream +where + S: Stream> + Unpin, { - type Item = ChatCompletionStreamResponse; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match Pin::new(&mut self.as_mut().response).poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - let mut utf8_str = String::from_utf8_lossy(&chunk).to_string(); + fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> { + let carriage_idx = buffer.find("\r\n\r\n"); + let newline_idx = buffer.find("\n\n"); + + match (carriage_idx, newline_idx) { + (Some(r_idx), Some(n_idx)) => { + if r_idx <= n_idx { + Some((r_idx, 4)) + } else { + Some((n_idx, 2)) + } + } + (Some(r_idx), None) => Some((r_idx, 4)), + (None, Some(n_idx)) => Some((n_idx, 2)), + (None, None) => None, + } + } - if self.first_chunk { - let lines: Vec<&str> = utf8_str.lines().collect(); - utf8_str = if lines.len() >= 2 { - lines[lines.len() - 2].to_string() - } else { - utf8_str.clone() - }; - self.first_chunk = false; + fn next_response_from_buffer(&mut self) -> Option { + while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) { + let event = self.buffer[..idx].to_owned(); + self.buffer = self.buffer[idx + delimiter_len..].to_owned(); + + let mut data_payload = String::new(); + for line in event.lines() { + let trimmed_line = line.trim_end_matches('\r'); + if let Some(content) = trimmed_line + .strip_prefix("data: ") + .or_else(|| trimmed_line.strip_prefix("data:")) + { + if !content.is_empty() { + if !data_payload.is_empty() { + data_payload.push('\n'); + } + data_payload.push_str(content); } + } + } - let trimmed_str = utf8_str.trim_start_matches("data: "); - if trimmed_str.contains("[DONE]") { - return Poll::Ready(Some(ChatCompletionStreamResponse::Done)); - } + if data_payload.is_empty() { + continue; + } + + if data_payload == "[DONE]" { + return Some(ChatCompletionStreamResponse::Done); + } - self.buffer.push_str(trimmed_str); - let json_result: Result = serde_json::from_str(&self.buffer); - - if let Ok(json) = json_result { - self.buffer.clear(); - - if let Some(choices) = json.get("choices") { - if let Some(choice) = choices.get(0) { - if let Some(delta) = choice.get("delta") { - if let Some(tool_calls) = delta.get("tool_calls") { - if let Some(tool_calls_array) = tool_calls.as_array() { - let tool_calls_vec: Vec = tool_calls_array - .iter() - .filter_map(|v| { - serde_json::from_value(v.clone()).ok() - }) - .collect(); - - return Poll::Ready(Some( - ChatCompletionStreamResponse::ToolCall( - tool_calls_vec, - ), + match serde_json::from_str::(&data_payload) { + Ok(json) => { + if let Some(choices) = json.get("choices") { + if let Some(choice) = choices.get(0) { + if let Some(delta) = choice.get("delta") { + if let Some(tool_calls) = delta.get("tool_calls") { + if let Some(tool_calls_array) = tool_calls.as_array() { + let tool_calls_vec: Vec = tool_calls_array + .iter() + .filter_map(|v| serde_json::from_value(v.clone()).ok()) + .collect(); + + if !tool_calls_vec.is_empty() { + return Some(ChatCompletionStreamResponse::ToolCall( + tool_calls_vec, )); } } + } - if let Some(content) = - delta.get("content").and_then(|c| c.as_str()) - { - let output = content.replace("\\n", "\n"); - return Poll::Ready(Some( - ChatCompletionStreamResponse::Content(output), - )); - } + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) + { + let output = content.replace("\\n", "\n"); + return Some(ChatCompletionStreamResponse::Content(output)); } } } } } + Err(error) => { + eprintln!("Failed to parse SSE chunk as JSON: {}", error); + } + } + } + + None + } +} + +impl> + Unpin> Stream + for ChatCompletionStream +{ + type Item = ChatCompletionStreamResponse; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(response) = self.next_response_from_buffer() { + return Poll::Ready(Some(response)); + } + + match Pin::new(&mut self.as_mut().response).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let chunk_str = String::from_utf8_lossy(&chunk).to_string(); + + if self.first_chunk { + self.first_chunk = false; + } + self.buffer.push_str(&chunk_str); + } Poll::Ready(Some(Err(error))) => { eprintln!("Error in stream: {:?}", error); return Poll::Ready(None); From a6821db342e7a77aa0a077cfa911eb90bfe51632 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Wed, 15 Oct 2025 14:24:53 +0900 Subject: [PATCH 2/2] Refactoring chat completion stream --- .../chat_completion/chat_completion_stream.rs | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/v1/chat_completion/chat_completion_stream.rs b/src/v1/chat_completion/chat_completion_stream.rs index 0a6fb95..f5b3283 100644 --- a/src/v1/chat_completion/chat_completion_stream.rs +++ b/src/v1/chat_completion/chat_completion_stream.rs @@ -166,30 +166,29 @@ where match serde_json::from_str::(&data_payload) { Ok(json) => { - if let Some(choices) = json.get("choices") { - if let Some(choice) = choices.get(0) { - if let Some(delta) = choice.get("delta") { - if let Some(tool_calls) = delta.get("tool_calls") { - if let Some(tool_calls_array) = tool_calls.as_array() { - let tool_calls_vec: Vec = tool_calls_array - .iter() - .filter_map(|v| serde_json::from_value(v.clone()).ok()) - .collect(); - - if !tool_calls_vec.is_empty() { - return Some(ChatCompletionStreamResponse::ToolCall( - tool_calls_vec, - )); - } - } - } - - if let Some(content) = delta.get("content").and_then(|c| c.as_str()) - { - let output = content.replace("\\n", "\n"); - return Some(ChatCompletionStreamResponse::Content(output)); - } - } + if let Some(delta) = json + .get("choices") + .and_then(|choices| choices.get(0)) + .and_then(|choice| choice.get("delta")) + { + if let Some(tool_call_response) = delta + .get("tool_calls") + .and_then(|tool_calls| tool_calls.as_array()) + .map(|tool_calls_array| { + tool_calls_array + .iter() + .filter_map(|v| serde_json::from_value(v.clone()).ok()) + .collect::>() + }) + .filter(|tool_calls_vec| !tool_calls_vec.is_empty()) + .map(ChatCompletionStreamResponse::ToolCall) + { + return Some(tool_call_response); + } + + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + let output = content.replace("\\n", "\n"); + return Some(ChatCompletionStreamResponse::Content(output)); } } }