Skip to content

Commit 6013669

Browse files
authored
feat: configurable per request path (#479)
* refactor to avoid repetition * configurable per request path
1 parent 340aa07 commit 6013669

File tree

3 files changed

+83
-149
lines changed

3 files changed

+83
-149
lines changed

async-openai/src/client.rs

Lines changed: 61 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,35 @@ impl<C: Config> Client<C> {
197197
&self.config
198198
}
199199

200+
/// Helper function to build a request builder with common configuration
201+
fn build_request_builder(
202+
&self,
203+
method: reqwest::Method,
204+
path: &str,
205+
request_options: &RequestOptions,
206+
) -> reqwest::RequestBuilder {
207+
let mut request_builder = if let Some(path) = request_options.path() {
208+
self.http_client
209+
.request(method, self.config.url(path.as_str()))
210+
} else {
211+
self.http_client.request(method, self.config.url(path))
212+
};
213+
214+
request_builder = request_builder
215+
.query(&self.config.query())
216+
.headers(self.config.headers());
217+
218+
if let Some(headers) = request_options.headers() {
219+
request_builder = request_builder.headers(headers.clone());
220+
}
221+
222+
if !request_options.query().is_empty() {
223+
request_builder = request_builder.query(request_options.query());
224+
}
225+
226+
request_builder
227+
}
228+
200229
/// Make a GET request to {path} and deserialize the response body
201230
pub(crate) async fn get<O>(
202231
&self,
@@ -207,21 +236,9 @@ impl<C: Config> Client<C> {
207236
O: DeserializeOwned,
208237
{
209238
let request_maker = || async {
210-
let mut request_builder = self
211-
.http_client
212-
.get(self.config.url(path))
213-
.query(&self.config.query())
214-
.headers(self.config.headers());
215-
216-
if let Some(headers) = request_options.headers() {
217-
request_builder = request_builder.headers(headers.clone());
218-
}
219-
220-
if !request_options.query().is_empty() {
221-
request_builder = request_builder.query(request_options.query());
222-
}
223-
224-
Ok(request_builder.build()?)
239+
Ok(self
240+
.build_request_builder(reqwest::Method::GET, path, request_options)
241+
.build()?)
225242
};
226243

227244
self.execute(request_maker).await
@@ -237,21 +254,9 @@ impl<C: Config> Client<C> {
237254
O: DeserializeOwned,
238255
{
239256
let request_maker = || async {
240-
let mut request_builder = self
241-
.http_client
242-
.delete(self.config.url(path))
243-
.query(&self.config.query())
244-
.headers(self.config.headers());
245-
246-
if let Some(headers) = request_options.headers() {
247-
request_builder = request_builder.headers(headers.clone());
248-
}
249-
250-
if !request_options.query().is_empty() {
251-
request_builder = request_builder.query(request_options.query());
252-
}
253-
254-
Ok(request_builder.build()?)
257+
Ok(self
258+
.build_request_builder(reqwest::Method::DELETE, path, request_options)
259+
.build()?)
255260
};
256261

257262
self.execute(request_maker).await
@@ -264,21 +269,9 @@ impl<C: Config> Client<C> {
264269
request_options: &RequestOptions,
265270
) -> Result<(Bytes, HeaderMap), OpenAIError> {
266271
let request_maker = || async {
267-
let mut request_builder = self
268-
.http_client
269-
.get(self.config.url(path))
270-
.query(&self.config.query())
271-
.headers(self.config.headers());
272-
273-
if let Some(headers) = request_options.headers() {
274-
request_builder = request_builder.headers(headers.clone());
275-
}
276-
277-
if !request_options.query().is_empty() {
278-
request_builder = request_builder.query(request_options.query());
279-
}
280-
281-
Ok(request_builder.build()?)
272+
Ok(self
273+
.build_request_builder(reqwest::Method::GET, path, request_options)
274+
.build()?)
282275
};
283276

284277
self.execute_raw(request_maker).await
@@ -295,22 +288,10 @@ impl<C: Config> Client<C> {
295288
I: Serialize,
296289
{
297290
let request_maker = || async {
298-
let mut request_builder = self
299-
.http_client
300-
.post(self.config.url(path))
301-
.query(&self.config.query())
302-
.headers(self.config.headers())
303-
.json(&request);
304-
305-
if let Some(headers) = request_options.headers() {
306-
request_builder = request_builder.headers(headers.clone());
307-
}
308-
309-
if !request_options.query().is_empty() {
310-
request_builder = request_builder.query(request_options.query());
311-
}
312-
313-
Ok(request_builder.build()?)
291+
Ok(self
292+
.build_request_builder(reqwest::Method::POST, path, request_options)
293+
.json(&request)
294+
.build()?)
314295
};
315296

316297
self.execute_raw(request_maker).await
@@ -328,22 +309,10 @@ impl<C: Config> Client<C> {
328309
O: DeserializeOwned,
329310
{
330311
let request_maker = || async {
331-
let mut request_builder = self
332-
.http_client
333-
.post(self.config.url(path))
334-
.query(&self.config.query())
335-
.headers(self.config.headers())
336-
.json(&request);
337-
338-
if let Some(headers) = request_options.headers() {
339-
request_builder = request_builder.headers(headers.clone());
340-
}
341-
342-
if !request_options.query().is_empty() {
343-
request_builder = request_builder.query(request_options.query());
344-
}
345-
346-
Ok(request_builder.build()?)
312+
Ok(self
313+
.build_request_builder(reqwest::Method::POST, path, request_options)
314+
.json(&request)
315+
.build()?)
347316
};
348317

349318
self.execute(request_maker).await
@@ -361,22 +330,10 @@ impl<C: Config> Client<C> {
361330
F: Clone,
362331
{
363332
let request_maker = || async {
364-
let mut request_builder = self
365-
.http_client
366-
.post(self.config.url(path))
367-
.query(&self.config.query())
368-
.headers(self.config.headers())
369-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
370-
371-
if let Some(headers) = request_options.headers() {
372-
request_builder = request_builder.headers(headers.clone());
373-
}
374-
375-
if !request_options.query().is_empty() {
376-
request_builder = request_builder.query(request_options.query());
377-
}
378-
379-
Ok(request_builder.build()?)
333+
Ok(self
334+
.build_request_builder(reqwest::Method::POST, path, request_options)
335+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
336+
.build()?)
380337
};
381338

382339
self.execute_raw(request_maker).await
@@ -395,22 +352,10 @@ impl<C: Config> Client<C> {
395352
F: Clone,
396353
{
397354
let request_maker = || async {
398-
let mut request_builder = self
399-
.http_client
400-
.post(self.config.url(path))
401-
.query(&self.config.query())
402-
.headers(self.config.headers())
403-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
404-
405-
if let Some(headers) = request_options.headers() {
406-
request_builder = request_builder.headers(headers.clone());
407-
}
408-
409-
if !request_options.query().is_empty() {
410-
request_builder = request_builder.query(request_options.query());
411-
}
412-
413-
Ok(request_builder.build()?)
355+
Ok(self
356+
.build_request_builder(reqwest::Method::POST, path, request_options)
357+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
358+
.build()?)
414359
};
415360

416361
self.execute(request_maker).await
@@ -429,20 +374,9 @@ impl<C: Config> Client<C> {
429374
{
430375
// Build and execute request manually since multipart::Form is not Clone
431376
// and .eventsource() requires cloneability
432-
let mut request_builder = self
433-
.http_client
434-
.post(self.config.url(path))
435-
.query(&self.config.query())
436-
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
437-
.headers(self.config.headers());
438-
439-
if let Some(headers) = request_options.headers() {
440-
request_builder = request_builder.headers(headers.clone());
441-
}
442-
443-
if !request_options.query().is_empty() {
444-
request_builder = request_builder.query(request_options.query());
445-
}
377+
let request_builder = self
378+
.build_request_builder(reqwest::Method::POST, path, request_options)
379+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?);
446380

447381
let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
448382

@@ -580,21 +514,10 @@ impl<C: Config> Client<C> {
580514
I: Serialize,
581515
O: DeserializeOwned + std::marker::Send + 'static,
582516
{
583-
let mut request_builder = self
584-
.http_client
585-
.post(self.config.url(path))
586-
.query(&self.config.query())
587-
.headers(self.config.headers())
517+
let request_builder = self
518+
.build_request_builder(reqwest::Method::POST, path, request_options)
588519
.json(&request);
589520

590-
if let Some(headers) = request_options.headers() {
591-
request_builder = request_builder.headers(headers.clone());
592-
}
593-
594-
if !request_options.query().is_empty() {
595-
request_builder = request_builder.query(request_options.query());
596-
}
597-
598521
let event_source = request_builder.eventsource().unwrap();
599522

600523
stream(event_source).await
@@ -611,21 +534,10 @@ impl<C: Config> Client<C> {
611534
I: Serialize,
612535
O: DeserializeOwned + std::marker::Send + 'static,
613536
{
614-
let mut request_builder = self
615-
.http_client
616-
.post(self.config.url(path))
617-
.query(&self.config.query())
618-
.headers(self.config.headers())
537+
let request_builder = self
538+
.build_request_builder(reqwest::Method::POST, path, request_options)
619539
.json(&request);
620540

621-
if let Some(headers) = request_options.headers() {
622-
request_builder = request_builder.headers(headers.clone());
623-
}
624-
625-
if !request_options.query().is_empty() {
626-
request_builder = request_builder.query(request_options.query());
627-
}
628-
629541
let event_source = request_builder.eventsource().unwrap();
630542

631543
stream_mapped_raw_events(event_source, event_mapper).await

async-openai/src/request_options.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,28 @@ use crate::{config::OPENAI_API_BASE, error::OpenAIError};
88
pub struct RequestOptions {
99
query: Option<Vec<(String, String)>>,
1010
headers: Option<HeaderMap>,
11+
path: Option<String>,
1112
}
1213

1314
impl RequestOptions {
1415
pub(crate) fn new() -> Self {
1516
Self {
1617
query: None,
1718
headers: None,
19+
path: None,
1820
}
1921
}
2022

23+
pub(crate) fn with_path(&mut self, path: &str) -> Result<(), OpenAIError> {
24+
if path.is_empty() {
25+
return Err(OpenAIError::InvalidArgument(
26+
"Path cannot be empty".to_string(),
27+
));
28+
}
29+
self.path = Some(path.to_string());
30+
Ok(())
31+
}
32+
2133
pub(crate) fn with_headers(&mut self, headers: HeaderMap) {
2234
// merge with existing headers or update with new headers
2335
if let Some(existing_headers) = &mut self.headers {
@@ -81,4 +93,8 @@ impl RequestOptions {
8193
pub(crate) fn headers(&self) -> Option<&HeaderMap> {
8294
self.headers.as_ref()
8395
}
96+
97+
pub(crate) fn path(&self) -> Option<&String> {
98+
self.path.as_ref()
99+
}
84100
}

async-openai/src/traits.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,10 @@ pub trait RequestOptionsBuilder: Sized {
5353
self.options_mut().with_query(query)?;
5454
Ok(self)
5555
}
56+
57+
/// Add a path to RequestOptions
58+
fn path<P: Into<String>>(mut self, path: P) -> Result<Self, OpenAIError> {
59+
self.options_mut().with_path(path.into().as_str())?;
60+
Ok(self)
61+
}
5662
}

0 commit comments

Comments
 (0)