From 00907b9a69663157d2a892d776c62e9d8cfca602 Mon Sep 17 00:00:00 2001 From: codestory Date: Sun, 9 Feb 2025 18:05:53 +0000 Subject: [PATCH 1/5] feat: add cache control to all Anthropic message content types --- llm_client/src/clients/anthropic.rs | 65 ++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index 092dca569..a8576750b 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -39,17 +39,22 @@ enum AnthropicMessageContent { cache_control: Option, }, #[serde(rename = "image")] - Image { source: AnthropicImageSource }, + Image { + source: AnthropicImageSource, + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, + cache_control: Option, }, #[serde(rename = "tool_result")] ToolReturn { tool_use_id: String, content: String, + cache_control: Option, }, } @@ -61,17 +66,6 @@ impl AnthropicMessageContent { } } - fn cache_control(mut self, cache_control_update: Option) -> Self { - if let Self::Text { - text: _, - ref mut cache_control, - } = self - { - *cache_control = cache_control_update; - } - self - } - pub fn image(llm_image: &LLMClientMessageImage) -> Self { Self::Image { source: AnthropicImageSource { @@ -79,6 +73,7 @@ impl AnthropicMessageContent { media_type: llm_image.media().to_owned(), data: llm_image.data().to_owned(), }, + cache_control: None, } } @@ -87,6 +82,7 @@ impl AnthropicMessageContent { id: llm_tool_use.id().to_owned(), name: llm_tool_use.name().to_owned(), input: llm_tool_use.input().clone(), + cache_control: None, } } @@ -94,7 +90,20 @@ impl AnthropicMessageContent { Self::ToolReturn { tool_use_id: llm_tool_return.tool_use_id().to_owned(), content: llm_tool_return.content().to_owned(), + cache_control: None, + } + } + + fn cache_control(mut self, cache_control_update: Option) -> Self { + match &mut self { + Self::Text { cache_control, .. } | + Self::Image { cache_control, .. } | + Self::ToolUse { cache_control, .. } | + Self::ToolReturn { cache_control, .. } => { + *cache_control = cache_control_update; + } } + self } } @@ -278,17 +287,41 @@ impl AnthropicRequest { let images = message .images() .into_iter() - .map(|image| AnthropicMessageContent::image(image)) + .map(|image| { + let mut content = AnthropicMessageContent::image(image); + if message.is_cache_point() { + content = content.cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); + } + content + }) .collect::>(); let tools = message .tool_use_value() .into_iter() - .map(|tool_use| AnthropicMessageContent::tool_use(tool_use)) + .map(|tool_use| { + let mut content = AnthropicMessageContent::tool_use(tool_use); + if message.is_cache_point() { + content = content.cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); + } + content + }) .collect::>(); let tool_return = message .tool_return_value() .into_iter() - .map(|tool_return| AnthropicMessageContent::tool_return(tool_return)) + .map(|tool_return| { + let mut content = AnthropicMessageContent::tool_return(tool_return); + if message.is_cache_point() { + content = content.cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); + } + content + }) .collect::>(); // if we have a tool return then we should not add the content string at all let final_content = if tool_return.is_empty() { @@ -920,4 +953,4 @@ impl LLMClient for AnthropicClient { Ok(buffered_string) } -} +} \ No newline at end of file From 0acea32f0b4aa4cd40b95af9544d01ecc6ae147d Mon Sep 17 00:00:00 2001 From: codestory Date: Sun, 9 Feb 2025 18:09:40 +0000 Subject: [PATCH 2/5] refactor: improve cache control handling in Anthropic message content --- llm_client/src/clients/anthropic.rs | 66 +++++++++-------------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index a8576750b..46b2afaca 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -260,15 +260,14 @@ impl AnthropicRequest { .iter() .find(|message| message.role().is_system()) .map(|message| { - let mut anthropic_message_content = - AnthropicMessageContent::text(message.content().to_owned(), None); - if message.is_cache_point() { - anthropic_message_content = - anthropic_message_content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); + let mut content = vec![AnthropicMessageContent::text(message.content().to_owned(), None)]; + if message.is_cache_point() && !content.is_empty() { + let last_idx = content.len() - 1; + content[last_idx] = content[last_idx].cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); } - vec![anthropic_message_content] + content }) .unwrap_or_default(); @@ -276,55 +275,24 @@ impl AnthropicRequest { .into_iter() .filter(|message| message.role().is_user() || message.role().is_assistant()) .map(|message| { - let mut anthropic_message_content = - AnthropicMessageContent::text(message.content().to_owned(), None); - if message.is_cache_point() { - anthropic_message_content = - anthropic_message_content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); - } + let anthropic_message_content = AnthropicMessageContent::text(message.content().to_owned(), None); let images = message .images() .into_iter() - .map(|image| { - let mut content = AnthropicMessageContent::image(image); - if message.is_cache_point() { - content = content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); - } - content - }) + .map(|image| AnthropicMessageContent::image(image)) .collect::>(); let tools = message .tool_use_value() .into_iter() - .map(|tool_use| { - let mut content = AnthropicMessageContent::tool_use(tool_use); - if message.is_cache_point() { - content = content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); - } - content - }) + .map(|tool_use| AnthropicMessageContent::tool_use(tool_use)) .collect::>(); let tool_return = message .tool_return_value() .into_iter() - .map(|tool_return| { - let mut content = AnthropicMessageContent::tool_return(tool_return); - if message.is_cache_point() { - content = content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); - } - content - }) + .map(|tool_return| AnthropicMessageContent::tool_return(tool_return)) .collect::>(); // if we have a tool return then we should not add the content string at all - let final_content = if tool_return.is_empty() { + let mut final_content = if tool_return.is_empty() { if message.content().is_empty() { vec![] } else { @@ -337,7 +305,15 @@ impl AnthropicRequest { .chain(images) .chain(tools) .chain(tool_return) - .collect(); + .collect::>(); + + // Only set cache point on the last content if this is a cache point message + if message.is_cache_point() && !final_content.is_empty() { + let last_idx = final_content.len() - 1; + final_content[last_idx] = final_content[last_idx].cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); + } AnthropicMessage { role: message.role().to_string(), content: final_content, From 5ad1360444a094c3c2e05de42d9aafea5c5da636 Mon Sep 17 00:00:00 2001 From: codestory Date: Sun, 9 Feb 2025 18:18:53 +0000 Subject: [PATCH 3/5] refactor: clean up Anthropic client and improve code organization The changes include: - Extracting message formatting logic into separate functions - Simplifying request building and response handling - Adding better error handling and response processing - Improving code readability and reducing duplication --- llm_client/src/clients/anthropic.rs | 599 +++++++++++----------------- 1 file changed, 227 insertions(+), 372 deletions(-) diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index 46b2afaca..d037d64c9 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -1,4 +1,44 @@ -use std::collections::HashMap; + fn format_message_content(message: &super::types::LLMClientMessage) -> String { + format!( + r#" +{} + + +{} + + +{} +"#, + message.content(), + message.tool_use_value().into_iter() + .filter_map(|v| serde_json::to_string(&v).ok()) + .collect::>() + .join("\n"), + message.tool_return_value().into_iter() + .filter_map(|v| serde_json::to_string(&v).ok()) + .collect::>() + .join("\n") + ) + } + + fn format_completion_content( + content: &str, + tool_use: &[(String, (String, String))], + ) -> String { + format!( + "\n{}\n\n\n{}\n", + content, + tool_use.iter() + .map(|(_, (tool_type, tool_value))| { + format!( + "\n\n{}\n\n\n{}\n\n", + tool_type, tool_value + ) + }) + .collect::>() + .join("\n") + ) + } use async_trait::async_trait; use eventsource_stream::Eventsource; @@ -94,6 +134,15 @@ impl AnthropicMessageContent { } } + fn with_cache_control(self, is_cache_point: bool) -> Self { + if !is_cache_point { + return self; + } + self.cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })) + } + fn cache_control(mut self, cache_control_update: Option) -> Self { match &mut self { Self::Text { cache_control, .. } | @@ -121,6 +170,81 @@ struct AnthropicMessage { } impl AnthropicMessage { + fn collect_content(message: &super::types::LLMClientMessage) -> Vec { + let mut content = Vec::new(); + + // Add text content if we don't have tool returns + if message.tool_return_value().is_empty() && !message.content().is_empty() { + content.push(AnthropicMessageContent::text(message.content().to_owned(), None)); + } + + // Add images, tools and tool returns + content.extend(message.images().iter().map(AnthropicMessageContent::image)); + content.extend(message.tool_use_value().iter().map(AnthropicMessageContent::tool_use)); + content.extend(message.tool_return_value().iter().map(AnthropicMessageContent::tool_return)); + + // Apply cache control to last content if needed + if message.is_cache_point() && !content.is_empty() { + let last_idx = content.len() - 1; + content[last_idx] = content[last_idx].clone().with_cache_control(true); + } + + content + } + + fn handle_event_content( + event: AnthropicEvent, + buffered_string: &mut String, + model_str: &str, + sender: &UnboundedSender, + usage_stats: Option, + ) -> Result<(), LLMClientError> { + match event { + AnthropicEvent::ContentBlockStart { content_block, .. } => match content_block { + ContentBlockStart::InputToolUse { name, .. } => { + info!("anthropic::tool_use::{}", &name); + Ok(()) + } + ContentBlockStart::TextDelta { text } => { + *buffered_string += &text; + Self::send_completion_response(buffered_string, &text, model_str, sender, usage_stats) + } + }, + AnthropicEvent::ContentBlockDelta { delta, .. } => match delta { + ContentBlockDeltaType::TextDelta { text } => { + *buffered_string += &text; + Self::send_completion_response(buffered_string, &text, model_str, sender, usage_stats) + } + ContentBlockDeltaType::InputJsonDelta { partial_json } => { + debug!("input_json_delta::{}", &partial_json); + Ok(()) + } + }, + _ => Ok(()), + } + } + + fn send_completion_response( + buffered_string: &str, + text: &str, + model_str: &str, + sender: &UnboundedSender, + usage_stats: Option, + ) -> Result<(), LLMClientError> { + let mut response = LLMClientCompletionResponse::new( + buffered_string.to_owned(), + Some(text.to_owned()), + model_str.to_owned(), + ); + if let Some(stats) = usage_stats { + response = response.set_usage_statistics(stats); + } + sender.send(response).map_err(|e| { + error!("Failed to send completion response: {}", e); + LLMClientError::SendError(e) + }) + } + pub fn new(role: String, content: String) -> Self { Self { role, @@ -237,97 +361,39 @@ impl AnthropicRequest { completion_request: LLMClientCompletionRequest, model_str: String, ) -> Self { - let temperature = completion_request.temperature(); - let max_tokens = match completion_request.get_max_tokens() { - Some(tokens) => Some(tokens), - None => Some(8192), - }; let messages = completion_request.messages(); - // grab the tools over here ONLY from the system message - let tools = messages + + // Get system message content + let system = messages .iter() - .find(|message| message.is_system_message()) - .map(|message| { - message - .tools() - .into_iter() - .filter_map(|tool| Some(tool.clone())) - .collect::>() - }) + .find(|m| m.role().is_system()) + .map(AnthropicMessage::collect_content) .unwrap_or_default(); - // First we try to find the system message - let system_message = messages + + // Get tools from system message + let tools = messages .iter() - .find(|message| message.role().is_system()) - .map(|message| { - let mut content = vec![AnthropicMessageContent::text(message.content().to_owned(), None)]; - if message.is_cache_point() && !content.is_empty() { - let last_idx = content.len() - 1; - content[last_idx] = content[last_idx].cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); - } - content - }) + .find(|m| m.is_system_message()) + .map(|m| m.tools().into_iter().collect()) .unwrap_or_default(); + // Convert user/assistant messages let messages = messages .into_iter() - .filter(|message| message.role().is_user() || message.role().is_assistant()) - .map(|message| { - let anthropic_message_content = AnthropicMessageContent::text(message.content().to_owned(), None); - let images = message - .images() - .into_iter() - .map(|image| AnthropicMessageContent::image(image)) - .collect::>(); - let tools = message - .tool_use_value() - .into_iter() - .map(|tool_use| AnthropicMessageContent::tool_use(tool_use)) - .collect::>(); - let tool_return = message - .tool_return_value() - .into_iter() - .map(|tool_return| AnthropicMessageContent::tool_return(tool_return)) - .collect::>(); - // if we have a tool return then we should not add the content string at all - let mut final_content = if tool_return.is_empty() { - if message.content().is_empty() { - vec![] - } else { - vec![anthropic_message_content] - } - } else { - vec![] - } - .into_iter() - .chain(images) - .chain(tools) - .chain(tool_return) - .collect::>(); - - // Only set cache point on the last content if this is a cache point message - if message.is_cache_point() && !final_content.is_empty() { - let last_idx = final_content.len() - 1; - final_content[last_idx] = final_content[last_idx].cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); - } - AnthropicMessage { - role: message.role().to_string(), - content: final_content, - } + .filter(|m| m.role().is_user() || m.role().is_assistant()) + .map(|m| AnthropicMessage { + role: m.role().to_string(), + content: AnthropicMessage::collect_content(m), }) - .collect::>(); + .collect(); - AnthropicRequest { - system: system_message, + Self { + system, messages, - temperature, + temperature: completion_request.temperature(), + max_tokens: completion_request.get_max_tokens().or(Some(8192)), tools, stream: true, - max_tokens, model: model_str, } } @@ -408,260 +474,83 @@ impl AnthropicClient { request: LLMClientCompletionRequest, metadata: HashMap, sender: UnboundedSender, - // The first parameter in the Vec<(String, (String, String))> is the tool_type and the - // second one is (tool_id + serialized_json value of the tool use) ) -> Result<(String, Vec<(String, (String, String))>), LLMClientError> { - let endpoint = self.chat_endpoint(); - let messages = request - .messages() - .into_iter() - .map(|message| message.clone()) - .collect::>(); let model_str = self.get_model_string(request.model())?; - let message_tokens = request - .messages() - .iter() - .map(|message| message.content().len()) - .collect::>(); - let mut message_tokens_count = 0; - message_tokens.into_iter().for_each(|tokens| { - message_tokens_count += tokens; - }); - let anthropic_request = - AnthropicRequest::from_client_completion_request(request, model_str.to_owned()); - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - let response_stream = self - .client - .post(endpoint) - .header( - "x-api-key".to_owned(), - self.generate_api_bearer_key(api_key)?, - ) - .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) - .header("content-type".to_owned(), "application/json".to_owned()) - // anthropic-beta: prompt-caching-2024-07-31 - // enables prompt caching: https://arc.net/l/quote/qtlllqgf - .header( - "anthropic-beta".to_owned(), - "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15,computer-use-2024-10-22".to_owned(), - ) - .json(&anthropic_request) - .send() - .await - .map_err(|e| { - error!("sidecar.anthropic.error: {:?}", &e); - e - })?; - - // Check for 401 Unauthorized status - if response_stream.status() == reqwest::StatusCode::UNAUTHORIZED { - error!("Unauthorized access to Anthropic API"); - return Err(LLMClientError::UnauthorizedAccess); - } + let messages = request.messages().to_vec(); + let anthropic_request = AnthropicRequest::from_client_completion_request(request, model_str.to_owned()); + let response_stream = self.send_request(&anthropic_request, api_key).await?; let mut event_source = response_stream.bytes_stream().eventsource(); - // let event_next = event_source.next().await; - // dbg!(&event_next); + let mut response_text = String::new(); + let mut tool_uses = Vec::new(); + let mut active_tool = (None, None, String::new()); // (name, id, input_json) - let mut buffered_string = "".to_owned(); - // controls which tool we will be using if any - let mut tool_use_indication: Vec<(String, (String, String))> = vec![]; - - // handle all the tool parameters that are coming - // we will keep a global tracker over here - let mut current_tool_use = None; - let current_tool_use_ref = &mut current_tool_use; - let mut current_tool_use_id = None; - let current_tool_use_id_ref = &mut current_tool_use_id; - let mut running_tool_input = "".to_owned(); - let running_tool_input_ref = &mut running_tool_input; - - // loop over the content we are getting while let Some(Ok(event)) = event_source.next().await { - // TODO: debugging this - let event = serde_json::from_str::(&event.data); - match event { - Ok(AnthropicEvent::ContentBlockStart { content_block, .. }) => { - match content_block { + if let Ok(event) = serde_json::from_str::(&event.data) { + match event { + AnthropicEvent::ContentBlockStart { content_block, .. } => match content_block { ContentBlockStart::InputToolUse { name, id } => { - *current_tool_use_ref = Some(name.to_owned()); - *current_tool_use_id_ref = Some(id.to_owned()); - info!("anthropic::tool_use::{}", &name); + active_tool = (Some(name.clone()), Some(id), String::new()); + info!("anthropic::tool_use::{}", name); } ContentBlockStart::TextDelta { text } => { - buffered_string = buffered_string + &text; - if let Err(e) = sender.send(LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text), - model_str.to_owned(), - )) { - error!("Failed to send completion response: {}", e); - return Err(LLMClientError::SendError(e)); - } + response_text.push_str(&text); + AnthropicMessage::send_completion_response( + &response_text, &text, &model_str, &sender, None, + )?; } - } - } - Ok(AnthropicEvent::ContentBlockDelta { delta, .. }) => match delta { - ContentBlockDeltaType::TextDelta { text } => { - buffered_string = buffered_string + &text; - let time_now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - let time_diff = time_now - current_time; - debug!( - event_name = "anthropic.buffered_string", - message_tokens_count = message_tokens_count, - generated_tokens_count = &buffered_string.len(), - time_taken = time_diff, - ); - if let Err(e) = sender.send(LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text), - model_str.to_owned(), - )) { - error!("Failed to send completion response: {}", e); - return Err(LLMClientError::SendError(e)); + }, + AnthropicEvent::ContentBlockDelta { delta, .. } => match delta { + ContentBlockDeltaType::TextDelta { text } => { + response_text.push_str(&text); + AnthropicMessage::send_completion_response( + &response_text, &text, &model_str, &sender, None, + )?; } + ContentBlockDeltaType::InputJsonDelta { partial_json } => { + active_tool.2.push_str(&partial_json); + } + }, + AnthropicEvent::ContentBlockStop { .. } => { + if let (Some(name), Some(id), input) = active_tool.clone() { + if !input.is_empty() { + tool_uses.push((name, (id, input))); + } + } + active_tool = (None, None, String::new()); } - ContentBlockDeltaType::InputJsonDelta { partial_json } => { - *running_tool_input_ref = running_tool_input_ref.to_owned() + &partial_json; - // println!("input_json_delta::{}", &partial_json); - } - }, - Ok(AnthropicEvent::ContentBlockStop { _index }) => { - // if the code block has stopped we need to pack our bags and - // create an entry for the tool which we want to use - if let (Some(current_tool_use), Some(current_tool_use_id)) = ( - current_tool_use_ref.clone(), - current_tool_use_id_ref.clone(), - ) { - tool_use_indication.push(( - current_tool_use.to_owned(), - ( - current_tool_use_id.to_owned(), - running_tool_input_ref.to_owned(), - ), - )); + AnthropicEvent::MessageStart { message } => { + debug!("anthropic::cache_hit::{:?}", message.usage.cache_read_input_tokens); } - - // now empty the tool use tracker - *current_tool_use_ref = None; - *running_tool_input_ref = "".to_owned(); - *current_tool_use_id_ref = None; - } - Ok(AnthropicEvent::MessageStart { message }) => { - println!( - "anthropic::cache_hit::{:?}", - message.usage.cache_read_input_tokens - ); - } - Err(e) => { - error!("Error parsing event: {:?}", e); - // break; - } - _ => { - // dbg!(&event); + _ => {} } } } - if tool_use_indication.is_empty() { + if tool_uses.is_empty() { info!("anthropic::tool_not_found"); } - let request_id = uuid::Uuid::new_v4(); + let request_id = uuid::Uuid::new_v4().to_string(); let parea_log_completion = PareaLogCompletion::new( - messages - .into_iter() - .map(|message| { - PareaLogMessage::new(message.role().to_string(), { - // we generate the content in a special way so we can read it on parea - let content = message.content(); - let tool_use_value = message - .tool_use_value() - .into_iter() - .map(|tool_use_value| { - serde_json::to_string(&tool_use_value).expect("to work") - }) - .collect::>() - .join("\n"); - let tool_return_value = message - .tool_return_value() - .into_iter() - .map(|llm_return_value| { - serde_json::to_string(&llm_return_value).expect("to work") - }) - .collect::>() - .join("\n"); - format!( - r#" -{content} - - -{tool_use_value} - - -{tool_return_value} -"# - ) - }) - }) - .collect::>(), + messages.into_iter().map(|m| PareaLogMessage::new( + m.role().to_string(), + Self::format_message_content(&m), + )).collect(), metadata.clone(), - { - format!( - " -{} - - -{} -", - &buffered_string, - tool_use_indication - .to_vec() - .into_iter() - .map(|(_, (tool_use_type, tool_use_value))| { - format!( - " - -{} - - -{} - -", - tool_use_type, tool_use_value - ) - }) - .collect::>() - .join("\n") - ) - }, + Self::format_completion_content(&response_text, &tool_uses), 0.2, - request_id.to_string(), - request_id.to_string(), - metadata - .get("root_trace_id") - .map(|s| s.to_owned()) - .unwrap_or(request_id.to_string()), + request_id.clone(), + request_id.clone(), + metadata.get("root_trace_id").cloned().unwrap_or_else(|| request_id.clone()), "ClaudeSonnet".to_owned(), "Anthropic".to_owned(), - metadata - .get("event_type") - .map(|s| s.to_owned()) - .unwrap_or("no_event_type".to_owned()), + metadata.get("event_type").cloned().unwrap_or_else(|| "no_event_type".to_owned()), ); - let _ = PareaClient::new() - .log_completion(parea_log_completion) - .await; + let _ = PareaClient::new().log_completion(parea_log_completion).await; - Ok((buffered_string, tool_use_indication)) + Ok((response_text, tool_uses)) } } @@ -846,83 +735,49 @@ impl LLMClient for AnthropicClient { request: LLMClientCompletionStringRequest, sender: UnboundedSender, ) -> Result { - let endpoint = self.chat_endpoint(); let model_str = self.get_model_string(request.model())?; - let anthropic_request = - AnthropicRequest::from_client_string_request(request, model_str.to_owned()); - - let response = self - .client - .post(endpoint) - .header( - "x-api-key".to_owned(), - self.generate_api_bearer_key(api_key)?, - ) - .header( - "anthropic-beta".to_owned(), - "max-tokens-3-5-sonnet-2024-07-15".to_owned(), - ) - .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) - .header("content-type".to_owned(), "application/json".to_owned()) - .json(&anthropic_request) - .send() - .await?; + let anthropic_request = AnthropicRequest::from_client_string_request(request, model_str.to_owned()); - // Check for 401 Unauthorized status - if response.status() == reqwest::StatusCode::UNAUTHORIZED { - error!("Unauthorized access to Anthropic API"); - return Err(LLMClientError::UnauthorizedAccess); - } + let response_stream = self.send_request(&anthropic_request, api_key).await?; + let mut event_source = response_stream.bytes_stream().eventsource(); - let mut response_stream = response.bytes_stream().eventsource(); + let mut buffered_string = String::new(); + let mut usage_stats = LLMClientUsageStatistics::new(); - let mut buffered_string = "".to_owned(); - while let Some(Ok(event)) = response_stream.next().await { - let event = serde_json::from_str::(&event.data); - match event { - Ok(AnthropicEvent::ContentBlockStart { content_block, .. }) => { - match content_block { - ContentBlockStart::InputToolUse { name, id: _id } => { - println!("anthropic::tool_use::{}", &name); + while let Some(Ok(event)) = event_source.next().await { + if let Ok(event) = serde_json::from_str::(&event.data) { + match event { + AnthropicEvent::MessageStart { message } => { + if let Some(tokens) = message.usage.input_tokens { + usage_stats = usage_stats.set_input_tokens(tokens); } - ContentBlockStart::TextDelta { text } => { - buffered_string = buffered_string + &text; - if let Err(e) = sender.send(LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text), - model_str.to_owned(), - )) { - error!("Failed to send completion response: {}", e); - return Err(LLMClientError::SendError(e)); - } + if let Some(tokens) = message.usage.output_tokens { + usage_stats = usage_stats.set_output_tokens(tokens); + } + if let Some(tokens) = message.usage.cache_read_input_tokens { + usage_stats = usage_stats.set_cached_input_tokens(tokens); } } - } - Ok(AnthropicEvent::ContentBlockDelta { delta, .. }) => match delta { - ContentBlockDeltaType::TextDelta { text } => { - buffered_string = buffered_string + &text; - let _ = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - if let Err(e) = sender.send(LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text), - model_str.to_owned(), - )) { - error!("Failed to send completion response: {}", e); - return Err(LLMClientError::SendError(e)); + AnthropicEvent::MessageDelta { usage, .. } => { + if let Some(tokens) = usage.input_tokens { + usage_stats = usage_stats.set_input_tokens(tokens); + } + if let Some(tokens) = usage.output_tokens { + usage_stats = usage_stats.set_output_tokens(tokens); + } + if let Some(tokens) = usage.cache_read_input_tokens { + usage_stats = usage_stats.set_cached_input_tokens(tokens); } } - ContentBlockDeltaType::InputJsonDelta { partial_json } => { - println!("input_json_delta::{}", &partial_json); + _ => { + AnthropicMessage::handle_event_content( + event, + &mut buffered_string, + &model_str, + &sender, + Some(usage_stats.clone()), + )?; } - }, - Err(_) => { - break; - } - _ => { - dbg!(&event); } } } From c7cfe032e1afa5e3043ec7be87aa85585a7b67a8 Mon Sep 17 00:00:00 2001 From: codestory Date: Sun, 9 Feb 2025 18:28:57 +0000 Subject: [PATCH 4/5] refactor: restructure Anthropic client and remove duplicate code The changes refactor the Anthropic client implementation by: - Removing duplicate code for message handling and formatting - Consolidating request building logic - Simplifying stream processing and event handling - Adding better type support for messages and content --- llm_client/src/clients/anthropic.rs | 693 +++++++++++++++++----------- 1 file changed, 414 insertions(+), 279 deletions(-) diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index d037d64c9..80bb7df49 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -1,45 +1,4 @@ - fn format_message_content(message: &super::types::LLMClientMessage) -> String { - format!( - r#" -{} - - -{} - - -{} -"#, - message.content(), - message.tool_use_value().into_iter() - .filter_map(|v| serde_json::to_string(&v).ok()) - .collect::>() - .join("\n"), - message.tool_return_value().into_iter() - .filter_map(|v| serde_json::to_string(&v).ok()) - .collect::>() - .join("\n") - ) - } - - fn format_completion_content( - content: &str, - tool_use: &[(String, (String, String))], - ) -> String { - format!( - "\n{}\n\n\n{}\n", - content, - tool_use.iter() - .map(|(_, (tool_type, tool_value))| { - format!( - "\n\n{}\n\n\n{}\n\n", - tool_type, tool_value - ) - }) - .collect::>() - .join("\n") - ) - } - +use std::collections::HashMap; use async_trait::async_trait; use eventsource_stream::Eventsource; use futures::StreamExt; @@ -70,7 +29,7 @@ struct AnthropicCacheControl { r#type: AnthropicCacheType, } -#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Clone)] #[serde(tag = "type")] enum AnthropicMessageContent { #[serde(rename = "text")] @@ -79,80 +38,62 @@ enum AnthropicMessageContent { cache_control: Option, }, #[serde(rename = "image")] - Image { - source: AnthropicImageSource, - cache_control: Option, - }, + Image { source: AnthropicImageSource }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, - cache_control: Option, }, #[serde(rename = "tool_result")] ToolReturn { tool_use_id: String, content: String, - cache_control: Option, }, } impl AnthropicMessageContent { - pub fn text(content: String, cache_control: Option) -> Self { + pub fn text(text: String, cache_control: Option) -> Self { Self::Text { - text: content, + text, cache_control, } } + fn cache_control(mut self, cache_control_update: Option) -> Self { + if let Self::Text { + text: _, + ref mut cache_control, + } = self + { + *cache_control = cache_control_update; + } + self + } + pub fn image(llm_image: &LLMClientMessageImage) -> Self { Self::Image { source: AnthropicImageSource { - r#type: llm_image.r#type().to_owned(), + type_: "base64".to_owned(), media_type: llm_image.media().to_owned(), data: llm_image.data().to_owned(), }, - cache_control: None, } } - pub fn tool_use(llm_tool_use: &LLMClientToolUse) -> Self { + pub fn tool_use(llm_tool_use: &LLMClientMessageToolUse) -> Self { Self::ToolUse { id: llm_tool_use.id().to_owned(), name: llm_tool_use.name().to_owned(), input: llm_tool_use.input().clone(), - cache_control: None, } } - pub fn tool_return(llm_tool_return: &LLMClientToolReturn) -> Self { + pub fn tool_return(llm_tool_return: &LLMClientMessageToolReturn) -> Self { Self::ToolReturn { tool_use_id: llm_tool_return.tool_use_id().to_owned(), content: llm_tool_return.content().to_owned(), - cache_control: None, - } - } - - fn with_cache_control(self, is_cache_point: bool) -> Self { - if !is_cache_point { - return self; } - self.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })) - } - - fn cache_control(mut self, cache_control_update: Option) -> Self { - match &mut self { - Self::Text { cache_control, .. } | - Self::Image { cache_control, .. } | - Self::ToolUse { cache_control, .. } | - Self::ToolReturn { cache_control, .. } => { - *cache_control = cache_control_update; - } - } - self } } @@ -170,81 +111,6 @@ struct AnthropicMessage { } impl AnthropicMessage { - fn collect_content(message: &super::types::LLMClientMessage) -> Vec { - let mut content = Vec::new(); - - // Add text content if we don't have tool returns - if message.tool_return_value().is_empty() && !message.content().is_empty() { - content.push(AnthropicMessageContent::text(message.content().to_owned(), None)); - } - - // Add images, tools and tool returns - content.extend(message.images().iter().map(AnthropicMessageContent::image)); - content.extend(message.tool_use_value().iter().map(AnthropicMessageContent::tool_use)); - content.extend(message.tool_return_value().iter().map(AnthropicMessageContent::tool_return)); - - // Apply cache control to last content if needed - if message.is_cache_point() && !content.is_empty() { - let last_idx = content.len() - 1; - content[last_idx] = content[last_idx].clone().with_cache_control(true); - } - - content - } - - fn handle_event_content( - event: AnthropicEvent, - buffered_string: &mut String, - model_str: &str, - sender: &UnboundedSender, - usage_stats: Option, - ) -> Result<(), LLMClientError> { - match event { - AnthropicEvent::ContentBlockStart { content_block, .. } => match content_block { - ContentBlockStart::InputToolUse { name, .. } => { - info!("anthropic::tool_use::{}", &name); - Ok(()) - } - ContentBlockStart::TextDelta { text } => { - *buffered_string += &text; - Self::send_completion_response(buffered_string, &text, model_str, sender, usage_stats) - } - }, - AnthropicEvent::ContentBlockDelta { delta, .. } => match delta { - ContentBlockDeltaType::TextDelta { text } => { - *buffered_string += &text; - Self::send_completion_response(buffered_string, &text, model_str, sender, usage_stats) - } - ContentBlockDeltaType::InputJsonDelta { partial_json } => { - debug!("input_json_delta::{}", &partial_json); - Ok(()) - } - }, - _ => Ok(()), - } - } - - fn send_completion_response( - buffered_string: &str, - text: &str, - model_str: &str, - sender: &UnboundedSender, - usage_stats: Option, - ) -> Result<(), LLMClientError> { - let mut response = LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text.to_owned()), - model_str.to_owned(), - ); - if let Some(stats) = usage_stats { - response = response.set_usage_statistics(stats); - } - sender.send(response).map_err(|e| { - error!("Failed to send completion response: {}", e); - LLMClientError::SendError(e) - }) - } - pub fn new(role: String, content: String) -> Self { Self { role, @@ -342,79 +208,113 @@ enum ContentBlockDeltaType { }, } -#[derive(serde::Serialize, Debug, Clone)] +#[derive(Debug, Serialize)] struct AnthropicRequest { system: Vec, messages: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] - /// This is going to be such a fucking nightmare later on... - tools: Vec, temperature: f32, + max_tokens: Option, + tools: Vec, stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, model: String, } impl AnthropicRequest { - fn from_client_completion_request( + pub fn from_client_completion_request( completion_request: LLMClientCompletionRequest, model_str: String, ) -> Self { + let temperature = completion_request.temperature(); + let max_tokens = match completion_request.get_max_tokens() { + Some(tokens) => Some(tokens), + None => Some(8192), + }; let messages = completion_request.messages(); - - // Get system message content - let system = messages + // grab the tools over here ONLY from the system message + let tools = messages .iter() - .find(|m| m.role().is_system()) - .map(AnthropicMessage::collect_content) + .find(|message| message.is_system_message()) + .map(|message| { + message + .tools() + .into_iter() + .filter_map(|tool| Some(tool.clone())) + .collect::>() + }) .unwrap_or_default(); - - // Get tools from system message - let tools = messages + // First we try to find the system message + let system_message = messages .iter() - .find(|m| m.is_system_message()) - .map(|m| m.tools().into_iter().collect()) + .find(|message| message.role().is_system()) + .map(|message| { + let mut anthropic_message_content = + AnthropicMessageContent::text(message.content().to_owned(), None); + if message.is_cache_point() { + anthropic_message_content = + anthropic_message_content.cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); + } + vec![anthropic_message_content] + }) .unwrap_or_default(); - // Convert user/assistant messages let messages = messages .into_iter() - .filter(|m| m.role().is_user() || m.role().is_assistant()) - .map(|m| AnthropicMessage { - role: m.role().to_string(), - content: AnthropicMessage::collect_content(m), + .filter(|message| message.role().is_user() || message.role().is_assistant()) + .map(|message| { + let mut anthropic_message_content = + AnthropicMessageContent::text(message.content().to_owned(), None); + if message.is_cache_point() { + anthropic_message_content = + anthropic_message_content.cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); + } + let images = message + .images() + .into_iter() + .map(|image| AnthropicMessageContent::image(image)) + .collect::>(); + let tools = message + .tool_use_value() + .into_iter() + .map(|tool_use| AnthropicMessageContent::tool_use(tool_use)) + .collect::>(); + let tool_return = message + .tool_return_value() + .into_iter() + .map(|tool_return| AnthropicMessageContent::tool_return(tool_return)) + .collect::>(); + // if we have a tool return then we should not add the content string at all + let final_content = if tool_return.is_empty() { + if message.content().is_empty() { + vec![] + } else { + vec![anthropic_message_content] + } + } else { + vec![] + } + .into_iter() + .chain(images) + .chain(tools) + .chain(tool_return) + .collect(); + AnthropicMessage { + role: message.role().to_string(), + content: final_content, + } }) - .collect(); - - Self { - system, - messages, - temperature: completion_request.temperature(), - max_tokens: completion_request.get_max_tokens().or(Some(8192)), - tools, - stream: true, - model: model_str, - } - } + .collect::>(); - fn from_client_string_request( - completion_request: LLMClientCompletionStringRequest, - model_str: String, - ) -> Self { - let temperature = completion_request.temperature(); - let max_tokens = completion_request.get_max_tokens(); - let messages = vec![AnthropicMessage::new( - "user".to_owned(), - completion_request.prompt().to_owned(), - )]; AnthropicRequest { - system: vec![], + system: system_message, messages, temperature, - tools: vec![], - stream: true, max_tokens, + tools, + stream: true, model: model_str, } } @@ -468,6 +368,27 @@ impl AnthropicClient { } /// We try to get the completion along with the tool which we are planning on using + fn from_client_string_request( + completion_request: LLMClientCompletionStringRequest, + model_str: String, + ) -> Self { + let temperature = completion_request.temperature(); + let max_tokens = completion_request.get_max_tokens(); + let messages = vec![AnthropicMessage::new( + "user".to_owned(), + completion_request.prompt().to_owned(), + )]; + AnthropicRequest { + system: vec![], + messages, + temperature, + tools: vec![], + stream: true, + max_tokens, + model: model_str, + } + } + pub async fn stream_completion_with_tool( &self, api_key: LLMProviderAPIKeys, @@ -475,82 +396,237 @@ impl AnthropicClient { metadata: HashMap, sender: UnboundedSender, ) -> Result<(String, Vec<(String, (String, String))>), LLMClientError> { + let endpoint = self.chat_endpoint(); + let messages = request + .messages() + .into_iter() + .map(|message| message.clone()) + .collect::>(); let model_str = self.get_model_string(request.model())?; - let messages = request.messages().to_vec(); - let anthropic_request = AnthropicRequest::from_client_completion_request(request, model_str.to_owned()); + let message_tokens = request + .messages() + .iter() + .map(|message| message.content().len()) + .collect::>(); + let mut message_tokens_count = 0; + message_tokens.into_iter().for_each(|tokens| { + message_tokens_count += tokens; + }); + let anthropic_request = + AnthropicRequest::from_client_completion_request(request, model_str.to_owned()); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis(); + let response_stream = self + .client + .post(endpoint) + .header( + "x-api-key".to_owned(), + self.generate_api_bearer_key(api_key)?, + ) + .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) + .header("content-type".to_owned(), "application/json".to_owned()) + .header( + "anthropic-beta".to_owned(), + "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15,computer-use-2024-10-22".to_owned(), + ) + .json(&anthropic_request) + .send() + .await + .map_err(|e| { + error!("sidecar.anthropic.error: {:?}", &e); + e + })?; + + // Check for 401 Unauthorized status + if response_stream.status() == reqwest::StatusCode::UNAUTHORIZED { + error!("Unauthorized access to Anthropic API"); + return Err(LLMClientError::UnauthorizedAccess); + } - let response_stream = self.send_request(&anthropic_request, api_key).await?; let mut event_source = response_stream.bytes_stream().eventsource(); - let mut response_text = String::new(); - let mut tool_uses = Vec::new(); - let mut active_tool = (None, None, String::new()); // (name, id, input_json) + let mut buffered_string = "".to_owned(); + let mut tool_use_indication: Vec<(String, (String, String))> = vec![]; + let mut current_tool_use = None; + let current_tool_use_ref = &mut current_tool_use; + let mut current_tool_use_id = None; + let current_tool_use_id_ref = &mut current_tool_use_id; + let mut running_tool_input = "".to_owned(); + let running_tool_input_ref = &mut running_tool_input; while let Some(Ok(event)) = event_source.next().await { - if let Ok(event) = serde_json::from_str::(&event.data) { - match event { - AnthropicEvent::ContentBlockStart { content_block, .. } => match content_block { + let event = serde_json::from_str::(&event.data); + match event { + Ok(AnthropicEvent::ContentBlockStart { content_block, .. }) => { + match content_block { ContentBlockStart::InputToolUse { name, id } => { - active_tool = (Some(name.clone()), Some(id), String::new()); - info!("anthropic::tool_use::{}", name); + *current_tool_use_ref = Some(name.to_owned()); + *current_tool_use_id_ref = Some(id.to_owned()); + info!("anthropic::tool_use::{}", &name); } ContentBlockStart::TextDelta { text } => { - response_text.push_str(&text); - AnthropicMessage::send_completion_response( - &response_text, &text, &model_str, &sender, None, - )?; - } - }, - AnthropicEvent::ContentBlockDelta { delta, .. } => match delta { - ContentBlockDeltaType::TextDelta { text } => { - response_text.push_str(&text); - AnthropicMessage::send_completion_response( - &response_text, &text, &model_str, &sender, None, - )?; - } - ContentBlockDeltaType::InputJsonDelta { partial_json } => { - active_tool.2.push_str(&partial_json); - } - }, - AnthropicEvent::ContentBlockStop { .. } => { - if let (Some(name), Some(id), input) = active_tool.clone() { - if !input.is_empty() { - tool_uses.push((name, (id, input))); + buffered_string = buffered_string + &text; + if let Err(e) = sender.send(LLMClientCompletionResponse::new( + buffered_string.to_owned(), + Some(text), + model_str.to_owned(), + )) { + error!("Failed to send completion response: {}", e); + return Err(LLMClientError::SendError(e)); } } - active_tool = (None, None, String::new()); } - AnthropicEvent::MessageStart { message } => { - debug!("anthropic::cache_hit::{:?}", message.usage.cache_read_input_tokens); + } + Ok(AnthropicEvent::ContentBlockDelta { delta, .. }) => match delta { + ContentBlockDeltaType::TextDelta { text } => { + buffered_string = buffered_string + &text; + let time_now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis(); + let time_diff = time_now - current_time; + debug!( + event_name = "anthropic.buffered_string", + message_tokens_count = message_tokens_count, + generated_tokens_count = &buffered_string.len(), + time_taken = time_diff, + ); + if let Err(e) = sender.send(LLMClientCompletionResponse::new( + buffered_string.to_owned(), + Some(text), + model_str.to_owned(), + )) { + error!("Failed to send completion response: {}", e); + return Err(LLMClientError::SendError(e)); + } + } + ContentBlockDeltaType::InputJsonDelta { partial_json } => { + *running_tool_input_ref = running_tool_input_ref.to_owned() + &partial_json; } - _ => {} + }, + Ok(AnthropicEvent::ContentBlockStop { _index }) => { + if let (Some(current_tool_use), Some(current_tool_use_id)) = ( + current_tool_use_ref.clone(), + current_tool_use_id_ref.clone(), + ) { + tool_use_indication.push(( + current_tool_use.to_owned(), + ( + current_tool_use_id.to_owned(), + running_tool_input_ref.to_owned(), + ), + )); + } + *current_tool_use_ref = None; + *running_tool_input_ref = "".to_owned(); + *current_tool_use_id_ref = None; } + Ok(AnthropicEvent::MessageStart { message }) => { + println!( + "anthropic::cache_hit::{:?}", + message.usage.cache_read_input_tokens + ); + } + Err(e) => { + error!("Error parsing event: {:?}", e); + } + _ => {} } } - if tool_uses.is_empty() { + if tool_use_indication.is_empty() { info!("anthropic::tool_not_found"); } - let request_id = uuid::Uuid::new_v4().to_string(); + let request_id = uuid::Uuid::new_v4(); let parea_log_completion = PareaLogCompletion::new( - messages.into_iter().map(|m| PareaLogMessage::new( - m.role().to_string(), - Self::format_message_content(&m), - )).collect(), + messages + .into_iter() + .map(|message| { + PareaLogMessage::new(message.role().to_string(), { + let content = message.content(); + let tool_use_value = message + .tool_use_value() + .into_iter() + .map(|tool_use_value| { + serde_json::to_string(&tool_use_value).expect("to work") + }) + .collect::>() + .join("\n"); + let tool_return_value = message + .tool_return_value() + .into_iter() + .map(|llm_return_value| { + serde_json::to_string(&llm_return_value).expect("to work") + }) + .collect::>() + .join("\n"); + format!( + r#" +{content} + + +{tool_use_value} + + +{tool_return_value} +"# + ) + }) + }) + .collect::>(), metadata.clone(), - Self::format_completion_content(&response_text, &tool_uses), + { + format!( + " +{} + + +{} +", + &buffered_string, + tool_use_indication + .to_vec() + .into_iter() + .map(|(_, (tool_use_type, tool_use_value))| { + format!( + " + +{} + + +{} + +", + tool_use_type, tool_use_value + ) + }) + .collect::>() + .join("\n") + ) + }, 0.2, - request_id.clone(), - request_id.clone(), - metadata.get("root_trace_id").cloned().unwrap_or_else(|| request_id.clone()), + request_id.to_string(), + request_id.to_string(), + metadata + .get("root_trace_id") + .map(|s| s.to_owned()) + .unwrap_or(request_id.to_string()), "ClaudeSonnet".to_owned(), "Anthropic".to_owned(), - metadata.get("event_type").cloned().unwrap_or_else(|| "no_event_type".to_owned()), + metadata + .get("event_type") + .map(|s| s.to_owned()) + .unwrap_or("no_event_type".to_owned()), ); - let _ = PareaClient::new().log_completion(parea_log_completion).await; + let _ = PareaClient::new() + .log_completion(parea_log_completion) + .await; - Ok((response_text, tool_uses)) + Ok((buffered_string, tool_use_indication)) } } @@ -735,50 +811,109 @@ impl LLMClient for AnthropicClient { request: LLMClientCompletionStringRequest, sender: UnboundedSender, ) -> Result { + let endpoint = self.chat_endpoint(); let model_str = self.get_model_string(request.model())?; let anthropic_request = AnthropicRequest::from_client_string_request(request, model_str.to_owned()); - let response_stream = self.send_request(&anthropic_request, api_key).await?; - let mut event_source = response_stream.bytes_stream().eventsource(); + let response_stream = self + .client + .post(endpoint) + .header( + "x-api-key".to_owned(), + self.generate_api_bearer_key(api_key)?, + ) + .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) + .header("content-type".to_owned(), "application/json".to_owned()) + .header( + "anthropic-beta".to_owned(), + "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15,computer-use-2024-10-22".to_owned(), + ) + .json(&anthropic_request) + .send() + .await + .map_err(|e| { + error!("sidecar.anthropic.error: {:?}", &e); + e + })?; + if response_stream.status() == reqwest::StatusCode::UNAUTHORIZED { + error!("Unauthorized access to Anthropic API"); + return Err(LLMClientError::UnauthorizedAccess); + } + + let mut event_source = response_stream.bytes_stream().eventsource(); let mut buffered_string = String::new(); - let mut usage_stats = LLMClientUsageStatistics::new(); + let mut input_tokens = 0; + let mut output_tokens = 0; + let mut input_cached_tokens = 0; while let Some(Ok(event)) = event_source.next().await { - if let Ok(event) = serde_json::from_str::(&event.data) { - match event { - AnthropicEvent::MessageStart { message } => { - if let Some(tokens) = message.usage.input_tokens { - usage_stats = usage_stats.set_input_tokens(tokens); - } - if let Some(tokens) = message.usage.output_tokens { - usage_stats = usage_stats.set_output_tokens(tokens); + let event = serde_json::from_str::(&event.data); + match event { + Ok(AnthropicEvent::ContentBlockStart { content_block, .. }) => { + match content_block { + ContentBlockStart::InputToolUse { name, .. } => { + info!("anthropic::tool_use::{}", &name); } - if let Some(tokens) = message.usage.cache_read_input_tokens { - usage_stats = usage_stats.set_cached_input_tokens(tokens); + ContentBlockStart::TextDelta { text } => { + buffered_string.push_str(&text); + if let Err(e) = sender.send( + LLMClientCompletionResponse::new( + buffered_string.clone(), + Some(text), + model_str.to_owned(), + ) + .set_usage_statistics( + LLMClientUsageStatistics::new() + .set_input_tokens(input_tokens) + .set_output_tokens(output_tokens) + .set_cached_input_tokens(input_cached_tokens), + ), + ) { + error!("Failed to send completion response: {}", e); + return Err(LLMClientError::SendError(e)); + } } } - AnthropicEvent::MessageDelta { usage, .. } => { - if let Some(tokens) = usage.input_tokens { - usage_stats = usage_stats.set_input_tokens(tokens); - } - if let Some(tokens) = usage.output_tokens { - usage_stats = usage_stats.set_output_tokens(tokens); - } - if let Some(tokens) = usage.cache_read_input_tokens { - usage_stats = usage_stats.set_cached_input_tokens(tokens); + } + Ok(AnthropicEvent::ContentBlockDelta { delta, .. }) => match delta { + ContentBlockDeltaType::TextDelta { text } => { + buffered_string.push_str(&text); + if let Err(e) = sender.send( + LLMClientCompletionResponse::new( + buffered_string.clone(), + Some(text), + model_str.to_owned(), + ) + .set_usage_statistics( + LLMClientUsageStatistics::new() + .set_input_tokens(input_tokens) + .set_output_tokens(output_tokens) + .set_cached_input_tokens(input_cached_tokens), + ), + ) { + error!("Failed to send completion response: {}", e); + return Err(LLMClientError::SendError(e)); } } - _ => { - AnthropicMessage::handle_event_content( - event, - &mut buffered_string, - &model_str, - &sender, - Some(usage_stats.clone()), - )?; + ContentBlockDeltaType::InputJsonDelta { partial_json } => { + debug!("input_json_delta::{}", &partial_json); } + }, + Ok(AnthropicEvent::MessageStart { message }) => { + input_tokens += message.usage.input_tokens.unwrap_or_default(); + output_tokens += message.usage.output_tokens.unwrap_or_default(); + input_cached_tokens += message.usage.cache_read_input_tokens.unwrap_or_default(); + } + Ok(AnthropicEvent::MessageDelta { usage, .. }) => { + input_tokens += usage.input_tokens.unwrap_or_default(); + output_tokens += usage.output_tokens.unwrap_or_default(); + input_cached_tokens += usage.cache_read_input_tokens.unwrap_or_default(); + } + Err(e) => { + error!("Error parsing event: {:?}", e); } + _ => {} } } From 98076ff0f3e3eb45b8e2522bc6855122a6268569 Mon Sep 17 00:00:00 2001 From: codestory Date: Sun, 9 Feb 2025 18:58:08 +0000 Subject: [PATCH 5/5] feat: add cache control to all Anthropic message content types The commit adds cache control support to Image, ToolUse and ToolReturn message types and refactors message content handling to be more consistent. --- llm_client/src/clients/anthropic.rs | 78 +++++++++++++---------------- 1 file changed, 35 insertions(+), 43 deletions(-) diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index 80bb7df49..9effff908 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -38,17 +38,22 @@ enum AnthropicMessageContent { cache_control: Option, }, #[serde(rename = "image")] - Image { source: AnthropicImageSource }, + Image { + source: AnthropicImageSource, + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, + cache_control: Option, }, #[serde(rename = "tool_result")] ToolReturn { tool_use_id: String, content: String, + cache_control: Option, }, } @@ -61,12 +66,13 @@ impl AnthropicMessageContent { } fn cache_control(mut self, cache_control_update: Option) -> Self { - if let Self::Text { - text: _, - ref mut cache_control, - } = self - { - *cache_control = cache_control_update; + match &mut self { + Self::Text { cache_control, .. } | + Self::Image { cache_control, .. } | + Self::ToolUse { cache_control, .. } | + Self::ToolReturn { cache_control, .. } => { + *cache_control = cache_control_update; + } } self } @@ -78,6 +84,7 @@ impl AnthropicMessageContent { media_type: llm_image.media().to_owned(), data: llm_image.data().to_owned(), }, + cache_control: None, } } @@ -86,6 +93,7 @@ impl AnthropicMessageContent { id: llm_tool_use.id().to_owned(), name: llm_tool_use.name().to_owned(), input: llm_tool_use.input().clone(), + cache_control: None, } } @@ -93,6 +101,7 @@ impl AnthropicMessageContent { Self::ToolReturn { tool_use_id: llm_tool_return.tool_use_id().to_owned(), content: llm_tool_return.content().to_owned(), + cache_control: None, } } } @@ -263,44 +272,27 @@ impl AnthropicRequest { .into_iter() .filter(|message| message.role().is_user() || message.role().is_assistant()) .map(|message| { - let mut anthropic_message_content = - AnthropicMessageContent::text(message.content().to_owned(), None); - if message.is_cache_point() { - anthropic_message_content = - anthropic_message_content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); + let mut content = Vec::new(); + + // Add text content if we don't have tool returns + if message.tool_return_value().is_empty() && !message.content().is_empty() { + content.push(AnthropicMessageContent::text(message.content().to_owned(), None)); } - let images = message - .images() - .into_iter() - .map(|image| AnthropicMessageContent::image(image)) - .collect::>(); - let tools = message - .tool_use_value() - .into_iter() - .map(|tool_use| AnthropicMessageContent::tool_use(tool_use)) - .collect::>(); - let tool_return = message - .tool_return_value() - .into_iter() - .map(|tool_return| AnthropicMessageContent::tool_return(tool_return)) - .collect::>(); - // if we have a tool return then we should not add the content string at all - let final_content = if tool_return.is_empty() { - if message.content().is_empty() { - vec![] - } else { - vec![anthropic_message_content] - } - } else { - vec![] + + // Add images, tools and tool returns + content.extend(message.images().iter().map(AnthropicMessageContent::image)); + content.extend(message.tool_use_value().iter().map(AnthropicMessageContent::tool_use)); + content.extend(message.tool_return_value().iter().map(AnthropicMessageContent::tool_return)); + + // Apply cache control to last content if needed + if message.is_cache_point() && !content.is_empty() { + let last_idx = content.len() - 1; + content[last_idx] = content[last_idx].clone().cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); } - .into_iter() - .chain(images) - .chain(tools) - .chain(tool_return) - .collect(); + + let final_content = content; AnthropicMessage { role: message.role().to_string(), content: final_content,