diff --git a/crates/q_chat/src/consts.rs b/crates/q_chat/src/consts.rs index 6850f7efa..6279b8980 100644 --- a/crates/q_chat/src/consts.rs +++ b/crates/q_chat/src/consts.rs @@ -16,4 +16,6 @@ pub const MAX_USER_MESSAGE_SIZE: usize = 600_000; /// In tokens pub const CONTEXT_WINDOW_SIZE: usize = 200_000; +pub const CONTEXT_FILES_MAX_SIZE: usize = 150_000; + pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold diff --git a/crates/q_chat/src/conversation_state.rs b/crates/q_chat/src/conversation_state.rs index e3ab5f17a..4bfe1428d 100644 --- a/crates/q_chat/src/conversation_state.rs +++ b/crates/q_chat/src/conversation_state.rs @@ -4,6 +4,14 @@ use std::collections::{ }; use std::sync::Arc; +use crossterm::style::Color::{ + DarkGreen, + DarkYellow, +}; +use crossterm::{ + execute, + style, +}; use fig_api_client::model::{ AssistantResponseMessage, ChatMessage, @@ -26,6 +34,7 @@ use tracing::{ }; use super::consts::{ + CONTEXT_FILES_MAX_SIZE, MAX_CHARS, MAX_CONVERSATION_STATE_HISTORY_LEN, }; @@ -58,6 +67,7 @@ use super::tools::{ const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; const CONTEXT_ENTRY_END_HEADER: &str = "--- CONTEXT ENTRY END ---\n\n"; +use super::util::drop_matched_context_files; /// Tracks state related to an ongoing conversation. #[derive(Debug, Clone)] pub struct ConversationState { @@ -310,7 +320,7 @@ impl ConversationState { self.history.drain(self.valid_history_range.1..); self.history.drain(..self.valid_history_range.0); - self.backend_conversation_state(run_hooks, false) + self.backend_conversation_state(run_hooks, false, true) .await .into_fig_conversation_state() .expect("unable to construct conversation state") @@ -318,7 +328,12 @@ impl ConversationState { /// Returns a conversation state representation which reflects the exact conversation to send /// back to the model. - pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> { + pub async fn backend_conversation_state( + &mut self, + run_hooks: bool, + quiet: bool, + show_dropped_context_files_warning: bool, + ) -> BackendConversationState<'_> { self.enforce_conversation_invariants(); // Run hooks and add to conversation start and next user message. @@ -330,7 +345,6 @@ impl ConversationState { } else { Some(self.updates.as_mut().unwrap_or(&mut null_writer)) }; - let hook_results = cm.run_hooks(updates).await; conversation_start_context = Some(format_hook_context(hook_results.iter(), HookTrigger::ConversationStart)); @@ -340,7 +354,9 @@ impl ConversationState { } } - let context_messages = self.context_messages(conversation_start_context).await; + let context_messages = self + .context_messages(conversation_start_context, show_dropped_context_files_warning) + .await; BackendConversationState { conversation_id: self.conversation_id.as_str(), @@ -399,7 +415,7 @@ impl ConversationState { }, }; - let conv_state = self.backend_conversation_state(false, true).await; + let conv_state = self.backend_conversation_state(false, true, false).await; // Include everything but the last message in the history. let history_len = conv_state.history.len(); @@ -482,6 +498,7 @@ impl ConversationState { async fn context_messages( &mut self, conversation_start_context: Option, + show_dropped_context_files_warning: bool, ) -> Option> { let mut context_content = String::new(); @@ -497,7 +514,28 @@ impl ConversationState { // Add context files if available if let Some(context_manager) = self.context_manager.as_mut() { match context_manager.get_context_files(true).await { - Ok(files) => { + Ok(mut files) => { + if let Ok(dropped_files) = drop_matched_context_files(&mut files, CONTEXT_FILES_MAX_SIZE) { + if !dropped_files.is_empty() { + if show_dropped_context_files_warning { + let mut output = SharedWriter::stdout(); + execute!( + output, + style::SetForegroundColor(DarkYellow), + style::Print("\nSome context files are dropped due to size limit, please run "), + style::SetForegroundColor(DarkGreen), + style::Print("/context show "), + style::SetForegroundColor(DarkYellow), + style::Print("to learn more.\n"), + style::SetForegroundColor(style::Color::Reset) + ) + .ok(); + } + for (filename, _) in dropped_files.iter() { + files.retain(|(f, _)| f != filename); + } + } + } if !files.is_empty() { context_content.push_str(CONTEXT_ENTRY_START_HEADER); for (filename, content) in files { @@ -533,7 +571,7 @@ impl ConversationState { /// Calculate the total character count in the conversation pub async fn calculate_char_count(&mut self) -> CharCount { - self.backend_conversation_state(false, true).await.char_count() + self.backend_conversation_state(false, true, false).await.char_count() } /// Get the current token warning level diff --git a/crates/q_chat/src/lib.rs b/crates/q_chat/src/lib.rs index 773b6bb8f..ec08fb362 100644 --- a/crates/q_chat/src/lib.rs +++ b/crates/q_chat/src/lib.rs @@ -43,7 +43,10 @@ use command::{ PromptsSubcommand, ToolsSubcommand, }; -use consts::CONTEXT_WINDOW_SIZE; +use consts::{ + CONTEXT_FILES_MAX_SIZE, + CONTEXT_WINDOW_SIZE, +}; use context::ContextManager; use conversation_state::{ ConversationState, @@ -176,6 +179,7 @@ use tracing::{ use unicode_width::UnicodeWidthStr; use util::{ animate_output, + drop_matched_context_files, play_notification_bell, region_check, }; @@ -982,7 +986,7 @@ impl ChatContext { fig_api_client::Error::ContextWindowOverflow => { let history_too_small = self .conversation_state - .backend_conversation_state(false, true) + .backend_conversation_state(false, true, false) .await .history .len() @@ -1744,8 +1748,8 @@ impl ChatContext { style::SetAttribute(Attribute::Reset) )?; - for (filename, content) in global_context_files { - let est_tokens = TokenCounter::count_tokens(&content); + for (filename, content) in &global_context_files { + let est_tokens = TokenCounter::count_tokens(content); execute!( self.output, style::Print(format!("🌍 {} ", filename)), @@ -1763,8 +1767,8 @@ impl ChatContext { } } - for (filename, content) in profile_context_files { - let est_tokens = TokenCounter::count_tokens(&content); + for (filename, content) in &profile_context_files { + let est_tokens = TokenCounter::count_tokens(content); execute!( self.output, style::Print(format!("👤 {} ", filename)), @@ -1786,11 +1790,55 @@ impl ChatContext { execute!(self.output, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; } + let mut combined_files: Vec<(String, String)> = global_context_files + .iter() + .chain(profile_context_files.iter()) + .cloned() + .collect(); + + let dropped_files = + drop_matched_context_files(&mut combined_files, CONTEXT_FILES_MAX_SIZE).ok(); + execute!( self.output, - style::Print(format!("\nTotal: ~{} tokens\n\n", total_tokens)), + style::Print(format!("\nTotal: ~{} tokens\n\n", total_tokens)) )?; + if let Some(dropped_files) = dropped_files { + if !dropped_files.is_empty() { + execute!( + self.output, + style::SetForegroundColor(Color::DarkYellow), + style::Print(format!( + "Total token count exceeds limit: {}. The following files will be automatically dropped when interacting with Q. Consider remove them. \n\n", + CONTEXT_FILES_MAX_SIZE + )), + style::SetForegroundColor(Color::Reset) + )?; + let total_files = dropped_files.len(); + + let truncated_dropped_files = &dropped_files[..10]; + + for (filename, content) in truncated_dropped_files { + let est_tokens = TokenCounter::count_tokens(content); + execute!( + self.output, + style::Print(format!("{} ", filename)), + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("(~{} tkns)\n", est_tokens)), + style::SetForegroundColor(Color::Reset), + )?; + } + + if total_files > 10 { + execute!( + self.output, + style::Print(format!("({} more files)\n", total_files - 10)) + )?; + } + } + } + execute!(self.output, style::Print("\n"))?; } }, @@ -2573,7 +2621,10 @@ impl ChatContext { } }, Command::Usage => { - let state = self.conversation_state.backend_conversation_state(true, true).await; + let state = self + .conversation_state + .backend_conversation_state(true, true, true) + .await; let data = state.calculate_conversation_size(); let context_token_count: TokenCount = data.context_messages.into(); diff --git a/crates/q_chat/src/util/mod.rs b/crates/q_chat/src/util/mod.rs index b7561b936..abd25f5ce 100644 --- a/crates/q_chat/src/util/mod.rs +++ b/crates/q_chat/src/util/mod.rs @@ -3,9 +3,11 @@ pub mod issue; use std::io::Write; use std::time::Duration; +use eyre::Result; use fig_util::system_info::in_cloudshell; use super::ChatError; +use super::token_counter::TokenCounter; const GOV_REGIONS: &[&str] = &["us-gov-east-1", "us-gov-west-1"]; @@ -98,6 +100,30 @@ fn should_play_bell() -> bool { false } +/// This is a simple greedy algorithm that drops the largest files first +/// until the total size is below the limit +/// +/// # Arguments +/// * `files` - A mutable reference to a vector of tuples: (filename, content). This file will be +/// sorted but the content will not be changed. +/// +/// Returns the dropped files +pub fn drop_matched_context_files(files: &mut [(String, String)], limit: usize) -> Result> { + files.sort_by(|a, b| TokenCounter::count_tokens(&b.1).cmp(&TokenCounter::count_tokens(&a.1))); + let mut total_size = 0; + let mut dropped_files = Vec::new(); + + for (filename, content) in files.iter() { + let size = TokenCounter::count_tokens(content); + if total_size + size > limit { + dropped_files.push((filename.clone(), content.clone())); + } else { + total_size += size; + } + } + Ok(dropped_files) +} + #[cfg(test)] mod tests { use super::*; @@ -109,4 +135,26 @@ mod tests { assert_eq!(truncate_safe("Hello World", 11), "Hello World"); assert_eq!(truncate_safe("Hello World", 15), "Hello World"); } + + #[test] + fn test_drop_matched_context_files() { + let mut files = vec![ + ("file1".to_string(), "This is a test file".to_string()), + ( + "file3".to_string(), + "Yet another test file that's has the largest context file".to_string(), + ), + ]; + let limit = 10; + + let dropped_files = drop_matched_context_files(&mut files, limit).unwrap(); + assert_eq!(dropped_files.len(), 1); + assert_eq!(dropped_files[0].0, "file3"); + assert_eq!(files.len(), 2); + + for (filename, _) in dropped_files.iter() { + files.retain(|(f, _)| f != filename); + } + assert_eq!(files.len(), 1); + } }