diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index 092dca569..9effff908 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; - use async_trait::async_trait; use eventsource_stream::Eventsource; use futures::StreamExt; @@ -30,7 +29,7 @@ struct AnthropicCacheControl { r#type: AnthropicCacheType, } -#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +#[derive(Debug, Serialize, Clone)] #[serde(tag = "type")] enum AnthropicMessageContent { #[serde(rename = "text")] @@ -39,35 +38,41 @@ enum AnthropicMessageContent { cache_control: Option, }, #[serde(rename = "image")] - Image { source: AnthropicImageSource }, + Image { + source: AnthropicImageSource, + cache_control: Option, + }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, input: serde_json::Value, + cache_control: Option, }, #[serde(rename = "tool_result")] ToolReturn { tool_use_id: String, content: String, + cache_control: Option, }, } impl AnthropicMessageContent { - pub fn text(content: String, cache_control: Option) -> Self { + pub fn text(text: String, cache_control: Option) -> Self { Self::Text { - text: content, + text, cache_control, } } fn cache_control(mut self, cache_control_update: Option) -> Self { - if let Self::Text { - text: _, - ref mut cache_control, - } = self - { - *cache_control = cache_control_update; + match &mut self { + Self::Text { cache_control, .. } | + Self::Image { cache_control, .. } | + Self::ToolUse { cache_control, .. } | + Self::ToolReturn { cache_control, .. } => { + *cache_control = cache_control_update; + } } self } @@ -75,25 +80,28 @@ impl AnthropicMessageContent { pub fn image(llm_image: &LLMClientMessageImage) -> Self { Self::Image { source: AnthropicImageSource { - r#type: llm_image.r#type().to_owned(), + type_: "base64".to_owned(), media_type: llm_image.media().to_owned(), data: llm_image.data().to_owned(), }, + cache_control: None, } } - pub fn tool_use(llm_tool_use: &LLMClientToolUse) -> Self { + pub fn tool_use(llm_tool_use: &LLMClientMessageToolUse) -> Self { Self::ToolUse { id: llm_tool_use.id().to_owned(), name: llm_tool_use.name().to_owned(), input: llm_tool_use.input().clone(), + cache_control: None, } } - pub fn tool_return(llm_tool_return: &LLMClientToolReturn) -> Self { + pub fn tool_return(llm_tool_return: &LLMClientMessageToolReturn) -> Self { Self::ToolReturn { tool_use_id: llm_tool_return.tool_use_id().to_owned(), content: llm_tool_return.content().to_owned(), + cache_control: None, } } } @@ -209,22 +217,19 @@ enum ContentBlockDeltaType { }, } -#[derive(serde::Serialize, Debug, Clone)] +#[derive(Debug, Serialize)] struct AnthropicRequest { system: Vec, messages: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] - /// This is going to be such a fucking nightmare later on... - tools: Vec, temperature: f32, + max_tokens: Option, + tools: Vec, stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, model: String, } impl AnthropicRequest { - fn from_client_completion_request( + pub fn from_client_completion_request( completion_request: LLMClientCompletionRequest, model_str: String, ) -> Self { @@ -267,44 +272,27 @@ impl AnthropicRequest { .into_iter() .filter(|message| message.role().is_user() || message.role().is_assistant()) .map(|message| { - let mut anthropic_message_content = - AnthropicMessageContent::text(message.content().to_owned(), None); - if message.is_cache_point() { - anthropic_message_content = - anthropic_message_content.cache_control(Some(AnthropicCacheControl { - r#type: AnthropicCacheType::Ephemeral, - })); + let mut content = Vec::new(); + + // Add text content if we don't have tool returns + if message.tool_return_value().is_empty() && !message.content().is_empty() { + content.push(AnthropicMessageContent::text(message.content().to_owned(), None)); } - let images = message - .images() - .into_iter() - .map(|image| AnthropicMessageContent::image(image)) - .collect::>(); - let tools = message - .tool_use_value() - .into_iter() - .map(|tool_use| AnthropicMessageContent::tool_use(tool_use)) - .collect::>(); - let tool_return = message - .tool_return_value() - .into_iter() - .map(|tool_return| AnthropicMessageContent::tool_return(tool_return)) - .collect::>(); - // if we have a tool return then we should not add the content string at all - let final_content = if tool_return.is_empty() { - if message.content().is_empty() { - vec![] - } else { - vec![anthropic_message_content] - } - } else { - vec![] + + // Add images, tools and tool returns + content.extend(message.images().iter().map(AnthropicMessageContent::image)); + content.extend(message.tool_use_value().iter().map(AnthropicMessageContent::tool_use)); + content.extend(message.tool_return_value().iter().map(AnthropicMessageContent::tool_return)); + + // Apply cache control to last content if needed + if message.is_cache_point() && !content.is_empty() { + let last_idx = content.len() - 1; + content[last_idx] = content[last_idx].clone().cache_control(Some(AnthropicCacheControl { + r#type: AnthropicCacheType::Ephemeral, + })); } - .into_iter() - .chain(images) - .chain(tools) - .chain(tool_return) - .collect(); + + let final_content = content; AnthropicMessage { role: message.role().to_string(), content: final_content, @@ -316,30 +304,9 @@ impl AnthropicRequest { system: system_message, messages, temperature, - tools, - stream: true, max_tokens, - model: model_str, - } - } - - fn from_client_string_request( - completion_request: LLMClientCompletionStringRequest, - model_str: String, - ) -> Self { - let temperature = completion_request.temperature(); - let max_tokens = completion_request.get_max_tokens(); - let messages = vec![AnthropicMessage::new( - "user".to_owned(), - completion_request.prompt().to_owned(), - )]; - AnthropicRequest { - system: vec![], - messages, - temperature, - tools: vec![], + tools, stream: true, - max_tokens, model: model_str, } } @@ -393,14 +360,33 @@ impl AnthropicClient { } /// We try to get the completion along with the tool which we are planning on using + fn from_client_string_request( + completion_request: LLMClientCompletionStringRequest, + model_str: String, + ) -> Self { + let temperature = completion_request.temperature(); + let max_tokens = completion_request.get_max_tokens(); + let messages = vec![AnthropicMessage::new( + "user".to_owned(), + completion_request.prompt().to_owned(), + )]; + AnthropicRequest { + system: vec![], + messages, + temperature, + tools: vec![], + stream: true, + max_tokens, + model: model_str, + } + } + pub async fn stream_completion_with_tool( &self, api_key: LLMProviderAPIKeys, request: LLMClientCompletionRequest, metadata: HashMap, sender: UnboundedSender, - // The first parameter in the Vec<(String, (String, String))> is the tool_type and the - // second one is (tool_id + serialized_json value of the tool use) ) -> Result<(String, Vec<(String, (String, String))>), LLMClientError> { let endpoint = self.chat_endpoint(); let messages = request @@ -434,8 +420,6 @@ impl AnthropicClient { ) .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) .header("content-type".to_owned(), "application/json".to_owned()) - // anthropic-beta: prompt-caching-2024-07-31 - // enables prompt caching: https://arc.net/l/quote/qtlllqgf .header( "anthropic-beta".to_owned(), "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15,computer-use-2024-10-22".to_owned(), @@ -456,15 +440,8 @@ impl AnthropicClient { let mut event_source = response_stream.bytes_stream().eventsource(); - // let event_next = event_source.next().await; - // dbg!(&event_next); - let mut buffered_string = "".to_owned(); - // controls which tool we will be using if any let mut tool_use_indication: Vec<(String, (String, String))> = vec![]; - - // handle all the tool parameters that are coming - // we will keep a global tracker over here let mut current_tool_use = None; let current_tool_use_ref = &mut current_tool_use; let mut current_tool_use_id = None; @@ -472,9 +449,7 @@ impl AnthropicClient { let mut running_tool_input = "".to_owned(); let running_tool_input_ref = &mut running_tool_input; - // loop over the content we are getting while let Some(Ok(event)) = event_source.next().await { - // TODO: debugging this let event = serde_json::from_str::(&event.data); match event { Ok(AnthropicEvent::ContentBlockStart { content_block, .. }) => { @@ -522,12 +497,9 @@ impl AnthropicClient { } ContentBlockDeltaType::InputJsonDelta { partial_json } => { *running_tool_input_ref = running_tool_input_ref.to_owned() + &partial_json; - // println!("input_json_delta::{}", &partial_json); } }, Ok(AnthropicEvent::ContentBlockStop { _index }) => { - // if the code block has stopped we need to pack our bags and - // create an entry for the tool which we want to use if let (Some(current_tool_use), Some(current_tool_use_id)) = ( current_tool_use_ref.clone(), current_tool_use_id_ref.clone(), @@ -540,8 +512,6 @@ impl AnthropicClient { ), )); } - - // now empty the tool use tracker *current_tool_use_ref = None; *running_tool_input_ref = "".to_owned(); *current_tool_use_id_ref = None; @@ -554,11 +524,8 @@ impl AnthropicClient { } Err(e) => { error!("Error parsing event: {:?}", e); - // break; - } - _ => { - // dbg!(&event); } + _ => {} } } @@ -572,7 +539,6 @@ impl AnthropicClient { .into_iter() .map(|message| { PareaLogMessage::new(message.role().to_string(), { - // we generate the content in a special way so we can read it on parea let content = message.content(); let tool_use_value = message .tool_use_value() @@ -839,50 +805,63 @@ impl LLMClient for AnthropicClient { ) -> Result { let endpoint = self.chat_endpoint(); let model_str = self.get_model_string(request.model())?; - let anthropic_request = - AnthropicRequest::from_client_string_request(request, model_str.to_owned()); + let anthropic_request = AnthropicRequest::from_client_string_request(request, model_str.to_owned()); - let response = self + let response_stream = self .client .post(endpoint) .header( "x-api-key".to_owned(), self.generate_api_bearer_key(api_key)?, ) + .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) + .header("content-type".to_owned(), "application/json".to_owned()) .header( "anthropic-beta".to_owned(), - "max-tokens-3-5-sonnet-2024-07-15".to_owned(), + "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15,computer-use-2024-10-22".to_owned(), ) - .header("anthropic-version".to_owned(), "2023-06-01".to_owned()) - .header("content-type".to_owned(), "application/json".to_owned()) .json(&anthropic_request) .send() - .await?; + .await + .map_err(|e| { + error!("sidecar.anthropic.error: {:?}", &e); + e + })?; - // Check for 401 Unauthorized status - if response.status() == reqwest::StatusCode::UNAUTHORIZED { + if response_stream.status() == reqwest::StatusCode::UNAUTHORIZED { error!("Unauthorized access to Anthropic API"); return Err(LLMClientError::UnauthorizedAccess); } - let mut response_stream = response.bytes_stream().eventsource(); + let mut event_source = response_stream.bytes_stream().eventsource(); + let mut buffered_string = String::new(); + let mut input_tokens = 0; + let mut output_tokens = 0; + let mut input_cached_tokens = 0; - let mut buffered_string = "".to_owned(); - while let Some(Ok(event)) = response_stream.next().await { + while let Some(Ok(event)) = event_source.next().await { let event = serde_json::from_str::(&event.data); match event { Ok(AnthropicEvent::ContentBlockStart { content_block, .. }) => { match content_block { - ContentBlockStart::InputToolUse { name, id: _id } => { - println!("anthropic::tool_use::{}", &name); + ContentBlockStart::InputToolUse { name, .. } => { + info!("anthropic::tool_use::{}", &name); } ContentBlockStart::TextDelta { text } => { - buffered_string = buffered_string + &text; - if let Err(e) = sender.send(LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text), - model_str.to_owned(), - )) { + buffered_string.push_str(&text); + if let Err(e) = sender.send( + LLMClientCompletionResponse::new( + buffered_string.clone(), + Some(text), + model_str.to_owned(), + ) + .set_usage_statistics( + LLMClientUsageStatistics::new() + .set_input_tokens(input_tokens) + .set_output_tokens(output_tokens) + .set_cached_input_tokens(input_cached_tokens), + ), + ) { error!("Failed to send completion response: {}", e); return Err(LLMClientError::SendError(e)); } @@ -891,33 +870,45 @@ impl LLMClient for AnthropicClient { } Ok(AnthropicEvent::ContentBlockDelta { delta, .. }) => match delta { ContentBlockDeltaType::TextDelta { text } => { - buffered_string = buffered_string + &text; - let _ = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - if let Err(e) = sender.send(LLMClientCompletionResponse::new( - buffered_string.to_owned(), - Some(text), - model_str.to_owned(), - )) { + buffered_string.push_str(&text); + if let Err(e) = sender.send( + LLMClientCompletionResponse::new( + buffered_string.clone(), + Some(text), + model_str.to_owned(), + ) + .set_usage_statistics( + LLMClientUsageStatistics::new() + .set_input_tokens(input_tokens) + .set_output_tokens(output_tokens) + .set_cached_input_tokens(input_cached_tokens), + ), + ) { error!("Failed to send completion response: {}", e); return Err(LLMClientError::SendError(e)); } } ContentBlockDeltaType::InputJsonDelta { partial_json } => { - println!("input_json_delta::{}", &partial_json); + debug!("input_json_delta::{}", &partial_json); } }, - Err(_) => { - break; + Ok(AnthropicEvent::MessageStart { message }) => { + input_tokens += message.usage.input_tokens.unwrap_or_default(); + output_tokens += message.usage.output_tokens.unwrap_or_default(); + input_cached_tokens += message.usage.cache_read_input_tokens.unwrap_or_default(); } - _ => { - dbg!(&event); + Ok(AnthropicEvent::MessageDelta { usage, .. }) => { + input_tokens += usage.input_tokens.unwrap_or_default(); + output_tokens += usage.output_tokens.unwrap_or_default(); + input_cached_tokens += usage.cache_read_input_tokens.unwrap_or_default(); } + Err(e) => { + error!("Error parsing event: {:?}", e); + } + _ => {} } } Ok(buffered_string) } -} +} \ No newline at end of file