@@ -275,23 +275,9 @@ impl Client<crate::reqwest::ReqwestClient> {
275275 uid : impl AsRef < str > ,
276276 body : & S ,
277277 ) -> Result < reqwest:: Response , Error > {
278- use reqwest :: header :: { HeaderValue , ACCEPT , CONTENT_TYPE } ;
278+ let request = self . build_stream_chat_request ( uid . as_ref ( ) , body ) ? ;
279279
280- let payload = to_vec ( body) . map_err ( Error :: ParseError ) ?;
281-
282- let response = self
283- . http_client
284- . inner ( )
285- . post ( format ! (
286- "{}/chats/{}/chat/completions" ,
287- self . host,
288- uid. as_ref( )
289- ) )
290- . header ( ACCEPT , HeaderValue :: from_static ( "text/event-stream" ) )
291- . header ( CONTENT_TYPE , HeaderValue :: from_static ( "application/json" ) )
292- . body ( payload)
293- . send ( )
294- . await ?;
280+ let response = self . http_client . inner ( ) . execute ( request) . await ?;
295281
296282 let status = response. status ( ) ;
297283 if !status. is_success ( ) {
@@ -310,14 +296,41 @@ impl Client<crate::reqwest::ReqwestClient> {
310296
311297 Ok ( response)
312298 }
299+
300+ fn build_stream_chat_request < S : Serialize + ?Sized > (
301+ & self ,
302+ uid : & str ,
303+ body : & S ,
304+ ) -> Result < reqwest:: Request , Error > {
305+ use reqwest:: header:: { HeaderValue , ACCEPT , AUTHORIZATION , CONTENT_TYPE } ;
306+
307+ let payload = to_vec ( body) . map_err ( Error :: ParseError ) ?;
308+
309+ let mut request = self
310+ . http_client
311+ . inner ( )
312+ . post ( format ! ( "{}/chats/{}/chat/completions" , self . host, uid) )
313+ . header ( ACCEPT , HeaderValue :: from_static ( "text/event-stream" ) )
314+ . header ( CONTENT_TYPE , HeaderValue :: from_static ( "application/json" ) )
315+ . body ( payload)
316+ . build ( ) ?;
317+
318+ if let Some ( key) = self . api_key . as_deref ( ) {
319+ request. headers_mut ( ) . insert (
320+ AUTHORIZATION ,
321+ HeaderValue :: from_str ( & format ! ( "Bearer {key}" ) ) . unwrap ( ) ,
322+ ) ;
323+ }
324+
325+ Ok ( request)
326+ }
313327}
314328
315329#[ cfg( test) ]
316330mod tests {
317331 use super :: * ;
318332 use meilisearch_test_macro:: meilisearch_test;
319333 use serde_json:: json;
320-
321334 #[ meilisearch_test]
322335 async fn chat_workspace_lifecycle ( client : Client , name : String ) -> Result < ( ) , Error > {
323336 let _: serde_json:: Value = client
@@ -389,4 +402,100 @@ mod tests {
389402
390403 Ok ( ( ) )
391404 }
405+
406+ #[ test]
407+ fn chat_prompts_builder_helpers ( ) {
408+ let mut prompts = ChatPrompts :: new ( ) ;
409+ prompts
410+ . set_system ( "system" )
411+ . set_search_description ( "desc" )
412+ . set_search_q_param ( "q" )
413+ . set_search_index_uid_param ( "idx" )
414+ . insert ( "custom" , "value" ) ;
415+
416+ assert_eq ! ( prompts. system. as_deref( ) , Some ( "system" ) ) ;
417+ assert_eq ! ( prompts. search_description. as_deref( ) , Some ( "desc" ) ) ;
418+ assert_eq ! ( prompts. search_q_param. as_deref( ) , Some ( "q" ) ) ;
419+ assert_eq ! ( prompts. search_index_uid_param. as_deref( ) , Some ( "idx" ) ) ;
420+ assert_eq ! (
421+ prompts. extra. get( "custom" ) . map( String :: as_str) ,
422+ Some ( "value" )
423+ ) ;
424+ }
425+
426+ #[ test]
427+ fn chat_workspace_settings_builder_helpers ( ) {
428+ let mut settings = ChatWorkspaceSettings :: new ( ) ;
429+ settings
430+ . set_source ( "openAi" )
431+ . set_org_id ( "org" )
432+ . set_project_id ( "project" )
433+ . set_api_version ( "2024-01-01" )
434+ . set_deployment_id ( "deployment" )
435+ . set_base_url ( "http://example.com" )
436+ . set_api_key ( "secret" )
437+ . set_prompts ( {
438+ let mut prompts = ChatPrompts :: new ( ) ;
439+ prompts. set_system ( "hi" ) ;
440+ prompts
441+ } ) ;
442+
443+ assert_eq ! ( settings. source. as_deref( ) , Some ( "openAi" ) ) ;
444+ assert_eq ! ( settings. org_id. as_deref( ) , Some ( "org" ) ) ;
445+ assert_eq ! ( settings. project_id. as_deref( ) , Some ( "project" ) ) ;
446+ assert_eq ! ( settings. api_version. as_deref( ) , Some ( "2024-01-01" ) ) ;
447+ assert_eq ! ( settings. deployment_id. as_deref( ) , Some ( "deployment" ) ) ;
448+ assert_eq ! ( settings. base_url. as_deref( ) , Some ( "http://example.com" ) ) ;
449+ assert_eq ! ( settings. api_key. as_deref( ) , Some ( "secret" ) ) ;
450+ assert_eq ! (
451+ settings. prompts. and_then( |p| p. system) . as_deref( ) ,
452+ Some ( "hi" )
453+ ) ;
454+ }
455+
456+ #[ test]
457+ #[ cfg( feature = "reqwest" ) ]
458+ fn stream_chat_completion_request_includes_expected_headers ( ) {
459+ use reqwest:: header:: { AUTHORIZATION , CONTENT_TYPE } ;
460+
461+ let client = Client :: new ( "http://localhost:7700" , Some ( "secret" ) ) . unwrap ( ) ;
462+ let body = json ! ( {
463+ "model" : "gpt-3.5-turbo" ,
464+ "messages" : [ { "role" : "user" , "content" : "Hello" } ] ,
465+ "stream" : true
466+ } ) ;
467+
468+ let request = client
469+ . build_stream_chat_request ( "workspace" , & body)
470+ . expect ( "request should be built" ) ;
471+
472+ assert_eq ! ( request. method( ) , reqwest:: Method :: POST ) ;
473+ assert_eq ! (
474+ request. url( ) . as_str( ) ,
475+ "http://localhost:7700/chats/workspace/chat/completions"
476+ ) ;
477+
478+ let headers = request. headers ( ) ;
479+ assert_eq ! (
480+ headers
481+ . get( reqwest:: header:: ACCEPT )
482+ . map( |h| h. to_str( ) . unwrap( ) ) ,
483+ Some ( "text/event-stream" )
484+ ) ;
485+ assert_eq ! (
486+ headers. get( CONTENT_TYPE ) . map( |h| h. to_str( ) . unwrap( ) ) ,
487+ Some ( "application/json" )
488+ ) ;
489+ assert_eq ! (
490+ headers. get( AUTHORIZATION ) . map( |h| h. to_str( ) . unwrap( ) ) ,
491+ Some ( "Bearer secret" )
492+ ) ;
493+
494+ let expected_body = body. to_string ( ) ;
495+ let request_body = request
496+ . body ( )
497+ . and_then ( |b| b. as_bytes ( ) )
498+ . expect ( "request has body" ) ;
499+ assert_eq ! ( request_body, expected_body. as_bytes( ) ) ;
500+ }
392501}
0 commit comments