diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 12b3a3c5104e..97f08a9ac36b 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -24,6 +24,214 @@ use std::pin::Pin; use std::sync::LazyLock; use std::sync::Mutex; +#[derive(Debug, Default, PartialEq, Eq)] +pub struct FilterOut { + pub content: String, + pub thinking: String, +} + +pub struct ThinkFilter { + buffer: String, + inside_think: bool, + think_depth: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ThinkTag { + Open, + Close, +} + +enum BufferEvent { + Tag { + pos: usize, + end: usize, + kind: ThinkTag, + }, + Partial(usize), +} + +impl ThinkFilter { + pub fn new() -> Self { + Self { + buffer: String::new(), + inside_think: false, + think_depth: 0, + } + } + + pub fn push(&mut self, chunk: &str) -> FilterOut { + self.buffer.push_str(chunk); + self.process_buffer() + } + + pub fn finish(mut self) -> FilterOut { + let mut out = self.process_buffer(); + if !self.buffer.is_empty() { + if self.inside_think { + out.thinking.push_str(&self.buffer); + } else { + out.content.push_str(&self.buffer); + } + self.buffer.clear(); + } + out + } + + fn process_buffer(&mut self) -> FilterOut { + let mut out = FilterOut::default(); + + loop { + match next_buffer_event(&self.buffer, self.inside_think) { + Some(BufferEvent::Tag { pos, end, kind }) => { + if pos > 0 { + let prefix = self.buffer.get(..pos).unwrap_or_default().to_string(); + if self.inside_think { + out.thinking.push_str(&prefix); + } else { + out.content.push_str(&prefix); + } + } + + self.buffer.drain(..end); + + match kind { + ThinkTag::Open => { + self.think_depth += 1; + self.inside_think = true; + } + ThinkTag::Close => { + self.think_depth = self.think_depth.saturating_sub(1); + self.inside_think = self.think_depth > 0; + } + } + } + Some(BufferEvent::Partial(pos)) => { + if pos > 0 { + let prefix = self.buffer.get(..pos).unwrap_or_default().to_string(); + if self.inside_think { + out.thinking.push_str(&prefix); + } else { + out.content.push_str(&prefix); + } + self.buffer.drain(..pos); + } + break; + } + None => { + if !self.buffer.is_empty() { + if self.inside_think { + out.thinking.push_str(&self.buffer); + } else { + out.content.push_str(&self.buffer); + } + self.buffer.clear(); + } + break; + } + } + } + + out + } +} + +impl Default for ThinkFilter { + fn default() -> Self { + Self::new() + } +} + +pub fn split_think_blocks(text: &str) -> (String, String) { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(text); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + (out.content, out.thinking) +} + +fn next_buffer_event(buffer: &str, inside_think: bool) -> Option { + let mut search_from = 0; + + while let Some(rel_pos) = buffer.get(search_from..).and_then(|rest| rest.find('<')) { + let pos = search_from + rel_pos; + let suffix = buffer.get(pos..).unwrap_or_default(); + + if let Some((kind, end)) = parse_think_tag(buffer, pos) { + if inside_think || kind == ThinkTag::Open { + return Some(BufferEvent::Tag { pos, end, kind }); + } + } else if !suffix.contains('>') && is_possible_partial_think_tag(suffix) { + return Some(BufferEvent::Partial(pos)); + } + + search_from = pos + 1; + } + + None +} + +fn parse_think_tag(buffer: &str, start: usize) -> Option<(ThinkTag, usize)> { + let bytes = buffer.as_bytes(); + if bytes.get(start) != Some(&b'<') { + return None; + } + + let mut idx = start + 1; + let is_close = if bytes.get(idx) == Some(&b'/') { + idx += 1; + true + } else { + false + }; + + let name_start = idx; + while bytes.get(idx).is_some_and(u8::is_ascii_alphabetic) { + idx += 1; + } + + if idx == name_start { + return None; + } + + let name = buffer.get(name_start..idx).unwrap_or_default(); + let is_think = name.eq_ignore_ascii_case("think") || name.eq_ignore_ascii_case("thinking"); + if !is_think { + return None; + } + + if is_close { + while bytes.get(idx).is_some_and(u8::is_ascii_whitespace) { + idx += 1; + } + if bytes.get(idx) == Some(&b'>') { + return Some((ThinkTag::Close, idx + 1)); + } + return None; + } + + while let Some(byte) = bytes.get(idx) { + if *byte == b'>' { + return Some((ThinkTag::Open, idx + 1)); + } + idx += 1; + } + + None +} + +fn is_possible_partial_think_tag(suffix: &str) -> bool { + static OPEN_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"(?is)^<(?:t(?:h(?:i(?:n(?:k(?:i(?:n(?:g)?)?)?)?)?)?)?)(?:\s[^>]*)?$").unwrap() + }); + static CLOSE_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"(?is)^ String { static BLOCK_RE: LazyLock = LazyLock::new(|| { Regex::new(r"(?s)<([a-zA-Z][a-zA-Z0-9_]*)[^>]*>.*?").unwrap() @@ -883,6 +1091,97 @@ mod tests { ); } + #[test] + fn test_split_think_blocks_extracts_inline_reasoning() { + assert_eq!( + split_think_blocks("xy"), + ("y".to_string(), "x".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_is_case_insensitive() { + assert_eq!( + split_think_blocks("xy"), + ("y".to_string(), "x".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_handles_multiple_blocks() { + assert_eq!( + split_think_blocks("abcd"), + ("bd".to_string(), "ac".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_without_tags() { + assert_eq!( + split_think_blocks("plain content"), + ("plain content".to_string(), String::new()) + ); + } + + #[test] + fn test_split_think_blocks_handles_attributes() { + assert_eq!( + split_think_blocks(r#"ab"#), + ("b".to_string(), "a".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_handles_thinking_variant() { + assert_eq!( + split_think_blocks("ab"), + ("b".to_string(), "a".to_string()) + ); + } + + #[test] + fn test_think_filter_streaming_across_partial_tags() { + let mut filter = ThinkFilter::new(); + let mut out = FilterOut::default(); + + for chunk in ["xy"] { + let partial = filter.push(chunk); + out.content.push_str(&partial.content); + out.thinking.push_str(&partial.thinking); + } + + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "y"); + assert_eq!(out.thinking, "x"); + } + + #[test] + fn test_think_filter_preserves_non_think_tags() { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(""); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "
"); + assert!(out.thinking.is_empty()); + } + + #[test] + fn test_think_filter_finish_treats_unterminated_think_as_thinking() { + let mut filter = ThinkFilter::new(); + let mut out = filter.push("unfinished"); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert!(out.content.is_empty()); + assert_eq!(out.thinking, "unfinished"); + } + #[test] fn test_extract_short_title() { assert_eq!(extract_short_title("List files"), "List files"); diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index e896ecb16432..86bd8d0351b8 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1,7 +1,7 @@ use crate::conversation::message::{Message, MessageContent, ProviderMetadata}; use crate::mcp_utils::extract_text_from_resource; use crate::model::ModelConfig; -use crate::providers::base::{ProviderUsage, Usage}; +use crate::providers::base::{split_think_blocks, ProviderUsage, ThinkFilter, Usage}; use crate::providers::errors::ProviderError; use crate::providers::utils::{ convert_image, detect_image_path, extract_reasoning_effort, is_valid_function_name, @@ -498,9 +498,11 @@ pub fn response_to_message(response: &Value) -> anyhow::Result { let reasoning_value = original .get("reasoning_content") .or_else(|| original.get("reasoning")); + let mut has_structured_thinking = false; if let Some(reasoning_content) = reasoning_value { if let Some(reasoning_str) = reasoning_content.as_str() { if !reasoning_str.is_empty() { + has_structured_thinking = true; content.push(MessageContent::thinking(reasoning_str, "")); } } @@ -508,7 +510,15 @@ pub fn response_to_message(response: &Value) -> anyhow::Result { if let Some(text) = original.get("content") { if let Some(text_str) = text.as_str() { - content.push(MessageContent::text(text_str)); + let (cleaned, inline_thinking) = split_think_blocks(text_str); + + if !has_structured_thinking && !inline_thinking.is_empty() { + content.push(MessageContent::thinking(inline_thinking, "")); + } + + if !cleaned.is_empty() { + content.push(MessageContent::text(cleaned)); + } } } @@ -746,6 +756,8 @@ where let mut accumulated_reasoning: Vec = Vec::new(); let mut accumulated_reasoning_content = String::new(); + let mut think_filter = ThinkFilter::new(); + let mut saw_structured_reasoning = false; let mut last_signature: Option = None; 'outer: while let Some(response) = stream.next().await { @@ -770,6 +782,9 @@ where } if let Some(rc) = &chunk.choices[0].delta.reasoning_content { accumulated_reasoning_content.push_str(rc); + if !rc.is_empty() { + saw_structured_reasoning = true; + } } } @@ -812,6 +827,9 @@ where } if let Some(rc) = &tool_chunk.choices[0].delta.reasoning_content { accumulated_reasoning_content.push_str(rc); + if !rc.is_empty() { + saw_structured_reasoning = true; + } } if let Some(delta_tool_calls) = &tool_chunk.choices[0].delta.tool_calls { for delta_call in delta_tool_calls { @@ -852,6 +870,31 @@ where None }; + let filtered = think_filter.push(""); + if !filtered.content.is_empty() || (!filtered.thinking.is_empty() && !saw_structured_reasoning) { + let mut filtered_contents = Vec::new(); + if !filtered.content.is_empty() { + filtered_contents.push(MessageContent::text(filtered.content)); + } + if !saw_structured_reasoning && !filtered.thinking.is_empty() { + filtered_contents.push(MessageContent::thinking(filtered.thinking, "")); + } + + if !filtered_contents.is_empty() { + let mut msg = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + filtered_contents, + ); + + if let Some(id) = chunk.id.clone() { + msg = msg.with_id(id); + } + + yield (Some(msg), None); + } + } + let mut contents = Vec::new(); if !accumulated_reasoning_content.is_empty() { contents.push(MessageContent::thinking(&accumulated_reasoning_content, "")); @@ -935,8 +978,14 @@ where } if let Some(text) = text_content { - if !text.is_empty() { - content.push(MessageContent::text(&text)); + let filtered = think_filter.push(&text); + + if !saw_structured_reasoning && !filtered.thinking.is_empty() { + content.push(MessageContent::thinking(filtered.thinking, "")); + } + + if !filtered.content.is_empty() { + content.push(MessageContent::text(filtered.content)); } } @@ -966,6 +1015,28 @@ where yield (None, usage) } } + + let filtered = think_filter.finish(); + if !filtered.content.is_empty() || (!filtered.thinking.is_empty() && !saw_structured_reasoning) { + let mut content = Vec::new(); + + if !filtered.content.is_empty() { + content.push(MessageContent::text(filtered.content)); + } + + if !saw_structured_reasoning && !filtered.thinking.is_empty() { + content.push(MessageContent::thinking(filtered.thinking, "")); + } + + yield ( + Some(Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + content, + )), + None, + ) + } } } @@ -2111,6 +2182,43 @@ data: [DONE]"#; panic!("Expected tool call message with nested extra_content metadata"); } + #[tokio::test] + async fn test_streaming_response_extracts_inline_think_blocks() -> anyhow::Result<()> { + let response_lines = concat!( + "data: {\"id\":\"chunk-1\",\"choices\":[{\"delta\":{\"content\":\"xy\"},\"index\":0,\"finish_reason\":\"stop\"}]}\n", + "data: [DONE]\n" + ); + + let response_stream = + tokio_stream::iter(response_lines.lines().map(|line| Ok(line.to_string()))); + let mut messages = std::pin::pin!(response_to_streaming_message(response_stream)); + + let mut text = String::new(); + let mut thinking = String::new(); + + while let Some(result) = messages.next().await { + let (message, _) = result?; + if let Some(message) = message { + for item in message.content { + match item { + MessageContent::Text(text_content) => text.push_str(&text_content.text), + MessageContent::Thinking(thinking_content) => { + thinking.push_str(&thinking_content.thinking) + } + _ => {} + } + } + } + } + + assert_eq!(text, "y"); + assert_eq!(thinking, "x"); + + Ok(()) + } + #[test] fn test_response_to_message_with_reasoning_content() -> anyhow::Result<()> { // Test capturing reasoning_content from DeepSeek reasoning models @@ -2149,6 +2257,66 @@ data: [DONE]"#; Ok(()) } + #[test] + fn test_response_to_message_extracts_inline_think_blocks() -> anyhow::Result<()> { + let response = json!({ + "choices": [{ + "role": "assistant", + "message": { + "content": "internal reasoningVisible answer" + } + }] + }); + + let message = response_to_message(&response)?; + assert_eq!(message.content.len(), 2); + + if let MessageContent::Thinking(thinking) = &message.content[0] { + assert_eq!(thinking.thinking, "internal reasoning"); + } else { + panic!("Expected Thinking content, got {:?}", message.content[0]); + } + + if let MessageContent::Text(text) = &message.content[1] { + assert_eq!(text.text, "Visible answer"); + } else { + panic!("Expected Text content"); + } + + Ok(()) + } + + #[test] + fn test_response_to_message_prefers_structured_reasoning_over_inline_think( + ) -> anyhow::Result<()> { + let response = json!({ + "choices": [{ + "role": "assistant", + "message": { + "reasoning_content": "structured reasoning", + "content": "inline reasoningVisible answer" + } + }] + }); + + let message = response_to_message(&response)?; + assert_eq!(message.content.len(), 2); + + if let MessageContent::Thinking(thinking) = &message.content[0] { + assert_eq!(thinking.thinking, "structured reasoning"); + } else { + panic!("Expected Thinking content"); + } + + if let MessageContent::Text(text) = &message.content[1] { + assert_eq!(text.text, "Visible answer"); + } else { + panic!("Expected Text content"); + } + + Ok(()) + } + #[test] fn test_format_messages_with_reasoning_content() -> anyhow::Result<()> { // Test that reasoning_content is properly included in formatted messages