Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,214 @@ use std::pin::Pin;
use std::sync::LazyLock;
use std::sync::Mutex;

#[derive(Debug, Default, PartialEq, Eq)]
pub struct FilterOut {
pub content: String,
pub thinking: String,
}

pub struct ThinkFilter {
buffer: String,
inside_think: bool,
think_depth: usize,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ThinkTag {
Open,
Close,
}

enum BufferEvent {
Tag {
pos: usize,
end: usize,
kind: ThinkTag,
},
Partial(usize),
}

impl ThinkFilter {
pub fn new() -> Self {
Self {
buffer: String::new(),
inside_think: false,
think_depth: 0,
}
}

pub fn push(&mut self, chunk: &str) -> FilterOut {
self.buffer.push_str(chunk);
self.process_buffer()
}

pub fn finish(mut self) -> FilterOut {
let mut out = self.process_buffer();
if !self.buffer.is_empty() {
if self.inside_think {
out.thinking.push_str(&self.buffer);
} else {
out.content.push_str(&self.buffer);
}
self.buffer.clear();
}
out
}

fn process_buffer(&mut self) -> FilterOut {
let mut out = FilterOut::default();

loop {
match next_buffer_event(&self.buffer, self.inside_think) {
Some(BufferEvent::Tag { pos, end, kind }) => {
if pos > 0 {
let prefix = self.buffer.get(..pos).unwrap_or_default().to_string();
if self.inside_think {
out.thinking.push_str(&prefix);
} else {
out.content.push_str(&prefix);
}
}

self.buffer.drain(..end);

match kind {
ThinkTag::Open => {
self.think_depth += 1;
self.inside_think = true;
}
ThinkTag::Close => {
self.think_depth = self.think_depth.saturating_sub(1);
self.inside_think = self.think_depth > 0;
}
}
}
Some(BufferEvent::Partial(pos)) => {
if pos > 0 {
let prefix = self.buffer.get(..pos).unwrap_or_default().to_string();
if self.inside_think {
out.thinking.push_str(&prefix);
} else {
out.content.push_str(&prefix);
}
self.buffer.drain(..pos);
}
break;
}
None => {
if !self.buffer.is_empty() {
if self.inside_think {
out.thinking.push_str(&self.buffer);
} else {
out.content.push_str(&self.buffer);
}
self.buffer.clear();
}
break;
}
}
}

out
}
}

impl Default for ThinkFilter {
fn default() -> Self {
Self::new()
}
}

pub fn split_think_blocks(text: &str) -> (String, String) {
let mut filter = ThinkFilter::new();
let mut out = filter.push(text);
let final_out = filter.finish();
out.content.push_str(&final_out.content);
out.thinking.push_str(&final_out.thinking);
(out.content, out.thinking)
}

fn next_buffer_event(buffer: &str, inside_think: bool) -> Option<BufferEvent> {
let mut search_from = 0;

while let Some(rel_pos) = buffer.get(search_from..).and_then(|rest| rest.find('<')) {
let pos = search_from + rel_pos;
let suffix = buffer.get(pos..).unwrap_or_default();

if let Some((kind, end)) = parse_think_tag(buffer, pos) {
if inside_think || kind == ThinkTag::Open {
return Some(BufferEvent::Tag { pos, end, kind });
}
} else if !suffix.contains('>') && is_possible_partial_think_tag(suffix) {
return Some(BufferEvent::Partial(pos));
}

search_from = pos + 1;
}

None
}

fn parse_think_tag(buffer: &str, start: usize) -> Option<(ThinkTag, usize)> {
let bytes = buffer.as_bytes();
if bytes.get(start) != Some(&b'<') {
return None;
}

let mut idx = start + 1;
let is_close = if bytes.get(idx) == Some(&b'/') {
idx += 1;
true
} else {
false
};

let name_start = idx;
while bytes.get(idx).is_some_and(u8::is_ascii_alphabetic) {
idx += 1;
}

if idx == name_start {
return None;
}

let name = buffer.get(name_start..idx).unwrap_or_default();
let is_think = name.eq_ignore_ascii_case("think") || name.eq_ignore_ascii_case("thinking");
if !is_think {
return None;
}

if is_close {
while bytes.get(idx).is_some_and(u8::is_ascii_whitespace) {
idx += 1;
}
if bytes.get(idx) == Some(&b'>') {
return Some((ThinkTag::Close, idx + 1));
}
return None;
}

while let Some(byte) = bytes.get(idx) {
if *byte == b'>' {
return Some((ThinkTag::Open, idx + 1));
}
Comment thread
mvanhorn marked this conversation as resolved.
Outdated
idx += 1;
}

None
}

fn is_possible_partial_think_tag(suffix: &str) -> bool {
static OPEN_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?is)^<(?:t(?:h(?:i(?:n(?:k(?:i(?:n(?:g)?)?)?)?)?)?)?)(?:\s[^>]*)?$").unwrap()
});
Comment on lines +270 to +272
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Recognize <think/ as an incomplete self-closing think tag

is_possible_partial_think_tag does not match the valid in-progress suffix <think/ (or <thinking/), so when a stream chunk splits a self-closing tag at that boundary (e.g. "...<think/" then ">..."), the first chunk is emitted as visible text and the marker is never stripped. That reintroduces leakage for one common chunk split pattern and violates the streaming guarantee that mid-tag boundaries are buffered until disambiguated.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mvanhorn looks like there's one last issue to resolve here (the streaming parser can emit <think/ or <thinking/ if that appears at the end of a chunk). If you can tweak the regex to handle that I think then we should be able to merge this

static CLOSE_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?is)^</(?:t(?:h(?:i(?:n(?:k(?:i(?:n(?:g)?)?)?)?)?)?)?)(?:\s*)?$").unwrap()
});

OPEN_RE.is_match(suffix) || CLOSE_RE.is_match(suffix)
}

fn strip_xml_tags(text: &str) -> String {
static BLOCK_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?s)<([a-zA-Z][a-zA-Z0-9_]*)[^>]*>.*?</[a-zA-Z][a-zA-Z0-9_]*>").unwrap()
Expand Down Expand Up @@ -883,6 +1091,97 @@ mod tests {
);
}

#[test]
fn test_split_think_blocks_extracts_inline_reasoning() {
assert_eq!(
split_think_blocks("<think>x</think>y"),
("y".to_string(), "x".to_string())
);
}

#[test]
fn test_split_think_blocks_is_case_insensitive() {
assert_eq!(
split_think_blocks("<THINK>x</think>y"),
("y".to_string(), "x".to_string())
);
}

#[test]
fn test_split_think_blocks_handles_multiple_blocks() {
assert_eq!(
split_think_blocks("<think>a</think>b<think>c</think>d"),
("bd".to_string(), "ac".to_string())
);
}

#[test]
fn test_split_think_blocks_without_tags() {
assert_eq!(
split_think_blocks("plain content"),
("plain content".to_string(), String::new())
);
}

#[test]
fn test_split_think_blocks_handles_attributes() {
assert_eq!(
split_think_blocks(r#"<think class="x">a</think>b"#),
("b".to_string(), "a".to_string())
);
}

#[test]
fn test_split_think_blocks_handles_thinking_variant() {
assert_eq!(
split_think_blocks("<thinking>a</thinking>b"),
("b".to_string(), "a".to_string())
);
}

#[test]
fn test_think_filter_streaming_across_partial_tags() {
let mut filter = ThinkFilter::new();
let mut out = FilterOut::default();

for chunk in ["<thi", "nk>x</thi", "nk>y"] {
let partial = filter.push(chunk);
out.content.push_str(&partial.content);
out.thinking.push_str(&partial.thinking);
}

let final_out = filter.finish();
out.content.push_str(&final_out.content);
out.thinking.push_str(&final_out.thinking);

assert_eq!(out.content, "y");
assert_eq!(out.thinking, "x");
}

#[test]
fn test_think_filter_preserves_non_think_tags() {
let mut filter = ThinkFilter::new();
let mut out = filter.push("<table>");
let final_out = filter.finish();
out.content.push_str(&final_out.content);
out.thinking.push_str(&final_out.thinking);

assert_eq!(out.content, "<table>");
assert!(out.thinking.is_empty());
}

#[test]
fn test_think_filter_finish_treats_unterminated_think_as_thinking() {
let mut filter = ThinkFilter::new();
let mut out = filter.push("<think>unfinished");
let final_out = filter.finish();
out.content.push_str(&final_out.content);
out.thinking.push_str(&final_out.thinking);

assert!(out.content.is_empty());
assert_eq!(out.thinking, "unfinished");
}

#[test]
fn test_extract_short_title() {
assert_eq!(extract_short_title("List files"), "List files");
Expand Down
Loading
Loading