Skip to content

Commit 04b22bb

Browse files
committed
fix(server): multiple chat round out of memory error
1 parent b4e2df5 commit 04b22bb

File tree

9 files changed

+185
-92
lines changed

9 files changed

+185
-92
lines changed

apps/server/api/src/config/logic.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub struct LogicConfig {
77
/// unit: ms
88
silence_voice_timeout: Option<i64>,
99
system_prompt: Option<String>,
10+
max_prompt_len: Option<u64>,
1011
}
1112

1213
impl LogicConfig {
@@ -17,6 +18,7 @@ impl LogicConfig {
1718
system_prompt: Some(String::from(
1819
"你是一个助手,所有回答必须使用纯文本自然语言,禁止使用任何Markdown符号如#、-、*等。",
1920
)),
21+
max_prompt_len: Some(3000),
2022
}
2123
}
2224

@@ -31,4 +33,8 @@ impl LogicConfig {
3133
pub fn system_prompt(&self) -> &str {
3234
self.system_prompt.as_deref().unwrap_or_default()
3335
}
36+
37+
pub fn max_prompt_len(&self) -> u64 {
38+
self.max_prompt_len.unwrap()
39+
}
3440
}

apps/server/api/src/llm/client.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{sync::Arc, thread};
1+
use std::{collections::VecDeque, sync::Arc, thread};
22

33
use crate::{
44
common::ModelError,
@@ -26,6 +26,7 @@ pub struct Client {
2626
model: Arc<Box<dyn Model>>,
2727
temperature: Option<f64>,
2828
max_tokens: Option<u64>,
29+
max_prompt_len: Option<u64>,
2930
history: Arc<Mutex<History>>,
3031
mcp_host: Option<Arc<Mutex<dyn McpHost>>>,
3132
}
@@ -59,6 +60,11 @@ impl Client {
5960
self
6061
}
6162

63+
pub fn with_max_prompt_len(mut self, max_prompt_len: Option<u64>) -> Self {
64+
self.max_prompt_len = max_prompt_len;
65+
self
66+
}
67+
6268
pub fn chat(
6369
&self,
6470
request: ChatRequest,
@@ -70,6 +76,7 @@ impl Client {
7076
let clone_history = self.history.clone();
7177
let temperature = self.temperature;
7278
let max_tokens = self.max_tokens;
79+
let max_prompt_len = self.max_prompt_len;
7380
thread::spawn(move || {
7481
let output = block_on(async move {
7582
let tools = {
@@ -84,6 +91,28 @@ impl Client {
8491
while has_next_step {
8592
let history = clone_history.clone();
8693
let mut history = history.lock().await;
94+
if let Some(max_prompt_len) = max_prompt_len {
95+
// cut prompt
96+
let mut current_len: u64 = 0;
97+
if let Some(item) = &history.preamble {
98+
current_len += item.len() as u64;
99+
}
100+
current_len += model.calculate_tools_prompt_len(&tools);
101+
let mut target_message_list = VecDeque::new();
102+
// TODO: remove clone?
103+
let chat_history: Vec<_> =
104+
history.chat_history.clone().into_iter().rev().collect();
105+
for message in chat_history {
106+
let len = model.calculate_message_prompt_len(&message);
107+
current_len += len;
108+
if current_len <= max_prompt_len {
109+
target_message_list.push_front(message);
110+
} else {
111+
break;
112+
}
113+
}
114+
history.chat_history = target_message_list.into();
115+
}
87116
let chat_history = {
88117
if !history.chat_history.is_empty() {
89118
let mut result = OneOrMany::many(history.chat_history.clone()).unwrap();
@@ -293,6 +322,7 @@ impl ClientBuilder {
293322
model: self.model,
294323
temperature: None,
295324
max_tokens: None,
325+
max_prompt_len: Some(3000),
296326
history: Arc::new(Mutex::new(History {
297327
preamble: None,
298328
chat_history: vec![],

apps/server/api/src/llm/mod.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ use crate::{
88
};
99
use async_trait::async_trait;
1010
use rig::{
11-
completion::{CompletionError, CompletionRequest},
11+
completion::{CompletionError, CompletionRequest, ToolDefinition},
12+
message::Message,
1213
streaming::StreamingCompletionResponse,
1314
};
1415
use std::sync::{Arc, OnceLock};
@@ -22,6 +23,12 @@ pub trait Model: Send + Sync {
2223
StreamingCompletionResponse<rig::providers::openai::streaming::StreamingCompletionResponse>,
2324
CompletionError,
2425
>;
26+
27+
fn calculate_system_prompt_len(&self, system_prompt: &Option<String>) -> u64;
28+
29+
fn calculate_tools_prompt_len(&self, tools: &[ToolDefinition]) -> u64;
30+
31+
fn calculate_message_prompt_len(&self, message: &Message) -> u64;
2532
}
2633

2734
#[derive(Default, Clone)]
@@ -38,6 +45,18 @@ impl Model for DummyModel {
3845
> {
3946
todo!()
4047
}
48+
49+
fn calculate_system_prompt_len(&self, _system_prompt: &Option<String>) -> u64 {
50+
todo!()
51+
}
52+
53+
fn calculate_tools_prompt_len(&self, _tools: &[ToolDefinition]) -> u64 {
54+
todo!()
55+
}
56+
57+
fn calculate_message_prompt_len(&self, _message: &Message) -> u64 {
58+
todo!()
59+
}
4160
}
4261

4362
static INSTANCE: OnceLock<LlmFactory> = OnceLock::new();

apps/server/api/src/llm/model/minicpm4/mod.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use async_trait::async_trait;
1010
use futures::{SinkExt, StreamExt, executor::block_on};
1111
use futures_channel::mpsc::unbounded;
1212
use rig::{
13-
completion::{CompletionError, CompletionRequest},
13+
completion::{CompletionError, CompletionRequest, ToolDefinition},
1414
message::{Message, UserContent},
1515
streaming::{RawStreamingChoice, StreamingCompletionResponse},
1616
};
@@ -101,6 +101,18 @@ impl<'a> Model for Minicpm4<'a> {
101101
});
102102
Ok(StreamingCompletionResponse::stream(Box::pin(rx)))
103103
}
104+
105+
fn calculate_system_prompt_len(&self, _system_prompt: &Option<String>) -> u64 {
106+
todo!()
107+
}
108+
109+
fn calculate_tools_prompt_len(&self, _tools: &[ToolDefinition]) -> u64 {
110+
todo!()
111+
}
112+
113+
fn calculate_message_prompt_len(&self, _message: &Message) -> u64 {
114+
todo!()
115+
}
104116
}
105117

106118
fn convert_response(

0 commit comments

Comments
 (0)