Skip to content

Commit eaeb2e8

Browse files
committed
Added more tests
1 parent 24b6f59 commit eaeb2e8

File tree

1 file changed

+126
-17
lines changed

1 file changed

+126
-17
lines changed

src/chats.rs

Lines changed: 126 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -275,23 +275,9 @@ impl Client<crate::reqwest::ReqwestClient> {
275275
uid: impl AsRef<str>,
276276
body: &S,
277277
) -> Result<reqwest::Response, Error> {
278-
use reqwest::header::{HeaderValue, ACCEPT, CONTENT_TYPE};
278+
let request = self.build_stream_chat_request(uid.as_ref(), body)?;
279279

280-
let payload = to_vec(body).map_err(Error::ParseError)?;
281-
282-
let response = self
283-
.http_client
284-
.inner()
285-
.post(format!(
286-
"{}/chats/{}/chat/completions",
287-
self.host,
288-
uid.as_ref()
289-
))
290-
.header(ACCEPT, HeaderValue::from_static("text/event-stream"))
291-
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
292-
.body(payload)
293-
.send()
294-
.await?;
280+
let response = self.http_client.inner().execute(request).await?;
295281

296282
let status = response.status();
297283
if !status.is_success() {
@@ -310,14 +296,41 @@ impl Client<crate::reqwest::ReqwestClient> {
310296

311297
Ok(response)
312298
}
299+
300+
fn build_stream_chat_request<S: Serialize + ?Sized>(
301+
&self,
302+
uid: &str,
303+
body: &S,
304+
) -> Result<reqwest::Request, Error> {
305+
use reqwest::header::{HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE};
306+
307+
let payload = to_vec(body).map_err(Error::ParseError)?;
308+
309+
let mut request = self
310+
.http_client
311+
.inner()
312+
.post(format!("{}/chats/{}/chat/completions", self.host, uid))
313+
.header(ACCEPT, HeaderValue::from_static("text/event-stream"))
314+
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
315+
.body(payload)
316+
.build()?;
317+
318+
if let Some(key) = self.api_key.as_deref() {
319+
request.headers_mut().insert(
320+
AUTHORIZATION,
321+
HeaderValue::from_str(&format!("Bearer {key}")).unwrap(),
322+
);
323+
}
324+
325+
Ok(request)
326+
}
313327
}
314328

315329
#[cfg(test)]
316330
mod tests {
317331
use super::*;
318332
use meilisearch_test_macro::meilisearch_test;
319333
use serde_json::json;
320-
321334
#[meilisearch_test]
322335
async fn chat_workspace_lifecycle(client: Client, name: String) -> Result<(), Error> {
323336
let _: serde_json::Value = client
@@ -389,4 +402,100 @@ mod tests {
389402

390403
Ok(())
391404
}
405+
406+
#[test]
407+
fn chat_prompts_builder_helpers() {
408+
let mut prompts = ChatPrompts::new();
409+
prompts
410+
.set_system("system")
411+
.set_search_description("desc")
412+
.set_search_q_param("q")
413+
.set_search_index_uid_param("idx")
414+
.insert("custom", "value");
415+
416+
assert_eq!(prompts.system.as_deref(), Some("system"));
417+
assert_eq!(prompts.search_description.as_deref(), Some("desc"));
418+
assert_eq!(prompts.search_q_param.as_deref(), Some("q"));
419+
assert_eq!(prompts.search_index_uid_param.as_deref(), Some("idx"));
420+
assert_eq!(
421+
prompts.extra.get("custom").map(String::as_str),
422+
Some("value")
423+
);
424+
}
425+
426+
#[test]
427+
fn chat_workspace_settings_builder_helpers() {
428+
let mut settings = ChatWorkspaceSettings::new();
429+
settings
430+
.set_source("openAi")
431+
.set_org_id("org")
432+
.set_project_id("project")
433+
.set_api_version("2024-01-01")
434+
.set_deployment_id("deployment")
435+
.set_base_url("http://example.com")
436+
.set_api_key("secret")
437+
.set_prompts({
438+
let mut prompts = ChatPrompts::new();
439+
prompts.set_system("hi");
440+
prompts
441+
});
442+
443+
assert_eq!(settings.source.as_deref(), Some("openAi"));
444+
assert_eq!(settings.org_id.as_deref(), Some("org"));
445+
assert_eq!(settings.project_id.as_deref(), Some("project"));
446+
assert_eq!(settings.api_version.as_deref(), Some("2024-01-01"));
447+
assert_eq!(settings.deployment_id.as_deref(), Some("deployment"));
448+
assert_eq!(settings.base_url.as_deref(), Some("http://example.com"));
449+
assert_eq!(settings.api_key.as_deref(), Some("secret"));
450+
assert_eq!(
451+
settings.prompts.and_then(|p| p.system).as_deref(),
452+
Some("hi")
453+
);
454+
}
455+
456+
#[test]
457+
#[cfg(feature = "reqwest")]
458+
fn stream_chat_completion_request_includes_expected_headers() {
459+
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
460+
461+
let client = Client::new("http://localhost:7700", Some("secret")).unwrap();
462+
let body = json!({
463+
"model": "gpt-3.5-turbo",
464+
"messages": [{ "role": "user", "content": "Hello" }],
465+
"stream": true
466+
});
467+
468+
let request = client
469+
.build_stream_chat_request("workspace", &body)
470+
.expect("request should be built");
471+
472+
assert_eq!(request.method(), reqwest::Method::POST);
473+
assert_eq!(
474+
request.url().as_str(),
475+
"http://localhost:7700/chats/workspace/chat/completions"
476+
);
477+
478+
let headers = request.headers();
479+
assert_eq!(
480+
headers
481+
.get(reqwest::header::ACCEPT)
482+
.map(|h| h.to_str().unwrap()),
483+
Some("text/event-stream")
484+
);
485+
assert_eq!(
486+
headers.get(CONTENT_TYPE).map(|h| h.to_str().unwrap()),
487+
Some("application/json")
488+
);
489+
assert_eq!(
490+
headers.get(AUTHORIZATION).map(|h| h.to_str().unwrap()),
491+
Some("Bearer secret")
492+
);
493+
494+
let expected_body = body.to_string();
495+
let request_body = request
496+
.body()
497+
.and_then(|b| b.as_bytes())
498+
.expect("request has body");
499+
assert_eq!(request_body, expected_body.as_bytes());
500+
}
392501
}

0 commit comments

Comments
 (0)