From 1d0b6b07b3490977a02947ce7f2ad8f10f19b821 Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Thu, 20 Nov 2025 22:45:07 +0000 Subject: [PATCH 1/9] Refactor to aid upcoming change to HTTP request backend API. The RequestImpl type alias for 1 of the 3 backends is now referenced in only one place, a new Bucket::make_request method which everything else that previously called RequestImpl::new now calls. The constructors for each of these backends are sync, as constructors should be. This will make it easier to change to a different backend factory API. The only reason these were async anyway is so that credentials_refresh could be called, but make_request can just do that. --- s3/src/bucket.rs | 135 +++++++++++++++------------- s3/src/put_object_request.rs | 9 +- s3/src/request/async_std_backend.rs | 19 ++-- s3/src/request/blocking.rs | 1 - s3/src/request/tokio_backend.rs | 21 ++--- 5 files changed, 86 insertions(+), 99 deletions(-) diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 9d240f43ce..ae3e56eaa5 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -170,6 +170,16 @@ impl Bucket { pub fn http_client(&self) -> reqwest::Client { self.http_client.clone() } + + #[maybe_async::maybe_async] + pub(crate) async fn make_request<'a>( + &'a self, + path: &'a str, + command: Command<'a>, + ) -> Result, S3Error> { + self.credentials_refresh().await?; + RequestImpl::new(self, path, command) + } } fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> { @@ -220,15 +230,15 @@ impl Bucket { custom_queries: Option>, ) -> Result { validate_expiry(expiry_secs)?; - let request = RequestImpl::new( - self, - path.as_ref(), - Command::PresignGet { - expiry_secs, - custom_queries, - }, - ) - .await?; + let request = self + .make_request( + path.as_ref(), + Command::PresignGet { + expiry_secs, + custom_queries, + }, + ) + .await?; request.presigned().await } @@ -303,16 +313,16 @@ impl Bucket { custom_queries: Option>, ) -> Result { validate_expiry(expiry_secs)?; - let request = RequestImpl::new( - self, - path.as_ref(), - Command::PresignPut { - expiry_secs, - custom_headers, - custom_queries, - }, - ) - .await?; + let request = self + .make_request( + path.as_ref(), + Command::PresignPut { + expiry_secs, + custom_headers, + custom_queries, + }, + ) + .await?; request.presigned().await } @@ -343,8 +353,9 @@ impl Bucket { expiry_secs: u32, ) -> Result { validate_expiry(expiry_secs)?; - let request = - RequestImpl::new(self, path.as_ref(), Command::PresignDelete { expiry_secs }).await?; + let request = self + .make_request(path.as_ref(), Command::PresignDelete { expiry_secs }) + .await?; request.presigned().await } @@ -402,7 +413,7 @@ impl Bucket { let command = Command::CreateBucket { config }; let bucket = Bucket::new(name, region, credentials)?; - let request = RequestImpl::new(&bucket, "", command).await?; + let request = bucket.make_request("", command).await?; let response_data = request.response_data(false).await?; let response_text = response_data.as_str()?; Ok(CreateBucketResponse { @@ -458,7 +469,7 @@ impl Bucket { /// Used by the public `list_buckets` method to retrieve the list of buckets for the configured client. #[maybe_async::maybe_async] async fn _list_buckets(&self) -> Result { - let request = RequestImpl::new(self, "", Command::ListBuckets).await?; + let request = self.make_request("", Command::ListBuckets).await?; let response = request.response_data(false).await?; Ok(quick_xml::de::from_str::< @@ -564,7 +575,7 @@ impl Bucket { let command = Command::CreateBucket { config }; let bucket = Bucket::new(name, region, credentials)?.with_path_style(); - let request = RequestImpl::new(&bucket, "", command).await?; + let request = bucket.make_request("", command).await?; let response_data = request.response_data(false).await?; let response_text = response_data.to_string()?; @@ -608,7 +619,7 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn delete(&self) -> Result { let command = Command::DeleteBucket; - let request = RequestImpl::new(self, "", command).await?; + let request = self.make_request("", command).await?; let response_data = request.response_data(false).await?; Ok(response_data.status_code()) } @@ -915,7 +926,7 @@ impl Bucket { let command = Command::CopyObject { from: from.as_ref(), }; - let request = RequestImpl::new(self, to.as_ref(), command).await?; + let request = self.make_request(to.as_ref(), command).await?; let response_data = request.response_data(false).await?; Ok(response_data.status_code()) } @@ -954,7 +965,7 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn get_object>(&self, path: S) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(false).await } @@ -969,7 +980,7 @@ impl Bucket { expected_bucket_owner: expected_bucket_owner.to_string(), version_id, }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; let response = request.response_data(false).await?; @@ -1023,7 +1034,7 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn object_exists>(&self, path: S) -> Result { let command = Command::HeadObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; let response_data = match request.response_data(false).await { Ok(response_data) => response_data, Err(S3Error::HttpFailWithBody(status_code, error)) => { @@ -1047,7 +1058,7 @@ impl Bucket { expected_bucket_owner: expected_bucket_owner.to_string(), configuration: cors_config.clone(), }; - let request = RequestImpl::new(self, "", command).await?; + let request = self.make_request("", command).await?; request.response_data(false).await } @@ -1059,7 +1070,7 @@ impl Bucket { let command = Command::GetBucketCors { expected_bucket_owner: expected_bucket_owner.to_string(), }; - let request = RequestImpl::new(self, "", command).await?; + let request = self.make_request("", command).await?; let response = request.response_data(false).await?; Ok(quick_xml::de::from_str::( response.as_str()?, @@ -1074,13 +1085,13 @@ impl Bucket { let command = Command::DeleteBucketCors { expected_bucket_owner: expected_bucket_owner.to_string(), }; - let request = RequestImpl::new(self, "", command).await?; + let request = self.make_request("", command).await?; request.response_data(false).await } #[maybe_async::maybe_async] pub async fn get_bucket_lifecycle(&self) -> Result { - let request = RequestImpl::new(self, "", Command::GetBucketLifecycle).await?; + let request = self.make_request("", Command::GetBucketLifecycle).await?; let response = request.response_data(false).await?; Ok(quick_xml::de::from_str::( response.as_str()?, @@ -1095,13 +1106,15 @@ impl Bucket { let command = Command::PutBucketLifecycle { configuration: lifecycle_config, }; - let request = RequestImpl::new(self, "", command).await?; + let request = self.make_request("", command).await?; request.response_data(false).await } #[maybe_async::maybe_async] pub async fn delete_bucket_lifecycle(&self) -> Result { - let request = RequestImpl::new(self, "", Command::DeleteBucketLifecycle).await?; + let request = self + .make_request("", Command::DeleteBucketLifecycle) + .await?; request.response_data(false).await } @@ -1142,7 +1155,7 @@ impl Bucket { path: S, ) -> Result { let command = Command::GetObjectTorrent; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(false).await } @@ -1190,7 +1203,7 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(false).await } @@ -1251,7 +1264,7 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data_to_writer(writer).await } @@ -1268,7 +1281,7 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = RequestImpl::new(self, path.as_ref(), command)?; + let request = self.make_request(path.as_ref(), command)?; request.response_data_to_writer(writer) } @@ -1316,7 +1329,7 @@ impl Bucket { writer: &mut T, ) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data_to_writer(writer).await } @@ -1327,7 +1340,7 @@ impl Bucket { writer: &mut T, ) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command)?; + let request = self.make_request(path.as_ref(), command)?; request.response_data_to_writer(writer) } @@ -1378,7 +1391,7 @@ impl Bucket { path: S, ) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data_to_stream().await } @@ -1581,7 +1594,7 @@ impl Bucket { custom_headers: None, content_type, }; - let request = RequestImpl::new(self, path, command).await?; + let request = self.make_request(path, command).await?; request.response_data(true).await } @@ -1861,7 +1874,7 @@ impl Bucket { content_type: &str, ) -> Result { let command = Command::InitiateMultipartUpload { content_type }; - let request = RequestImpl::new(self, s3_path, command).await?; + let request = self.make_request(s3_path, command).await?; let response_data = request.response_data(false).await?; if response_data.status_code() >= 300 { return Err(error_from_response_data(response_data)?); @@ -1879,7 +1892,7 @@ impl Bucket { content_type: &str, ) -> Result { let command = Command::InitiateMultipartUpload { content_type }; - let request = RequestImpl::new(self, s3_path, command)?; + let request = self.make_request(s3_path, command)?; let response_data = request.response_data(false)?; if response_data.status_code() >= 300 { return Err(error_from_response_data(response_data)?); @@ -1935,7 +1948,7 @@ impl Bucket { custom_headers: None, content_type, }; - let request = RequestImpl::new(self, path, command).await?; + let request = self.make_request(path, command).await?; let response_data = request.response_data(true).await?; if !(200..300).contains(&response_data.status_code()) { // if chunk upload failed - abort the upload @@ -1971,7 +1984,7 @@ impl Bucket { custom_headers: None, content_type, }; - let request = RequestImpl::new(self, path, command)?; + let request = self.make_request(path, command)?; let response_data = request.response_data(true)?; if !(200..300).contains(&response_data.status_code()) { // if chunk upload failed - abort the upload @@ -2001,7 +2014,7 @@ impl Bucket { ) -> Result { let data = CompleteMultipartUploadData { parts }; let complete = Command::CompleteMultipartUpload { upload_id, data }; - let complete_request = RequestImpl::new(self, path, complete).await?; + let complete_request = self.make_request(path, complete).await?; complete_request.response_data(false).await } @@ -2014,7 +2027,7 @@ impl Bucket { ) -> Result { let data = CompleteMultipartUploadData { parts }; let complete = Command::CompleteMultipartUpload { upload_id, data }; - let complete_request = RequestImpl::new(self, path, complete)?; + let complete_request = self.make_request(path, complete)?; complete_request.response_data(false) } @@ -2052,7 +2065,9 @@ impl Bucket { /// ``` #[maybe_async::maybe_async] pub async fn location(&self) -> Result<(Region, u16), S3Error> { - let request = RequestImpl::new(self, "?location", Command::GetBucketLocation).await?; + let request = self + .make_request("?location", Command::GetBucketLocation) + .await?; let response_data = request.response_data(false).await?; let region_string = String::from_utf8_lossy(response_data.as_slice()); let region = match quick_xml::de::from_reader(region_string.as_bytes()) { @@ -2112,7 +2127,7 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn delete_object>(&self, path: S) -> Result { let command = Command::DeleteObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(false).await } @@ -2154,7 +2169,7 @@ impl Bucket { path: S, ) -> Result<(HeadObjectResult, u16), S3Error> { let command = Command::HeadObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; let (headers, status) = request.response_header().await?; let header_object = HeadObjectResult::from(&headers); Ok((header_object, status)) @@ -2206,7 +2221,7 @@ impl Bucket { custom_headers: None, multipart: None, }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(true).await } @@ -2266,7 +2281,7 @@ impl Bucket { custom_headers, multipart: None, }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(true).await } @@ -2460,7 +2475,7 @@ impl Bucket { ) -> Result { let content = self._tags_xml(tags); let command = Command::PutObjectTagging { tags: &content }; - let request = RequestImpl::new(self, path, command).await?; + let request = self.make_request(path, command).await?; request.response_data(false).await } @@ -2502,7 +2517,7 @@ impl Bucket { path: S, ) -> Result { let command = Command::DeleteObjectTagging; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; request.response_data(false).await } @@ -2545,7 +2560,7 @@ impl Bucket { path: S, ) -> Result<(Vec, u16), S3Error> { let command = Command::GetObjectTagging {}; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let request = self.make_request(path.as_ref(), command).await?; let result = request.response_data(false).await?; let mut tags = Vec::new(); @@ -2618,7 +2633,7 @@ impl Bucket { max_keys, } }; - let request = RequestImpl::new(self, "/", command).await?; + let request = self.make_request("/", command).await?; let response_data = request.response_data(false).await?; let list_bucket_result = quick_xml::de::from_reader(response_data.as_slice())?; @@ -2702,7 +2717,7 @@ impl Bucket { key_marker, max_uploads, }; - let request = RequestImpl::new(self, "/", command).await?; + let request = self.make_request("/", command).await?; let response_data = request.response_data(false).await?; let list_bucket_result = quick_xml::de::from_reader(response_data.as_slice())?; @@ -2806,7 +2821,7 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn abort_upload(&self, key: &str, upload_id: &str) -> Result<(), S3Error> { let abort = Command::AbortMultipartUpload { upload_id }; - let abort_request = RequestImpl::new(self, key, abort).await?; + let abort_request = self.make_request(key, abort).await?; let response_data = abort_request.response_data(false).await?; if (200..300).contains(&response_data.status_code()) { diff --git a/s3/src/put_object_request.rs b/s3/src/put_object_request.rs index 1b48227661..014b90f808 100644 --- a/s3/src/put_object_request.rs +++ b/s3/src/put_object_request.rs @@ -14,13 +14,6 @@ use tokio::io::AsyncRead; #[cfg(feature = "with-async-std")] use async_std::io::Read as AsyncRead; -#[cfg(feature = "with-async-std")] -use crate::request::async_std_backend::SurfRequest as RequestImpl; -#[cfg(feature = "sync")] -use crate::request::blocking::AttoRequest as RequestImpl; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::ReqwestRequest as RequestImpl; - /// Builder for constructing S3 PUT object requests with custom options /// /// # Example @@ -195,7 +188,7 @@ impl<'a> PutObjectRequest<'a> { multipart: None, }; - let request = RequestImpl::new(self.bucket, &self.path, command).await?; + let request = self.bucket.make_request(&self.path, command).await?; request.response_data(true).await } } diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index c0a345d93c..36f1385631 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -184,12 +184,11 @@ impl<'a> Request for SurfRequest<'a> { } impl<'a> SurfRequest<'a> { - pub async fn new<'b>( + pub fn new<'b>( bucket: &'b Bucket, path: &'b str, command: Command<'b>, ) -> Result, S3Error> { - bucket.credentials_refresh().await?; Ok(SurfRequest { bucket, path, @@ -222,9 +221,7 @@ mod tests { let region = "custom-region".parse()?; let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; let path = "/my-first/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url()?.scheme(), "https"); @@ -240,9 +237,7 @@ mod tests { let region = "custom-region".parse()?; let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?.with_path_style(); let path = "/my-first/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "https"); @@ -258,9 +253,7 @@ mod tests { let region = "http://custom-region".parse()?; let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; let path = "/my-second/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "http"); @@ -275,9 +268,7 @@ mod tests { let region = "http://custom-region".parse()?; let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?.with_path_style(); let path = "/my-second/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "http"); diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index ac065fd898..7bd15b9063 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -148,7 +148,6 @@ impl<'a> AttoRequest<'a> { path: &'b str, command: Command<'b>, ) -> Result, S3Error> { - bucket.credentials_refresh()?; Ok(AttoRequest { bucket, path, diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 498f19135b..f8211daae6 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -210,12 +210,11 @@ impl<'a> Request for ReqwestRequest<'a> { } impl<'a> ReqwestRequest<'a> { - pub async fn new( + pub fn new( bucket: &'a Bucket, path: &'a str, command: Command<'a>, ) -> Result, S3Error> { - bucket.credentials_refresh().await?; Ok(Self { bucket, path, @@ -248,9 +247,7 @@ mod tests { let region = "custom-region".parse().unwrap(); let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); let path = "/my-first/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "https"); @@ -267,9 +264,7 @@ mod tests { .unwrap() .with_path_style(); let path = "/my-first/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "https"); @@ -284,9 +279,7 @@ mod tests { let region = "http://custom-region".parse().unwrap(); let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap(); let path = "/my-second/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "http"); @@ -302,9 +295,7 @@ mod tests { .unwrap() .with_path_style(); let path = "/my-second/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); + let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); assert_eq!(request.url().unwrap().scheme(), "http"); @@ -329,7 +320,6 @@ mod tests { end: None, }, ) - .await .unwrap(); let headers = request.headers().await.unwrap(); let range = headers.get(RANGE).unwrap(); @@ -343,7 +333,6 @@ mod tests { end: Some(1), }, ) - .await .unwrap(); let headers = request.headers().await.unwrap(); let range = headers.get(RANGE).unwrap(); From a240b059da457d83ad8e1287e087279a0805a5b9 Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Fri, 21 Nov 2025 20:00:40 +0000 Subject: [PATCH 2/9] Deliver requests to all 3 backends in http::Request format. This will make it easier to offer pluggable backends in the future. --- s3/src/bucket.rs | 59 +++--- s3/src/error.rs | 1 - s3/src/request/async_std_backend.rs | 103 ++++------ s3/src/request/blocking.rs | 96 ++++----- s3/src/request/request_trait.rs | 296 +++++++++++++++++----------- s3/src/request/tokio_backend.rs | 121 ++++-------- 6 files changed, 323 insertions(+), 353 deletions(-) diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index ae3e56eaa5..28d658bee7 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -50,7 +50,7 @@ use crate::request::ResponseDataStream; use crate::request::tokio_backend::ClientOptions; #[cfg(feature = "with-tokio")] use crate::request::tokio_backend::client; -use crate::request::{Request as _, ResponseData}; +use crate::request::{Request as _, ResponseData, build_presigned, build_request}; use std::str::FromStr; use std::sync::Arc; @@ -173,12 +173,19 @@ impl Bucket { #[maybe_async::maybe_async] pub(crate) async fn make_request<'a>( - &'a self, - path: &'a str, + &self, + path: &str, command: Command<'a>, ) -> Result, S3Error> { self.credentials_refresh().await?; - RequestImpl::new(self, path, command) + let http_request = build_request(self, path, command).await?; + RequestImpl::new(http_request, self) + } + + #[maybe_async::maybe_async] + async fn make_presigned(&self, path: &str, command: Command<'_>) -> Result { + self.credentials_refresh().await?; + build_presigned(self, path, command).await } } @@ -230,16 +237,14 @@ impl Bucket { custom_queries: Option>, ) -> Result { validate_expiry(expiry_secs)?; - let request = self - .make_request( - path.as_ref(), - Command::PresignGet { - expiry_secs, - custom_queries, - }, - ) - .await?; - request.presigned().await + self.make_presigned( + path.as_ref(), + Command::PresignGet { + expiry_secs, + custom_queries, + }, + ) + .await } /// Get a presigned url for posting an object to a given path @@ -313,17 +318,15 @@ impl Bucket { custom_queries: Option>, ) -> Result { validate_expiry(expiry_secs)?; - let request = self - .make_request( - path.as_ref(), - Command::PresignPut { - expiry_secs, - custom_headers, - custom_queries, - }, - ) - .await?; - request.presigned().await + self.make_presigned( + path.as_ref(), + Command::PresignPut { + expiry_secs, + custom_headers, + custom_queries, + }, + ) + .await } /// Get a presigned url for deleting object on a given path @@ -353,10 +356,8 @@ impl Bucket { expiry_secs: u32, ) -> Result { validate_expiry(expiry_secs)?; - let request = self - .make_request(path.as_ref(), Command::PresignDelete { expiry_secs }) - .await?; - request.presigned().await + self.make_presigned(path.as_ref(), Command::PresignDelete { expiry_secs }) + .await } /// Create a new `Bucket` and instantiate it diff --git a/s3/src/error.rs b/s3/src/error.rs index a3191579a7..8e1c56ffac 100644 --- a/s3/src/error.rs +++ b/s3/src/error.rs @@ -21,7 +21,6 @@ pub enum S3Error { UrlParse(#[from] url::ParseError), #[error("io: {0}")] Io(#[from] std::io::Error), - #[cfg(feature = "with-tokio")] #[error("http: {0}")] Http(#[from] http::Error), #[cfg(feature = "with-tokio")] diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index 36f1385631..c3fdcf9110 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -3,15 +3,12 @@ use async_std::io::{ReadExt, WriteExt}; use async_std::stream::StreamExt; use bytes::Bytes; use futures_util::FutureExt; +use std::borrow::Cow; use std::collections::HashMap; use crate::bucket::Bucket; -use crate::command::Command; use crate::error::S3Error; -use crate::utils::now_utc; -use time::OffsetDateTime; -use crate::command::HttpMethod; use crate::request::{Request, ResponseData, ResponseDataStream}; use http::HeaderMap; @@ -21,10 +18,7 @@ use surf::http::headers::{HeaderName, HeaderValue}; // Temporary structure for making a request pub struct SurfRequest<'a> { - pub bucket: &'a Bucket, - pub path: &'a str, - pub command: Command<'a>, - pub datetime: OffsetDateTime, + request: http::Request>, pub sync: bool, } @@ -33,37 +27,24 @@ impl<'a> Request for SurfRequest<'a> { type Response = surf::Response; type HeaderMap = HeaderMap; - fn datetime(&self) -> OffsetDateTime { - self.datetime - } - - fn bucket(&self) -> Bucket { - self.bucket.clone() - } - - fn command(&self) -> Command<'_> { - self.command.clone() - } - - fn path(&self) -> String { - self.path.to_string() - } - async fn response(&self) -> Result { - // Build headers - let headers = self.headers().await?; - - let request = match self.command.http_verb() { - HttpMethod::Get => surf::Request::builder(Method::Get, self.url()?), - HttpMethod::Delete => surf::Request::builder(Method::Delete, self.url()?), - HttpMethod::Put => surf::Request::builder(Method::Put, self.url()?), - HttpMethod::Post => surf::Request::builder(Method::Post, self.url()?), - HttpMethod::Head => surf::Request::builder(Method::Head, self.url()?), - }; - - let mut request = request.body(self.request_body()?); + let url = format!("{}", self.request.uri()).parse()?; + let mut request = match *self.request.method() { + http::Method::GET => surf::Request::builder(Method::Get, url), + http::Method::DELETE => surf::Request::builder(Method::Delete, url), + http::Method::PUT => surf::Request::builder(Method::Put, url), + http::Method::POST => surf::Request::builder(Method::Post, url), + http::Method::HEAD => surf::Request::builder(Method::Head, url), + ref m => surf::Request::builder( + m.as_str() + .parse() + .map_err(|e: surf::Error| S3Error::Surf(e.to_string()))?, + url, + ), + } + .body(self.request.body().clone().into_owned()); - for (name, value) in headers.iter() { + for (name, value) in self.request.headers().iter() { request = request.header( HeaderName::from_bytes(AsRef::<[u8]>::as_ref(&name).to_vec()) .expect("Could not parse heaeder name"), @@ -184,16 +165,9 @@ impl<'a> Request for SurfRequest<'a> { } impl<'a> SurfRequest<'a> { - pub fn new<'b>( - bucket: &'b Bucket, - path: &'b str, - command: Command<'b>, - ) -> Result, S3Error> { - Ok(SurfRequest { - bucket, - path, - command, - datetime: now_utc(), + pub fn new(request: http::Request>, _: &Bucket) -> Result { + Ok(Self { + request, sync: false, }) } @@ -203,8 +177,7 @@ impl<'a> SurfRequest<'a> { mod tests { use crate::bucket::Bucket; use crate::command::Command; - use crate::request::Request; - use crate::request::async_std_backend::SurfRequest; + use crate::request::request_trait::build_request; use anyhow::Result; use awscreds::Credentials; @@ -221,11 +194,13 @@ mod tests { let region = "custom-region".parse()?; let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; let path = "/my-first/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url()?.scheme(), "https"); + assert_eq!(request.uri().scheme_str().unwrap(), "https"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "my-first-bucket.custom-region".to_string()); @@ -237,11 +212,13 @@ mod tests { let region = "custom-region".parse()?; let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?.with_path_style(); let path = "/my-first/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "https"); + assert_eq!(request.uri().scheme_str().unwrap(), "https"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "custom-region".to_string()); @@ -253,11 +230,13 @@ mod tests { let region = "http://custom-region".parse()?; let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; let path = "/my-second/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "http"); + assert_eq!(request.uri().scheme_str().unwrap(), "http"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "my-second-bucket.custom-region".to_string()); Ok(()) @@ -268,11 +247,13 @@ mod tests { let region = "http://custom-region".parse()?; let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?.with_path_style(); let path = "/my-second/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "http"); + assert_eq!(request.uri().scheme_str().unwrap(), "http"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "custom-region".to_string()); diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index 7bd15b9063..433c5ac01f 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -7,22 +7,17 @@ use std::io::Write; use attohttpc::header::HeaderName; use crate::bucket::Bucket; -use crate::command::Command; use crate::error::S3Error; -use crate::utils::now_utc; use bytes::Bytes; +use std::borrow::Cow; use std::collections::HashMap; -use time::OffsetDateTime; -use crate::command::HttpMethod; use crate::request::{Request, ResponseData}; // Temporary structure for making a request pub struct AttoRequest<'a> { - pub bucket: &'a Bucket, - pub path: &'a str, - pub command: Command<'a>, - pub datetime: OffsetDateTime, + request: http::Request>, + request_timeout: Option, pub sync: bool, } @@ -30,45 +25,29 @@ impl<'a> Request for AttoRequest<'a> { type Response = attohttpc::Response; type HeaderMap = attohttpc::header::HeaderMap; - fn datetime(&self) -> OffsetDateTime { - self.datetime - } - - fn bucket(&self) -> Bucket { - self.bucket.clone() - } - - fn command(&self) -> Command<'_> { - self.command.clone() - } - - fn path(&self) -> String { - self.path.to_string() - } - fn response(&self) -> Result { - // Build headers - let headers = self.headers()?; - let mut session = attohttpc::Session::new(); - for (name, value) in headers.iter() { + for (name, value) in self.request.headers().iter() { session.header(HeaderName::from_bytes(name.as_ref())?, value.to_str()?); } - if let Some(timeout) = self.bucket.request_timeout { + if let Some(timeout) = self.request_timeout { session.timeout(timeout) } - let request = match self.command.http_verb() { - HttpMethod::Get => session.get(self.url()?), - HttpMethod::Delete => session.delete(self.url()?), - HttpMethod::Put => session.put(self.url()?), - HttpMethod::Post => session.post(self.url()?), - HttpMethod::Head => session.head(self.url()?), + let url = format!("{}", self.request.uri()); + let request = match *self.request.method() { + http::Method::GET => session.get(url), + http::Method::DELETE => session.delete(url), + http::Method::PUT => session.put(url), + http::Method::POST => session.post(url), + http::Method::HEAD => session.head(url), + _ => { + return Err(S3Error::HttpFailWithBody(405, "".into())); + } }; - - let response = request.bytes(&self.request_body()?).send()?; + let response = request.bytes(self.request.body().clone()).send()?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { let status = response.status().as_u16(); @@ -116,7 +95,7 @@ impl<'a> Request for AttoRequest<'a> { } } else { // HEAD requests don't have a response body - if self.command.http_verb() == HttpMethod::Head { + if *self.request.method() == http::Method::HEAD { Bytes::from("") } else { Bytes::from(response.bytes()?) @@ -143,16 +122,10 @@ impl<'a> Request for AttoRequest<'a> { } impl<'a> AttoRequest<'a> { - pub fn new<'b>( - bucket: &'b Bucket, - path: &'b str, - command: Command<'b>, - ) -> Result, S3Error> { - Ok(AttoRequest { - bucket, - path, - command, - datetime: now_utc(), + pub fn new(request: http::Request>, bucket: &Bucket) -> Result { + Ok(Self { + request, + request_timeout: bucket.request_timeout, sync: false, }) } @@ -162,8 +135,7 @@ impl<'a> AttoRequest<'a> { mod tests { use crate::bucket::Bucket; use crate::command::Command; - use crate::request::Request; - use crate::request::blocking::AttoRequest; + use crate::request::request_trait::build_request; use anyhow::Result; use awscreds::Credentials; @@ -180,11 +152,11 @@ mod tests { let region = "custom-region".parse()?; let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; let path = "/my-first/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject).unwrap(); - assert_eq!(request.url()?.scheme(), "https"); + assert_eq!(request.uri().scheme_str().unwrap(), "https"); - let headers = request.headers().unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "my-first-bucket.custom-region".to_string()); @@ -197,11 +169,11 @@ mod tests { let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; bucket.with_path_style(); let path = "/my-first/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject).unwrap(); - assert_eq!(request.url()?.scheme(), "https"); + assert_eq!(request.uri().scheme_str().unwrap(), "https"); - let headers = request.headers().unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "custom-region".to_string()); @@ -213,11 +185,11 @@ mod tests { let region = "http://custom-region".parse()?; let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; let path = "/my-second/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject).unwrap(); - assert_eq!(request.url()?.scheme(), "http"); + assert_eq!(request.uri().scheme_str().unwrap(), "http"); - let headers = request.headers().unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "my-second-bucket.custom-region".to_string()); Ok(()) @@ -229,11 +201,11 @@ mod tests { let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; bucket.with_path_style(); let path = "/my-second/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject).unwrap(); - assert_eq!(request.url()?.scheme(), "http"); + assert_eq!(request.uri().scheme_str().unwrap(), "http"); - let headers = request.headers().unwrap(); + let headers = request.headers(); let host = headers.get("Host").unwrap(); assert_eq!(*host, "custom-region".to_string()); diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index f5c542b5e6..a40d932f72 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -1,7 +1,9 @@ use base64::Engine; use base64::engine::general_purpose; use hmac::Mac; +use http::Method; use quick_xml::se::to_string; +use std::borrow::Cow; use std::collections::HashMap; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] use std::pin::Pin; @@ -11,9 +13,10 @@ use url::Url; use crate::LONG_DATETIME; use crate::bucket::Bucket; -use crate::command::Command; +use crate::command::{Command, HttpMethod}; use crate::error::S3Error; use crate::signing; +use crate::utils::now_utc; use bytes::Bytes; use http::HeaderMap; use http::header::{ @@ -210,94 +213,76 @@ pub trait Request { #[cfg(any(feature = "with-async-std", feature = "with-tokio"))] async fn response_data_to_stream(&self) -> Result; async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>; - fn datetime(&self) -> OffsetDateTime; - fn bucket(&self) -> Bucket; - fn command(&self) -> Command<'_>; - fn path(&self) -> String; +} + +struct BuildHelper<'temp, 'body> { + bucket: &'temp Bucket, + path: &'temp str, + command: Command<'body>, + datetime: OffsetDateTime, +} +#[maybe_async::maybe_async] +impl<'temp, 'body> BuildHelper<'temp, 'body> { async fn signing_key(&self) -> Result, S3Error> { signing::signing_key( - &self.datetime(), + &self.datetime, &self - .bucket() + .bucket .secret_key() .await? .expect("Secret key must be provided to sign headers, found None"), - &self.bucket().region(), + &self.bucket.region(), "s3", ) } - fn request_body(&self) -> Result, S3Error> { - let result = if let Command::PutObject { content, .. } = self.command() { - Vec::from(content) - } else if let Command::PutObjectTagging { tags } = self.command() { - Vec::from(tags) - } else if let Command::UploadPart { content, .. } = self.command() { - Vec::from(content) - } else if let Command::CompleteMultipartUpload { data, .. } = &self.command() { - let body = data.to_string(); - body.as_bytes().to_vec() - } else if let Command::CreateBucket { config } = &self.command() { - if let Some(payload) = config.location_constraint_payload() { - Vec::from(payload) - } else { - Vec::new() - } - } else if let Command::PutBucketLifecycle { configuration, .. } = &self.command() { - quick_xml::se::to_string(configuration)?.as_bytes().to_vec() - } else if let Command::PutBucketCors { configuration, .. } = &self.command() { - let cors = configuration.to_string(); - cors.as_bytes().to_vec() - } else { - Vec::new() - }; - Ok(result) - } - fn long_date(&self) -> Result { - Ok(self.datetime().format(LONG_DATETIME)?) + Ok(self.datetime.format(LONG_DATETIME)?) } fn string_to_sign(&self, request: &str) -> Result { - signing::string_to_sign(&self.datetime(), &self.bucket().region(), request) + signing::string_to_sign(&self.datetime, &self.bucket.region(), request) } fn host_header(&self) -> String { - self.bucket().host() + self.bucket.host() } #[maybe_async::async_impl] async fn presigned(&self) -> Result { - let (expiry, custom_headers, custom_queries) = match self.command() { + let (expiry, custom_headers, custom_queries) = match &self.command { Command::PresignGet { expiry_secs, custom_queries, - } => (expiry_secs, None, custom_queries), + } => (expiry_secs, None, custom_queries.as_ref()), Command::PresignPut { expiry_secs, custom_headers, custom_queries, - } => (expiry_secs, custom_headers, custom_queries), + } => ( + expiry_secs, + custom_headers.as_ref(), + custom_queries.as_ref(), + ), Command::PresignDelete { expiry_secs } => (expiry_secs, None, None), _ => unreachable!(), }; let url = self - .presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref()) + .presigned_url_no_sig(*expiry, custom_headers, custom_queries) .await?; // Build the URL string preserving the original host (including standard ports) // The Url type drops standard ports when converting to string, but we need them // for signature validation - let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket().region() - { + let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket.region() { // Check if we need to preserve a standard port if (endpoint.contains(":80") && url.scheme() == "http" && url.port().is_none()) || (endpoint.contains(":443") && url.scheme() == "https" && url.port().is_none()) { // Rebuild the URL with the original host from the endpoint - let host = self.bucket().host(); + let host = self.bucket.host(); format!( "{}://{}{}{}", url.scheme(), @@ -315,41 +300,38 @@ pub trait Request { Ok(format!( "{}&X-Amz-Signature={}", url_str, - self.presigned_authorization(custom_headers.as_ref()) - .await? + self.presigned_authorization(custom_headers).await? )) } #[maybe_async::sync_impl] async fn presigned(&self) -> Result { - let (expiry, custom_headers, custom_queries) = match self.command() { + let (expiry, custom_headers, custom_queries) = match &self.command { Command::PresignGet { expiry_secs, custom_queries, - } => (expiry_secs, None, custom_queries), + } => (expiry_secs, None, custom_queries.as_ref()), Command::PresignPut { expiry_secs, custom_headers, .. - } => (expiry_secs, custom_headers, None), + } => (expiry_secs, custom_headers.as_ref(), None), Command::PresignDelete { expiry_secs } => (expiry_secs, None, None), _ => unreachable!(), }; - let url = - self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?; + let url = self.presigned_url_no_sig(*expiry, custom_headers, custom_queries)?; // Build the URL string preserving the original host (including standard ports) // The Url type drops standard ports when converting to string, but we need them // for signature validation - let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket().region() - { + let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket.region() { // Check if we need to preserve a standard port if (endpoint.contains(":80") && url.scheme() == "http" && url.port().is_none()) || (endpoint.contains(":443") && url.scheme() == "https" && url.port().is_none()) { // Rebuild the URL with the original host from the endpoint - let host = self.bucket().host(); + let host = self.bucket.host(); format!( "{}://{}{}{}", url.scheme(), @@ -367,7 +349,7 @@ pub trait Request { Ok(format!( "{}&X-Amz-Signature={}", url_str, - self.presigned_authorization(custom_headers.as_ref())? + self.presigned_authorization(custom_headers)? )) } @@ -393,24 +375,28 @@ pub trait Request { } async fn presigned_canonical_request(&self, headers: &HeaderMap) -> Result { - let (expiry, custom_headers, custom_queries) = match self.command() { + let (expiry, custom_headers, custom_queries) = match &self.command { Command::PresignGet { expiry_secs, custom_queries, - } => (expiry_secs, None, custom_queries), + } => (expiry_secs, None, custom_queries.as_ref()), Command::PresignPut { expiry_secs, custom_headers, custom_queries, - } => (expiry_secs, custom_headers, custom_queries), + } => ( + expiry_secs, + custom_headers.as_ref(), + custom_queries.as_ref(), + ), Command::PresignDelete { expiry_secs } => (expiry_secs, None, None), _ => unreachable!(), }; signing::canonical_request( - &self.command().http_verb().to_string(), + &self.command.http_verb().to_string(), &self - .presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref()) + .presigned_url_no_sig(*expiry, custom_headers, custom_queries) .await?, headers, "UNSIGNED-PAYLOAD", @@ -424,7 +410,7 @@ pub trait Request { custom_headers: Option<&HeaderMap>, custom_queries: Option<&HashMap>, ) -> Result { - let bucket = self.bucket(); + let bucket = self.bucket; let token = if let Some(security_token) = bucket.security_token().await? { Some(security_token) } else { @@ -434,9 +420,9 @@ pub trait Request { "{}{}{}", self.url()?, &signing::authorization_query_params_no_sig( - &self.bucket().access_key().await?.unwrap_or_default(), - &self.datetime(), - &self.bucket().region(), + &self.bucket.access_key().await?.unwrap_or_default(), + &self.datetime, + &self.bucket.region(), expiry, custom_headers, token.as_ref() @@ -454,7 +440,7 @@ pub trait Request { custom_headers: Option<&HeaderMap>, custom_queries: Option<&HashMap>, ) -> Result { - let bucket = self.bucket(); + let bucket = self.bucket; let token = if let Some(security_token) = bucket.security_token()? { Some(security_token) } else { @@ -464,9 +450,9 @@ pub trait Request { "{}{}{}", self.url()?, &signing::authorization_query_params_no_sig( - &self.bucket().access_key()?.unwrap_or_default(), - &self.datetime(), - &self.bucket().region(), + &self.bucket.access_key()?.unwrap_or_default(), + &self.datetime, + &self.bucket.region(), expiry, custom_headers, token.as_ref() @@ -478,20 +464,20 @@ pub trait Request { } fn url(&self) -> Result { - let mut url_str = self.bucket().url(); + let mut url_str = self.bucket.url(); - if let Command::ListBuckets { .. } = self.command() { + if let Command::ListBuckets { .. } = self.command { return Ok(Url::parse(&url_str)?); } - if let Command::CreateBucket { .. } = self.command() { + if let Command::CreateBucket { .. } = self.command { return Ok(Url::parse(&url_str)?); } - let path = if self.path().starts_with('/') { - self.path()[1..].to_string() + let path = if self.path.starts_with('/') { + self.path[1..].to_string() } else { - self.path()[..].to_string() + self.path[..].to_string() }; url_str.push('/'); @@ -499,7 +485,7 @@ pub trait Request { // Append to url_path #[allow(clippy::collapsible_match)] - match self.command() { + match &self.command { Command::InitiateMultipartUpload { .. } | Command::ListMultipartUploads { .. } => { url_str.push_str("?uploads") } @@ -554,7 +540,7 @@ pub trait Request { let mut url = Url::parse(&url_str)?; - for (key, value) in &self.bucket().extra_query { + for (key, value) in &self.bucket.extra_query { url.query_pairs_mut().append_pair(key, value); } @@ -564,7 +550,7 @@ pub trait Request { continuation_token, start_after, max_keys, - } = self.command().clone() + } = self.command.clone() { let mut query_pairs = url.query_pairs_mut(); delimiter.map(|d| query_pairs.append_pair("delimiter", &d)); @@ -587,7 +573,7 @@ pub trait Request { delimiter, marker, max_keys, - } = self.command().clone() + } = self.command.clone() { let mut query_pairs = url.query_pairs_mut(); delimiter.map(|d| query_pairs.append_pair("delimiter", &d)); @@ -601,7 +587,7 @@ pub trait Request { } } - match self.command() { + match &self.command { Command::ListMultipartUploads { prefix, delimiter, @@ -614,7 +600,7 @@ pub trait Request { query_pairs.append_pair("prefix", prefix); } if let Some(key_marker) = key_marker { - query_pairs.append_pair("key-marker", &key_marker); + query_pairs.append_pair("key-marker", key_marker); } if let Some(max_uploads) = max_uploads { query_pairs.append_pair("max-uploads", max_uploads.to_string().as_str()); @@ -633,10 +619,10 @@ pub trait Request { fn canonical_request(&self, headers: &HeaderMap) -> Result { signing::canonical_request( - &self.command().http_verb().to_string(), + &self.command.http_verb().to_string(), &self.url()?, headers, - &self.command().sha256()?, + &self.command.sha256()?, ) } @@ -650,31 +636,29 @@ pub trait Request { let signed_header = signing::signed_header_string(headers); signing::authorization_header( &self - .bucket() + .bucket .access_key() .await? .expect("No access_key provided"), - &self.datetime(), - &self.bucket().region(), + &self.datetime, + &self.bucket.region(), &signed_header, &signature, ) } #[maybe_async::maybe_async] - async fn headers(&self) -> Result { + async fn add_headers(&self, headers: &mut HeaderMap) -> Result<(), S3Error> { // Generate this once, but it's used in more than one place. - let sha256 = self.command().sha256()?; + let sha256 = self.command.sha256()?; // Start with extra_headers, that way our headers replace anything with // the same name. - let mut headers = HeaderMap::new(); - - for (k, v) in self.bucket().extra_headers.iter() { + for (k, v) in self.bucket.extra_headers.iter() { if k.as_str().starts_with("x-amz-meta-") { // metadata is invalid on any multipart command other than initiate - match self.command() { + match self.command { Command::UploadPart { .. } | Command::AbortMultipartUpload { .. } | Command::CompleteMultipartUpload { .. } @@ -688,7 +672,7 @@ pub trait Request { } // Append custom headers for PUT request if any - if let Command::PutObject { custom_headers, .. } = self.command() + if let Command::PutObject { custom_headers, .. } = &self.command && let Some(custom_headers) = custom_headers { for (k, v) in custom_headers.iter() { @@ -700,7 +684,7 @@ pub trait Request { headers.insert(HOST, host_header.parse()?); - match self.command() { + match self.command { Command::CopyObject { from } => { headers.insert(HeaderName::from_static("x-amz-copy-source"), from.parse()?); } @@ -713,9 +697,9 @@ pub trait Request { _ => { headers.insert( CONTENT_LENGTH, - self.command().content_length()?.to_string().parse()?, + self.command.content_length()?.to_string().parse()?, ); - headers.insert(CONTENT_TYPE, self.command().content_type().parse()?); + headers.insert(CONTENT_TYPE, self.command.content_type().parse()?); } } headers.insert( @@ -727,34 +711,34 @@ pub trait Request { self.long_date()?.parse()?, ); - if let Some(session_token) = self.bucket().session_token().await? { + if let Some(session_token) = self.bucket.session_token().await? { headers.insert( HeaderName::from_static("x-amz-security-token"), session_token.parse()?, ); - } else if let Some(security_token) = self.bucket().security_token().await? { + } else if let Some(security_token) = self.bucket.security_token().await? { headers.insert( HeaderName::from_static("x-amz-security-token"), security_token.parse()?, ); } - if let Command::PutObjectTagging { tags } = self.command() { + if let Command::PutObjectTagging { tags } = self.command { let digest = md5::compute(tags); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); - } else if let Command::PutObject { content, .. } = self.command() { + } else if let Command::PutObject { content, .. } = self.command { let digest = md5::compute(content); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); - } else if let Command::UploadPart { content, .. } = self.command() { + } else if let Command::UploadPart { content, .. } = self.command { let digest = md5::compute(content); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); - } else if let Command::GetObject {} = self.command() { + } else if let Command::GetObject {} = self.command { headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?); // headers.insert(header::ACCEPT_CHARSET, HeaderValue::from_str("UTF-8")?); - } else if let Command::GetObjectRange { start, end } = self.command() { + } else if let Command::GetObjectRange { start, end } = self.command { headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?); let mut range = format!("bytes={}-", start); @@ -764,9 +748,9 @@ pub trait Request { } headers.insert(RANGE, range.parse()?); - } else if let Command::CreateBucket { ref config } = self.command() { - config.add_headers(&mut headers)?; - } else if let Command::PutBucketLifecycle { ref configuration } = self.command() { + } else if let Command::CreateBucket { ref config } = self.command { + config.add_headers(headers)?; + } else if let Command::PutBucketLifecycle { ref configuration } = self.command { let digest = md5::compute(to_string(configuration)?.as_bytes()); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); @@ -775,7 +759,7 @@ pub trait Request { expected_bucket_owner, configuration, .. - } = self.command() + } = &self.command { let digest = md5::compute(configuration.to_string().as_bytes()); let hash = general_purpose::STANDARD.encode(digest.as_ref()); @@ -787,7 +771,7 @@ pub trait Request { ); } else if let Command::GetBucketCors { expected_bucket_owner, - } = self.command() + } = &self.command { headers.insert( HeaderName::from_static("x-amz-expected-bucket-owner"), @@ -795,7 +779,7 @@ pub trait Request { ); } else if let Command::DeleteBucketCors { expected_bucket_owner, - } = self.command() + } = &self.command { headers.insert( HeaderName::from_static("x-amz-expected-bucket-owner"), @@ -804,7 +788,7 @@ pub trait Request { } else if let Command::GetObjectAttributes { expected_bucket_owner, .. - } = self.command() + } = &self.command { headers.insert( HeaderName::from_static("x-amz-expected-bucket-owner"), @@ -817,8 +801,8 @@ pub trait Request { } // This must be last, as it signs the other headers, omitted if no secret key is provided - if self.bucket().secret_key().await?.is_some() { - let authorization = self.authorization(&headers).await?; + if self.bucket.secret_key().await?.is_some() { + let authorization = self.authorization(headers).await?; headers.insert(AUTHORIZATION, authorization.parse()?); } @@ -828,10 +812,94 @@ pub trait Request { // range and can't be used again e.g. reply attacks. Adding this header // after the generation of the Authorization header leaves it out of // the signed headers. - headers.insert(DATE, self.datetime().format(&Rfc2822)?.parse()?); + headers.insert(DATE, self.datetime.format(&Rfc2822)?.parse()?); + + Ok(()) + } +} + +fn make_body(command: Command<'_>) -> Result, S3Error> { + let result = if let Command::PutObject { content, .. } = command { + content.into() + } else if let Command::PutObjectTagging { tags } = command { + tags.as_bytes().into() + } else if let Command::UploadPart { content, .. } = command { + content.into() + } else if let Command::CompleteMultipartUpload { data, .. } = &command { + let body = data.to_string(); + body.as_bytes().to_vec().into() + } else if let Command::CreateBucket { config } = &command { + if let Some(payload) = config.location_constraint_payload() { + payload.as_bytes().to_vec().into() + } else { + b"".into() + } + } else if let Command::PutBucketLifecycle { configuration, .. } = &command { + quick_xml::se::to_string(configuration)? + .as_bytes() + .to_vec() + .into() + } else if let Command::PutBucketCors { configuration, .. } = &command { + let cors = configuration.to_string(); + cors.as_bytes().to_vec().into() + } else { + b"".into() + }; + Ok(result) +} - Ok(headers) +#[maybe_async::maybe_async] +pub(crate) async fn build_request<'body>( + bucket: &Bucket, + path: &str, + command: Command<'body>, +) -> Result>, S3Error> { + let method = match command.http_verb() { + HttpMethod::Delete => Method::DELETE, + HttpMethod::Get => Method::GET, + HttpMethod::Post => Method::POST, + HttpMethod::Put => Method::PUT, + HttpMethod::Head => Method::HEAD, + }; + + let mut request_builder = http::Request::builder().method(method); + let headers_builder = BuildHelper { + bucket, + path, + command, + datetime: now_utc(), + }; + headers_builder + .add_headers(request_builder.headers_mut().unwrap()) + .await?; + let url = headers_builder.url()?; + let uri_builder = http::Uri::builder() + .scheme(url.scheme()) + .authority(url.authority()); + let uri_builder = match url.query() { + None => uri_builder.path_and_query(url.path()), + Some(query) => uri_builder.path_and_query(format!("{}?{}", url.path(), query)), + }; + let request = request_builder + .uri(uri_builder.build()?) + .body(make_body(headers_builder.command)?)?; + Ok(request) +} + +#[maybe_async::maybe_async] +pub(crate) async fn build_presigned( + bucket: &Bucket, + path: &str, + command: Command<'_>, +) -> Result { + BuildHelper { + bucket, + path, + command, + datetime: now_utc(), } + .presigned() + .await } #[cfg(all(test, feature = "with-tokio"))] diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index f8211daae6..7f2d9c9090 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -4,17 +4,13 @@ extern crate md5; use bytes::Bytes; use futures_util::TryStreamExt; use maybe_async::maybe_async; +use std::borrow::Cow; use std::collections::HashMap; -use std::str::FromStr as _; -use time::OffsetDateTime; use super::request_trait::{Request, ResponseData, ResponseDataStream}; use crate::bucket::Bucket; -use crate::command::Command; -use crate::command::HttpMethod; use crate::error::S3Error; use crate::retry; -use crate::utils::now_utc; use tokio_stream::StreamExt; @@ -60,10 +56,8 @@ pub(crate) fn client(options: &ClientOptions) -> Result { - pub bucket: &'a Bucket, - pub path: &'a str, - pub command: Command<'a>, - pub datetime: OffsetDateTime, + request: http::Request>, + client: reqwest::Client, pub sync: bool, } @@ -73,40 +67,8 @@ impl<'a> Request for ReqwestRequest<'a> { type HeaderMap = reqwest::header::HeaderMap; async fn response(&self) -> Result { - let headers = self - .headers() - .await? - .iter() - .map(|(k, v)| { - ( - reqwest::header::HeaderName::from_str(k.as_str()), - reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()), - ) - }) - .filter(|(k, v)| k.is_ok() && v.is_ok()) - .map(|(k, v)| (k.unwrap(), v.unwrap())) - .collect(); - - let client = self.bucket.http_client(); - - let method = match self.command.http_verb() { - HttpMethod::Delete => reqwest::Method::DELETE, - HttpMethod::Get => reqwest::Method::GET, - HttpMethod::Post => reqwest::Method::POST, - HttpMethod::Put => reqwest::Method::PUT, - HttpMethod::Head => reqwest::Method::HEAD, - }; - - let request = client - .request(method, self.url()?.as_str()) - .headers(headers) - .body(self.request_body()?); - - let request = request.build()?; - - // println!("Request: {:?}", request); - - let response = client.execute(request).await?; + let request = self.request.clone().map(|b| b.into_owned()).try_into()?; + let response = self.client.execute(request).await?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { let status = response.status().as_u16(); @@ -191,35 +153,13 @@ impl<'a> Request for ReqwestRequest<'a> { let headers = response.headers().clone(); Ok((headers, status_code)) } - - fn datetime(&self) -> OffsetDateTime { - self.datetime - } - - fn bucket(&self) -> Bucket { - self.bucket.clone() - } - - fn command(&self) -> Command<'_> { - self.command.clone() - } - - fn path(&self) -> String { - self.path.to_string() - } } impl<'a> ReqwestRequest<'a> { - pub fn new( - bucket: &'a Bucket, - path: &'a str, - command: Command<'a>, - ) -> Result, S3Error> { + pub fn new(request: http::Request>, bucket: &Bucket) -> Result { Ok(Self { - bucket, - path, - command, - datetime: now_utc(), + request, + client: bucket.http_client(), sync: false, }) } @@ -229,8 +169,7 @@ impl<'a> ReqwestRequest<'a> { mod tests { use crate::bucket::Bucket; use crate::command::Command; - use crate::request::Request; - use crate::request::tokio_backend::ReqwestRequest; + use crate::request::request_trait::build_request; use awscreds::Credentials; use http::header::{HOST, RANGE}; @@ -247,11 +186,13 @@ mod tests { let region = "custom-region".parse().unwrap(); let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); let path = "/my-first/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "https"); + assert_eq!(request.uri().scheme_str().unwrap(), "https"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get(HOST).unwrap(); assert_eq!(*host, "my-first-bucket.custom-region".to_string()); @@ -264,11 +205,13 @@ mod tests { .unwrap() .with_path_style(); let path = "/my-first/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "https"); + assert_eq!(request.uri().scheme_str().unwrap(), "https"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get(HOST).unwrap(); assert_eq!(*host, "custom-region".to_string()); @@ -279,11 +222,13 @@ mod tests { let region = "http://custom-region".parse().unwrap(); let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap(); let path = "/my-second/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "http"); + assert_eq!(request.uri().scheme_str().unwrap(), "http"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get(HOST).unwrap(); assert_eq!(*host, "my-second-bucket.custom-region".to_string()); } @@ -295,11 +240,13 @@ mod tests { .unwrap() .with_path_style(); let path = "/my-second/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject).unwrap(); + let request = build_request(&bucket, path, Command::GetObject) + .await + .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "http"); + assert_eq!(request.uri().scheme_str().unwrap(), "http"); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let host = headers.get(HOST).unwrap(); assert_eq!(*host, "custom-region".to_string()); } @@ -312,7 +259,7 @@ mod tests { .with_path_style(); let path = "/my-second/path"; - let request = ReqwestRequest::new( + let request = build_request( &bucket, path, Command::GetObjectRange { @@ -320,12 +267,13 @@ mod tests { end: None, }, ) + .await .unwrap(); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let range = headers.get(RANGE).unwrap(); assert_eq!(range, "bytes=0-"); - let request = ReqwestRequest::new( + let request = build_request( &bucket, path, Command::GetObjectRange { @@ -333,8 +281,9 @@ mod tests { end: Some(1), }, ) + .await .unwrap(); - let headers = request.headers().await.unwrap(); + let headers = request.headers(); let range = headers.get(RANGE).unwrap(); assert_eq!(range, "bytes=0-1"); } From d008813bd2f0e0e566f31456abe244759b88f52f Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Sat, 22 Nov 2025 02:17:07 +0000 Subject: [PATCH 3/9] Fix and deduplicate unit tests after the http::Request refactor. The unification of all 3 backends under http::Request left essentially 3 copies of the same unit tests. Deduplicate those and move them to request_trait.rs because that's the substance of what they were really testing. In their place add new tests for each backend that test the request translation from http::Request to the backend's own request representation. --- s3/Cargo.toml | 1 + s3/src/request/async_std_backend.rs | 118 ++++++++-------------- s3/src/request/blocking.rs | 108 +++++++------------- s3/src/request/request_trait.rs | 149 ++++++++++++++++++++++++++++ s3/src/request/tokio_backend.rs | 132 +++++------------------- 5 files changed, 253 insertions(+), 255 deletions(-) diff --git a/s3/Cargo.toml b/s3/Cargo.toml index cba235af67..bab67b4081 100644 --- a/s3/Cargo.toml +++ b/s3/Cargo.toml @@ -120,6 +120,7 @@ sync-rustls-tls = ["attohttpc/tls-rustls", "aws-creds/rustls-tls", "sync"] [dev-dependencies] tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "fs"] } async-std = { version = "1", features = ["attributes"] } +http-body-util = "0.1.3" uuid = { version = "1", features = ["v4"] } env_logger = "0.11" anyhow = "1" diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index c3fdcf9110..8f5f0b55eb 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -22,12 +22,8 @@ pub struct SurfRequest<'a> { pub sync: bool, } -#[maybe_async] -impl<'a> Request for SurfRequest<'a> { - type Response = surf::Response; - type HeaderMap = HeaderMap; - - async fn response(&self) -> Result { +impl SurfRequest<'_> { + fn build(&self) -> Result { let url = format!("{}", self.request.uri()).parse()?; let mut request = match *self.request.method() { http::Method::GET => surf::Request::builder(Method::Get, url), @@ -53,7 +49,18 @@ impl<'a> Request for SurfRequest<'a> { ); } - let response = request + Ok(request) + } +} + +#[maybe_async] +impl<'a> Request for SurfRequest<'a> { + type Response = surf::Response; + type HeaderMap = HeaderMap; + + async fn response(&self) -> Result { + let response = self + .build()? .send() .await .map_err(|e| S3Error::Surf(e.to_string()))?; @@ -175,11 +182,9 @@ impl<'a> SurfRequest<'a> { #[cfg(test)] mod tests { - use crate::bucket::Bucket; - use crate::command::Command; - use crate::request::request_trait::build_request; - use anyhow::Result; - use awscreds::Credentials; + use super::*; + use crate::Bucket; + use crate::creds::Credentials; // Fake keys - otherwise using Credentials::default will use actual user // credentials if they exist. @@ -190,73 +195,28 @@ mod tests { } #[async_std::test] - async fn url_uses_https_by_default() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; - let path = "/my-first/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "https"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "my-first-bucket.custom-region".to_string()); - Ok(()) - } - - #[async_std::test] - async fn url_uses_https_by_default_path_style() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?.with_path_style(); - let path = "/my-first/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "https"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "custom-region".to_string()); - Ok(()) - } - - #[async_std::test] - async fn url_uses_scheme_from_custom_region_if_defined() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; - let path = "/my-second/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await + async fn test_build() { + let http_request = http::Request::builder() + .uri("https://example.com/foo?bar=1") + .method(http::Method::POST) + .header("h1", "v1") + .header("h2", "v2") + .body(b"sneaky".into()) .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "http"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "my-second-bucket.custom-region".to_string()); - Ok(()) - } - - #[async_std::test] - async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?.with_path_style(); - let path = "/my-second/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "http"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "custom-region".to_string()); - - Ok(()) + let region = "custom-region".parse().unwrap(); + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); + + let mut r = SurfRequest::new(http_request, &bucket) + .unwrap() + .build() + .unwrap() + .build(); + + assert_eq!(r.method(), Method::Post); + assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.header("h1").unwrap(), "v1"); + assert_eq!(r.header("h2").unwrap(), "v2"); + let body = r.take_body().into_bytes().await.unwrap(); + assert_eq!(body.as_slice(), b"sneaky"); } } diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index 433c5ac01f..ad4fcbab90 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -21,11 +21,10 @@ pub struct AttoRequest<'a> { pub sync: bool, } -impl<'a> Request for AttoRequest<'a> { - type Response = attohttpc::Response; - type HeaderMap = attohttpc::header::HeaderMap; - - fn response(&self) -> Result { +impl AttoRequest<'_> { + fn build( + &self, + ) -> Result>>, S3Error> { let mut session = attohttpc::Session::new(); for (name, value) in self.request.headers().iter() { @@ -47,7 +46,17 @@ impl<'a> Request for AttoRequest<'a> { return Err(S3Error::HttpFailWithBody(405, "".into())); } }; - let response = request.bytes(self.request.body().clone()).send()?; + + Ok(request.bytes(self.request.body().clone())) + } +} + +impl<'a> Request for AttoRequest<'a> { + type Response = attohttpc::Response; + type HeaderMap = attohttpc::header::HeaderMap; + + fn response(&self) -> Result { + let response = self.build()?.send()?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { let status = response.status().as_u16(); @@ -133,10 +142,8 @@ impl<'a> AttoRequest<'a> { #[cfg(test)] mod tests { + use super::*; use crate::bucket::Bucket; - use crate::command::Command; - use crate::request::request_trait::build_request; - use anyhow::Result; use awscreds::Credentials; // Fake keys - otherwise using Credentials::default will use actual user @@ -148,67 +155,26 @@ mod tests { } #[test] - fn url_uses_https_by_default() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; - let path = "/my-first/path"; - let request = build_request(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "https"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "my-first-bucket.custom-region".to_string()); - Ok(()) - } - - #[test] - fn url_uses_https_by_default_path_style() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; - bucket.with_path_style(); - let path = "/my-first/path"; - let request = build_request(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "https"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "custom-region".to_string()); - Ok(()) - } - - #[test] - fn url_uses_scheme_from_custom_region_if_defined() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; - let path = "/my-second/path"; - let request = build_request(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "http"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "my-second-bucket.custom-region".to_string()); - Ok(()) - } - - #[test] - fn url_uses_scheme_from_custom_region_if_defined_with_path_style() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; - bucket.with_path_style(); - let path = "/my-second/path"; - let request = build_request(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "http"); - - let headers = request.headers(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "custom-region".to_string()); - - Ok(()) + fn test_build() { + let http_request = http::Request::builder() + .uri("https://example.com/foo?bar=1") + .method(http::Method::POST) + .header("h1", "v1") + .header("h2", "v2") + .body(b"sneaky".into()) + .unwrap(); + let region = "custom-region".parse().unwrap(); + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); + + let req = AttoRequest::new(http_request, &bucket).unwrap(); + let mut r = req.build().unwrap(); + + assert_eq!(r.inspect().method(), http::Method::POST); + assert_eq!(r.inspect().url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.headers().get("h1").unwrap(), "v1"); + assert_eq!(r.headers().get("h2").unwrap(), "v2"); + let mut i = r.inspect(); + let body = &i.body().0; + assert_eq!(&**body, b"sneaky"); } } diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index a40d932f72..1747500102 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -1135,3 +1135,152 @@ mod async_std_tests { assert_eq!(output, b"First chunk\nSecond chunk\nThird chunk\n"); } } + +#[cfg(test)] +mod request_tests { + use super::build_request; + use crate::bucket::Bucket; + use crate::command::Command; + use awscreds::Credentials; + use http::header::{HOST, RANGE}; + + /// A trivial spinning async executor we can use to be independent + /// of tokio or async-std + #[cfg(not(feature = "sync"))] + mod test_executor { + #[derive(Default)] + pub(super) struct TestWaker(pub std::sync::atomic::AtomicBool); + + impl std::task::Wake for TestWaker { + fn wake(self: std::sync::Arc) { + self.0.store(true, std::sync::atomic::Ordering::Release); + } + } + } + + #[cfg(not(feature = "sync"))] + fn testawait(fut: F) -> F::Output { + let w = std::sync::Arc::new(test_executor::TestWaker::default()); + let waker: std::task::Waker = w.clone().into(); + let mut cx = std::task::Context::from_waker(&waker); + let mut fut = std::pin::pin!(fut); + loop { + if let std::task::Poll::Ready(o) = fut.as_mut().poll(&mut cx) { + return o; + } + if !w.0.swap(false, std::sync::atomic::Ordering::AcqRel) { + panic!("Future made no progress"); + } + } + } + + #[cfg(feature = "sync")] + fn testawait(val: T) -> T { + val + } + + // Fake keys - otherwise using Credentials::default will use actual user + // credentials if they exist. + fn fake_credentials() -> Credentials { + let access_key = "AKIAIOSFODNN7EXAMPLE"; + let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() + } + + #[test] + fn url_uses_https_by_default() { + let region = "custom-region".parse().unwrap(); + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); + let path = "/my-first/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "https"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + + assert_eq!(*host, "my-first-bucket.custom-region".to_string()); + } + + #[test] + fn url_uses_https_by_default_path_style() { + let region = "custom-region".parse().unwrap(); + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()) + .unwrap() + .with_path_style(); + let path = "/my-first/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "https"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + + assert_eq!(*host, "custom-region".to_string()); + } + + #[test] + fn url_uses_scheme_from_custom_region_if_defined() { + let region = "http://custom-region".parse().unwrap(); + let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap(); + let path = "/my-second/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "http"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + assert_eq!(*host, "my-second-bucket.custom-region".to_string()); + } + + #[test] + fn url_uses_scheme_from_custom_region_if_defined_with_path_style() { + let region = "http://custom-region".parse().unwrap(); + let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) + .unwrap() + .with_path_style(); + let path = "/my-second/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "http"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + assert_eq!(*host, "custom-region".to_string()); + } + + #[test] + fn test_get_object_range_header() { + let region = "http://custom-region".parse().unwrap(); + let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) + .unwrap() + .with_path_style(); + let path = "/my-second/path"; + + let request = testawait(build_request( + &bucket, + path, + Command::GetObjectRange { + start: 0, + end: None, + }, + )) + .unwrap(); + let headers = request.headers(); + let range = headers.get(RANGE).unwrap(); + assert_eq!(range, "bytes=0-"); + + let request = testawait(build_request( + &bucket, + path, + Command::GetObjectRange { + start: 0, + end: Some(1), + }, + )) + .unwrap(); + let headers = request.headers(); + let range = headers.get(RANGE).unwrap(); + assert_eq!(range, "bytes=0-1"); + } +} diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 7f2d9c9090..995c238e07 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -61,14 +61,19 @@ pub struct ReqwestRequest<'a> { pub sync: bool, } +impl ReqwestRequest<'_> { + fn build(&self) -> Result { + Ok(self.request.clone().map(|b| b.into_owned()).try_into()?) + } +} + #[maybe_async] impl<'a> Request for ReqwestRequest<'a> { type Response = reqwest::Response; type HeaderMap = reqwest::header::HeaderMap; async fn response(&self) -> Result { - let request = self.request.clone().map(|b| b.into_owned()).try_into()?; - let response = self.client.execute(request).await?; + let response = self.client.execute(self.build()?).await?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { let status = response.status().as_u16(); @@ -167,11 +172,10 @@ impl<'a> ReqwestRequest<'a> { #[cfg(test)] mod tests { - use crate::bucket::Bucket; - use crate::command::Command; - use crate::request::request_trait::build_request; - use awscreds::Credentials; - use http::header::{HOST, RANGE}; + use super::*; + use crate::Bucket; + use crate::creds::Credentials; + use http_body_util::BodyExt; // Fake keys - otherwise using Credentials::default will use actual user // credentials if they exist. @@ -182,109 +186,27 @@ mod tests { } #[tokio::test] - async fn url_uses_https_by_default() { - let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); - let path = "/my-first/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await + async fn test_build() { + let http_request = http::Request::builder() + .uri("https://example.com/foo?bar=1") + .method(http::Method::POST) + .header("h1", "v1") + .header("h2", "v2") + .body(b"sneaky".into()) .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "https"); - - let headers = request.headers(); - let host = headers.get(HOST).unwrap(); - - assert_eq!(*host, "my-first-bucket.custom-region".to_string()); - } - - #[tokio::test] - async fn url_uses_https_by_default_path_style() { let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()) - .unwrap() - .with_path_style(); - let path = "/my-first/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "https"); - - let headers = request.headers(); - let host = headers.get(HOST).unwrap(); - - assert_eq!(*host, "custom-region".to_string()); - } - - #[tokio::test] - async fn url_uses_scheme_from_custom_region_if_defined() { - let region = "http://custom-region".parse().unwrap(); - let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap(); - let path = "/my-second/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.uri().scheme_str().unwrap(), "http"); - - let headers = request.headers(); - let host = headers.get(HOST).unwrap(); - assert_eq!(*host, "my-second-bucket.custom-region".to_string()); - } + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); - #[tokio::test] - async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() { - let region = "http://custom-region".parse().unwrap(); - let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) + let mut r = ReqwestRequest::new(http_request, &bucket) .unwrap() - .with_path_style(); - let path = "/my-second/path"; - let request = build_request(&bucket, path, Command::GetObject) - .await + .build() .unwrap(); - assert_eq!(request.uri().scheme_str().unwrap(), "http"); - - let headers = request.headers(); - let host = headers.get(HOST).unwrap(); - assert_eq!(*host, "custom-region".to_string()); - } - - #[tokio::test] - async fn test_get_object_range_header() { - let region = "http://custom-region".parse().unwrap(); - let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) - .unwrap() - .with_path_style(); - let path = "/my-second/path"; - - let request = build_request( - &bucket, - path, - Command::GetObjectRange { - start: 0, - end: None, - }, - ) - .await - .unwrap(); - let headers = request.headers(); - let range = headers.get(RANGE).unwrap(); - assert_eq!(range, "bytes=0-"); - - let request = build_request( - &bucket, - path, - Command::GetObjectRange { - start: 0, - end: Some(1), - }, - ) - .await - .unwrap(); - let headers = request.headers(); - let range = headers.get(RANGE).unwrap(); - assert_eq!(range, "bytes=0-1"); + assert_eq!(r.method(), http::Method::POST); + assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.headers().get("h1").unwrap(), "v1"); + assert_eq!(r.headers().get("h2").unwrap(), "v2"); + let body = r.body_mut().take().unwrap().collect().await; + assert_eq!(body.unwrap().to_bytes().as_ref(), b"sneaky"); } } From 758216fdec4ef3a84522ac20beddfa0384a77425 Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Sat, 22 Nov 2025 21:31:57 +0000 Subject: [PATCH 4/9] Deliver responses from all 3 backends in http::Response format. This will make it easier to offer pluggable backends in the future. --- s3/Cargo.toml | 7 +- s3/src/error.rs | 2 + s3/src/request/async_std_backend.rs | 150 +++++++---------------- s3/src/request/blocking.rs | 79 ++---------- s3/src/request/request_trait.rs | 182 +++++++++++++++++++++++++--- s3/src/request/tokio_backend.rs | 90 +------------- 6 files changed, 223 insertions(+), 287 deletions(-) diff --git a/s3/Cargo.toml b/s3/Cargo.toml index bab67b4081..5a9f70bbb1 100644 --- a/s3/Cargo.toml +++ b/s3/Cargo.toml @@ -55,6 +55,8 @@ futures-util = { version = "0.3", optional = true, default-features = false } hex = "0.4" hmac = "0.12" http = "1" +http-body = { version = "1.0.1", optional = true } +http-body-util = { version = "0.1.3", optional = true } log = "0.4" maybe-async = { version = "0.2" } md5 = "0.8" @@ -86,8 +88,8 @@ default = ["fail-on-err", "tags", "tokio-native-tls"] sync = ["attohttpc", "maybe-async/is_sync"] with-async-std-hyper = ["with-async-std", "surf/hyper-client"] -with-async-std = ["async-std", "futures-util", "surf", "sysinfo"] -with-tokio = ["futures-util", "reqwest", "tokio", "tokio/fs", "tokio-stream", "sysinfo"] +with-async-std = ["dep:http-body", "dep:http-body-util", "async-std", "futures-util", "surf", "sysinfo"] +with-tokio = ["dep:http-body", "dep:http-body-util", "futures-util", "reqwest", "tokio", "tokio/fs", "tokio-stream", "sysinfo"] blocking = ["block_on_proc", "tokio/rt", "tokio/rt-multi-thread"] fail-on-err = [] @@ -120,7 +122,6 @@ sync-rustls-tls = ["attohttpc/tls-rustls", "aws-creds/rustls-tls", "sync"] [dev-dependencies] tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "fs"] } async-std = { version = "1", features = ["attributes"] } -http-body-util = "0.1.3" uuid = { version = "1", features = ["v4"] } env_logger = "0.11" anyhow = "1" diff --git a/s3/src/error.rs b/s3/src/error.rs index 8e1c56ffac..5746867154 100644 --- a/s3/src/error.rs +++ b/s3/src/error.rs @@ -44,6 +44,8 @@ pub enum S3Error { #[cfg(feature = "with-async-std")] #[error("surf: {0}")] Surf(String), + #[error("{0}")] + InvalidStatusCode(#[from] http::status::InvalidStatusCode), #[cfg(feature = "sync")] #[error("attohttpc: {0}")] Atto(#[from] attohttpc::Error), diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index 8f5f0b55eb..6de73aca4e 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -1,17 +1,15 @@ -use async_std::io::Write as AsyncWrite; -use async_std::io::{ReadExt, WriteExt}; -use async_std::stream::StreamExt; use bytes::Bytes; -use futures_util::FutureExt; +use futures_util::AsyncBufRead as _; use std::borrow::Cow; -use std::collections::HashMap; +use std::pin::Pin; +use std::task::{Context, Poll}; use crate::bucket::Bucket; use crate::error::S3Error; -use crate::request::{Request, ResponseData, ResponseDataStream}; +use crate::request::Request; -use http::HeaderMap; +use http_body::Frame; use maybe_async::maybe_async; use surf::http::Method; use surf::http::headers::{HeaderName, HeaderValue}; @@ -53,13 +51,40 @@ impl SurfRequest<'_> { } } +pub struct SurfBody(surf::Body); + +impl http_body::Body for SurfBody { + type Data = Bytes; + type Error = std::io::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, std::io::Error>>> { + let mut inner = Pin::new(&mut self.0); + match inner.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(sliceu8)) => { + if sliceu8.is_empty() { + Poll::Ready(None) + } else { + let len = sliceu8.len(); + let frame = Frame::data(Bytes::copy_from_slice(sliceu8)); + inner.as_mut().consume(len); + Poll::Ready(Some(Ok(frame))) + } + } + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } + } +} + #[maybe_async] impl<'a> Request for SurfRequest<'a> { - type Response = surf::Response; - type HeaderMap = HeaderMap; + type ResponseBody = SurfBody; - async fn response(&self) -> Result { - let response = self + async fn response(&self) -> Result, S3Error> { + let mut response = self .build()? .send() .await @@ -69,105 +94,14 @@ impl<'a> Request for SurfRequest<'a> { return Err(S3Error::HttpFail); } - Ok(response) - } - - async fn response_data(&self, etag: bool) -> Result { - let mut response = crate::retry! {self.response().await}?; - let status_code = response.status(); - - let response_headers = response - .header_names() - .zip(response.header_values()) - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect::>(); - - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = response.header("ETag") { - Bytes::from(etag.as_str().to_string()) - } else { - Bytes::from("") + let mut builder = + http::Response::builder().status(http::StatusCode::from_u16(response.status().into())?); + for (name, values) in response.iter() { + for value in values { + builder = builder.header(name.as_str(), value.as_str()); } - } else { - let body = match response.body_bytes().await { - Ok(bytes) => Ok(Bytes::from(bytes)), - Err(e) => Err(S3Error::Surf(e.to_string())), - }; - body? - }; - Ok(ResponseData::new( - body_vec, - status_code.into(), - response_headers, - )) - } - - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result { - let mut buffer = Vec::new(); - - let response = crate::retry! {self.response().await}?; - - let status_code = response.status(); - - let mut stream = surf::http::Body::from_reader(response, None); - - stream.read_to_end(&mut buffer).await?; - - writer.write_all(&buffer).await?; - - Ok(status_code.into()) - } - - async fn response_header(&self) -> Result<(HeaderMap, u16), S3Error> { - let mut header_map = HeaderMap::new(); - let response = crate::retry! {self.response().await}?; - let status_code = response.status(); - - for (name, value) in response.iter() { - header_map.insert( - http::header::HeaderName::from_lowercase( - name.to_string().to_ascii_lowercase().as_ref(), - )?, - value.as_str().parse()?, - ); } - Ok((header_map, status_code.into())) - } - - async fn response_data_to_stream(&self) -> Result { - let mut response = crate::retry! {self.response().await}?; - let status_code = response.status(); - - let body = response - .take_body() - .bytes() - .filter_map(|n| n.ok()) - .fold(vec![], |mut b, n| { - b.push(n); - b - }) - .then(|b| async move { Ok(Bytes::from(b)) }) - .into_stream(); - - Ok(ResponseDataStream { - bytes: Box::pin(body), - status_code: status_code.into(), - }) + Ok(builder.body(SurfBody(response.take_body()))?) } } diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index ad4fcbab90..dd4e72b23f 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -1,18 +1,13 @@ extern crate base64; extern crate md5; -use std::io; -use std::io::Write; - use attohttpc::header::HeaderName; use crate::bucket::Bucket; use crate::error::S3Error; -use bytes::Bytes; use std::borrow::Cow; -use std::collections::HashMap; -use crate::request::{Request, ResponseData}; +use crate::request::Request; // Temporary structure for making a request pub struct AttoRequest<'a> { @@ -52,10 +47,9 @@ impl AttoRequest<'_> { } impl<'a> Request for AttoRequest<'a> { - type Response = attohttpc::Response; - type HeaderMap = attohttpc::header::HeaderMap; + type ResponseBody = attohttpc::ResponseReader; - fn response(&self) -> Result { + fn response(&self) -> Result, S3Error> { let response = self.build()?.send()?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { @@ -64,69 +58,12 @@ impl<'a> Request for AttoRequest<'a> { return Err(S3Error::HttpFailWithBody(status, text)); } - Ok(response) - } - - fn response_data(&self, etag: bool) -> Result { - let response = crate::retry! {self.response()}?; - let status_code = response.status().as_u16(); - - let response_headers = response - .headers() - .iter() - .map(|(k, v)| { - ( - k.to_string(), - v.to_str() - .unwrap_or("could-not-decode-header-value") - .to_string(), - ) - }) - .collect::>(); - - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = response.headers().get("ETag") { - Bytes::from(etag.to_str()?.to_string()) - } else { - Bytes::from("") - } - } else { - // HEAD requests don't have a response body - if *self.request.method() == http::Method::HEAD { - Bytes::from("") - } else { - Bytes::from(response.bytes()?) - } - }; - Ok(ResponseData::new(body_vec, status_code, response_headers)) - } - - fn response_data_to_writer(&self, writer: &mut T) -> Result { - let mut response = crate::retry! {self.response()}?; - - let status_code = response.status(); - io::copy(&mut response, writer)?; - - Ok(status_code.as_u16()) - } + let (status, headers, body) = response.split(); + let mut builder = + http::Response::builder().status(http::StatusCode::from_u16(status.into())?); + *builder.headers_mut().unwrap() = headers; - fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> { - let response = crate::retry! {self.response()}?; - let status_code = response.status().as_u16(); - let headers = response.headers().clone(); - Ok((headers, status_code)) + Ok(builder.body(body)?) } } diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index 1747500102..6578da1bc6 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -2,11 +2,13 @@ use base64::Engine; use base64::engine::general_purpose; use hmac::Mac; use http::Method; +#[cfg(any(feature = "with-tokio", feature = "with-async-std"))] +use http_body_util::BodyExt; use quick_xml::se::to_string; use std::borrow::Cow; use std::collections::HashMap; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] -use std::pin::Pin; +use std::pin::{Pin, pin}; use time::OffsetDateTime; use time::format_description::well_known::Rfc2822; use url::Url; @@ -15,8 +17,8 @@ use crate::LONG_DATETIME; use crate::bucket::Bucket; use crate::command::{Command, HttpMethod}; use crate::error::S3Error; -use crate::signing; use crate::utils::now_utc; +use crate::{retry, signing}; use bytes::Bytes; use http::HeaderMap; use http::header::{ @@ -25,10 +27,16 @@ use http::header::{ use std::fmt::Write as _; #[cfg(feature = "with-async-std")] -use async_std::stream::Stream; +use async_std::stream::{Stream, StreamExt}; + +#[cfg(feature = "with-tokio")] +use tokio_stream::{Stream, StreamExt}; + +#[cfg(feature = "with-async-std")] +use async_std::io::{Write as AsyncWrite, WriteExt as _}; #[cfg(feature = "with-tokio")] -use tokio_stream::Stream; +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; #[derive(Debug)] @@ -188,31 +196,167 @@ impl async_std::io::Read for ResponseDataStream { } } +mod sealed { + pub struct Sealed; +} + +#[cfg(not(feature = "sync"))] +pub trait ResponseBody: + http_body::Body + AsRef<[u8]> + Send, Error: Into + Send> + + Send + + 'static +{ + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed; + + fn into_bytes(self) -> impl Future> + Send + where + Self: Sized, + { + use futures_util::TryFutureExt as _; + self.collect().map_ok(|c| c.to_bytes()) + } +} + +#[cfg(not(feature = "sync"))] +impl ResponseBody for T +where + T: http_body::Body + Send + 'static, + ::Data: Into + AsRef<[u8]> + Send, + ::Error: Into + Send, +{ + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed { + sealed::Sealed + } +} + +#[cfg(feature = "sync")] +pub trait ResponseBody: std::io::Read { + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed; + + fn into_bytes(mut self) -> Result + where + Self: Sized, + { + let mut buf = Vec::new(); + self.read_to_end(&mut buf)?; + Ok(buf.into()) + } +} + +#[cfg(feature = "sync")] +impl ResponseBody for T { + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed { + sealed::Sealed + } +} + #[maybe_async::maybe_async] pub trait Request { - type Response; - type HeaderMap; + type ResponseBody: ResponseBody; - async fn response(&self) -> Result; - async fn response_data(&self, etag: bool) -> Result; - #[cfg(feature = "with-tokio")] - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result; - #[cfg(feature = "with-async-std")] - async fn response_data_to_writer( + async fn response(&self) -> Result, S3Error>; + + #[maybe_async::async_impl] + async fn response_with_retry(&self) -> Result, S3Error> { + retry! { self.response().await } + } + + #[maybe_async::sync_impl] + fn response_with_retry(&self) -> Result, S3Error> { + retry! { self.response() } + } + + async fn response_data(&self, etag: bool) -> Result { + let response = self.response_with_retry().await?; + let (mut head, body) = response.into_parts(); + let status_code = head.status.as_u16(); + // When etag=true, we extract the ETag header and return it as the body. + // This is used for PUT operations (regular puts, multipart chunks) where: + // 1. S3 returns an empty or non-useful response body + // 2. The ETag header contains the essential information we need + // 3. The calling code expects to get the ETag via response_data.as_str() + // + // Note: This approach means we discard any actual response body when etag=true, + // but for the operations that use this (PUTs), the body is typically empty + // or contains redundant information already available in headers. + // + // TODO: Refactor this to properly return the response body and access ETag + // from headers instead of replacing the body. This would be a breaking change. + let body_vec = if etag { + if let Some(etag) = head.headers.remove("ETag") { + Bytes::from(etag.to_str()?.to_string()) + } else { + Bytes::from("") + } + } else { + body.into_bytes().await.map_err(Into::::into)? + }; + let response_headers = head + .headers + .into_iter() + .filter_map(|(k, v)| { + k.map(|k| { + ( + k.to_string(), + v.to_str() + .unwrap_or("could-not-decode-header-value") + .to_string(), + ) + }) + }) + .collect::>(); + Ok(ResponseData::new(body_vec, status_code, response_headers)) + } + + #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] + async fn response_data_to_writer( &self, writer: &mut T, - ) -> Result; + ) -> Result { + let response = self.response_with_retry().await?; + + let status_code = response.status(); + let mut stream = pin!(response.into_body().into_data_stream()); + + while let Some(item) = stream.next().await { + writer.write_all(item.map_err(Into::into)?.as_ref()).await?; + } + + Ok(status_code.as_u16()) + } + #[cfg(feature = "sync")] fn response_data_to_writer( &self, writer: &mut T, - ) -> Result; + ) -> Result { + let response = self.response_with_retry().await?; + let status_code = response.status(); + let mut body = response.into_body(); + std::io::copy(&mut body, writer)?; + Ok(status_code.as_u16()) + } + #[cfg(any(feature = "with-async-std", feature = "with-tokio"))] - async fn response_data_to_stream(&self) -> Result; - async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>; + async fn response_data_to_stream(&self) -> Result { + let response = self.response_with_retry().await?; + let status_code = response.status(); + let stream = response.into_body().into_data_stream().map(|i| match i { + Ok(data) => Ok(data.into()), + Err(e) => Err(e.into()), + }); + + Ok(ResponseDataStream { + bytes: Box::pin(stream), + status_code: status_code.as_u16(), + }) + } + + async fn response_header(&self) -> Result<(http::HeaderMap, u16), S3Error> { + let response = self.response_with_retry().await?; + let (head, _) = response.into_parts(); + Ok((head.headers, head.status.as_u16())) + } } struct BuildHelper<'temp, 'body> { diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 995c238e07..93e5d956ee 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -1,18 +1,12 @@ extern crate base64; extern crate md5; -use bytes::Bytes; -use futures_util::TryStreamExt; use maybe_async::maybe_async; use std::borrow::Cow; -use std::collections::HashMap; -use super::request_trait::{Request, ResponseData, ResponseDataStream}; +use super::request_trait::Request; use crate::bucket::Bucket; use crate::error::S3Error; -use crate::retry; - -use tokio_stream::StreamExt; #[derive(Clone, Debug, Default)] pub(crate) struct ClientOptions { @@ -69,10 +63,9 @@ impl ReqwestRequest<'_> { #[maybe_async] impl<'a> Request for ReqwestRequest<'a> { - type Response = reqwest::Response; - type HeaderMap = reqwest::header::HeaderMap; + type ResponseBody = reqwest::Body; - async fn response(&self) -> Result { + async fn response(&self) -> Result, S3Error> { let response = self.client.execute(self.build()?).await?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { @@ -81,82 +74,7 @@ impl<'a> Request for ReqwestRequest<'a> { return Err(S3Error::HttpFailWithBody(status, text)); } - Ok(response) - } - - async fn response_data(&self, etag: bool) -> Result { - let response = retry! {self.response().await }?; - let status_code = response.status().as_u16(); - let mut headers = response.headers().clone(); - let response_headers = headers - .clone() - .iter() - .map(|(k, v)| { - ( - k.to_string(), - v.to_str() - .unwrap_or("could-not-decode-header-value") - .to_string(), - ) - }) - .collect::>(); - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = headers.remove("ETag") { - Bytes::from(etag.to_str()?.to_string()) - } else { - Bytes::from("") - } - } else { - response.bytes().await? - }; - Ok(ResponseData::new(body_vec, status_code, response_headers)) - } - - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result { - use tokio::io::AsyncWriteExt; - let response = retry! {self.response().await}?; - - let status_code = response.status(); - let mut stream = response.bytes_stream(); - - while let Some(item) = stream.next().await { - writer.write_all(&item?).await?; - } - - Ok(status_code.as_u16()) - } - - async fn response_data_to_stream(&self) -> Result { - let response = retry! {self.response().await}?; - let status_code = response.status(); - let stream = response.bytes_stream().map_err(S3Error::Reqwest); - - Ok(ResponseDataStream { - bytes: Box::pin(stream), - status_code: status_code.as_u16(), - }) - } - - async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> { - let response = retry! {self.response().await}?; - let status_code = response.status().as_u16(); - let headers = response.headers().clone(); - Ok((headers, status_code)) + Ok(response.into()) } } From a8765465abee60228202b224b148976bfa3b04bd Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Sun, 23 Nov 2025 10:29:10 +0000 Subject: [PATCH 5/9] Move backend-specific options from Bucket to individual backends. This is a breaking change. What used to we expressed as: bucket.with_request_timeout(...) bucket.request_timeout() bucket.set_dangereous_config(...) bucket.set_proxy(...) is now written as: bucket.with_backend(bucket.backend().with_request_timeout(...)?) bucket.backend().request_timeout() bucket.with_backend(bucket.backend().with_dangereous_config(...)?) bucket.with_backend(bucket.backend().with_proxy(...)?) Bucket no longer contains any assumptions about backends except that they accept http::Request and return http::Response. --- s3/src/bucket.rs | 212 ++++------------------------ s3/src/request/async_std_backend.rs | 28 ++-- s3/src/request/backend.rs | 14 ++ s3/src/request/blocking.rs | 51 ++++--- s3/src/request/mod.rs | 1 + s3/src/request/tokio_backend.rs | 112 ++++++++++++--- 6 files changed, 176 insertions(+), 242 deletions(-) create mode 100644 s3/src/request/backend.rs diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 28d658bee7..67c4252545 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -38,7 +38,6 @@ use block_on_proc::block_on; #[cfg(feature = "tags")] use minidom::Element; use std::collections::HashMap; -use std::time::Duration; use crate::bucket_ops::{BucketConfiguration, CreateBucketResponse}; use crate::command::{Command, Multipart}; @@ -46,11 +45,8 @@ use crate::creds::Credentials; use crate::region::Region; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] use crate::request::ResponseDataStream; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::ClientOptions; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::client; -use crate::request::{Request as _, ResponseData, build_presigned, build_request}; +use crate::request::backend::DefaultBackend; +use crate::request::{Request, ResponseData, build_presigned, build_request}; use std::str::FromStr; use std::sync::Arc; @@ -65,18 +61,11 @@ use std::sync::RwLock; pub type Query = HashMap; -#[cfg(feature = "with-async-std")] -use crate::request::async_std_backend::SurfRequest as RequestImpl; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::ReqwestRequest as RequestImpl; - #[cfg(feature = "with-async-std")] use async_std::io::Write as AsyncWrite; #[cfg(feature = "with-tokio")] use tokio::io::AsyncWrite; -#[cfg(feature = "sync")] -use crate::request::blocking::AttoRequest as RequestImpl; use std::io::Read; #[cfg(feature = "with-tokio")] @@ -102,8 +91,6 @@ use sysinfo::{MemoryRefreshKind, System}; pub const CHUNK_SIZE: usize = 8_388_608; // 8 Mebibytes, min is 5 (5_242_880); -const DEFAULT_REQUEST_TIMEOUT: Option = Some(Duration::from_secs(60)); - #[derive(Debug, PartialEq, Eq)] pub struct Tag { key: String, @@ -141,13 +128,9 @@ pub struct Bucket { credentials: Arc>, pub extra_headers: HeaderMap, pub extra_query: Query, - pub request_timeout: Option, path_style: bool, listobjects_v2: bool, - #[cfg(feature = "with-tokio")] - http_client: reqwest::Client, - #[cfg(feature = "with-tokio")] - client_options: crate::request::tokio_backend::ClientOptions, + backend: DefaultBackend, } impl Bucket { @@ -166,9 +149,8 @@ impl Bucket { } } - #[cfg(feature = "with-tokio")] - pub fn http_client(&self) -> reqwest::Client { - self.http_client.clone() + pub fn backend(&self) -> &DefaultBackend { + &self.backend } #[maybe_async::maybe_async] @@ -176,10 +158,10 @@ impl Bucket { &self, path: &str, command: Command<'a>, - ) -> Result, S3Error> { + ) -> Result { self.credentials_refresh().await?; let http_request = build_request(self, path, command).await?; - RequestImpl::new(http_request, self) + Ok(self.backend.request(http_request)) } #[maybe_async::maybe_async] @@ -644,22 +626,15 @@ impl Bucket { region: Region, credentials: Credentials, ) -> Result, S3Error> { - #[cfg(feature = "with-tokio")] - let options = ClientOptions::default(); - Ok(Box::new(Bucket { name: name.into(), region, credentials: Arc::new(RwLock::new(credentials)), extra_headers: HeaderMap::new(), extra_query: HashMap::new(), - request_timeout: DEFAULT_REQUEST_TIMEOUT, path_style: false, listobjects_v2: true, - #[cfg(feature = "with-tokio")] - http_client: client(&options)?, - #[cfg(feature = "with-tokio")] - client_options: options, + backend: DefaultBackend::default(), })) } @@ -675,22 +650,15 @@ impl Bucket { /// let bucket = Bucket::new_public(bucket_name, region).unwrap(); /// ``` pub fn new_public(name: &str, region: Region) -> Result { - #[cfg(feature = "with-tokio")] - let options = ClientOptions::default(); - Ok(Bucket { name: name.into(), region, credentials: Arc::new(RwLock::new(Credentials::anonymous()?)), extra_headers: HeaderMap::new(), extra_query: HashMap::new(), - request_timeout: DEFAULT_REQUEST_TIMEOUT, path_style: false, listobjects_v2: true, - #[cfg(feature = "with-tokio")] - http_client: client(&options)?, - #[cfg(feature = "with-tokio")] - client_options: options, + backend: DefaultBackend::default(), }) } @@ -701,13 +669,9 @@ impl Bucket { credentials: self.credentials.clone(), extra_headers: self.extra_headers.clone(), extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, path_style: true, listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), + backend: self.backend.clone(), }) } @@ -718,13 +682,9 @@ impl Bucket { credentials: self.credentials.clone(), extra_headers, extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, path_style: self.path_style, listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), + backend: self.backend.clone(), }) } @@ -738,51 +698,23 @@ impl Bucket { credentials: self.credentials.clone(), extra_headers: self.extra_headers.clone(), extra_query, - request_timeout: self.request_timeout, path_style: self.path_style, listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), + backend: self.backend.clone(), }) } - #[cfg(not(feature = "with-tokio"))] - pub fn with_request_timeout(&self, request_timeout: Duration) -> Result, S3Error> { - Ok(Box::new(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: Some(request_timeout), - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - })) - } - - #[cfg(feature = "with-tokio")] - pub fn with_request_timeout(&self, request_timeout: Duration) -> Result, S3Error> { - let options = ClientOptions { - request_timeout: Some(request_timeout), - ..Default::default() - }; - - Ok(Box::new(Bucket { + pub fn with_backend(&self, backend: DefaultBackend) -> Bucket { + Bucket { name: self.name.clone(), region: self.region.clone(), credentials: self.credentials.clone(), extra_headers: self.extra_headers.clone(), extra_query: self.extra_query.clone(), - request_timeout: Some(request_timeout), path_style: self.path_style, listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: client(&options)?, - #[cfg(feature = "with-tokio")] - client_options: options, - })) + backend, + } } pub fn with_listobjects_v1(&self) -> Bucket { @@ -792,91 +724,12 @@ impl Bucket { credentials: self.credentials.clone(), extra_headers: self.extra_headers.clone(), extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, path_style: self.path_style, listobjects_v2: false, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), + backend: self.backend.clone(), } } - /// Configures a bucket to accept invalid SSL certificates and hostnames. - /// - /// This method is available only when either the `tokio-native-tls` or `tokio-rustls-tls` feature is enabled. - /// - /// # Parameters - /// - /// - `accept_invalid_certs`: A boolean flag that determines whether the client should accept invalid SSL certificates. - /// - `accept_invalid_hostnames`: A boolean flag that determines whether the client should accept invalid hostnames. - /// - /// # Returns - /// - /// Returns a `Result` containing the newly configured `Bucket` instance if successful, or an `S3Error` if an error occurs during client configuration. - /// - /// # Errors - /// - /// This function returns an `S3Error` if the HTTP client configuration fails. - /// - /// # Example - /// - /// ```rust - /// # use s3::bucket::Bucket; - /// # use s3::error::S3Error; - /// # use s3::creds::Credentials; - /// # use s3::Region; - /// # use std::str::FromStr; - /// - /// # fn example() -> Result<(), S3Error> { - /// let bucket = Bucket::new("my-bucket", Region::from_str("us-east-1")?, Credentials::default()?)? - /// .set_dangereous_config(true, true)?; - /// # Ok(()) - /// # } - /// - #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] - pub fn set_dangereous_config( - &self, - accept_invalid_certs: bool, - accept_invalid_hostnames: bool, - ) -> Result { - let mut options = self.client_options.clone(); - options.accept_invalid_certs = accept_invalid_certs; - options.accept_invalid_hostnames = accept_invalid_hostnames; - - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - http_client: client(&options)?, - client_options: options, - }) - } - - #[cfg(feature = "with-tokio")] - pub fn set_proxy(&self, proxy: reqwest::Proxy) -> Result { - let mut options = self.client_options.clone(); - options.proxy = Some(proxy); - - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - http_client: client(&options)?, - client_options: options, - }) - } - /// Copy file from an S3 path, internally within the same bucket. /// /// # Example: @@ -2856,16 +2709,6 @@ impl Bucket { self.path_style = false; } - /// Configure bucket to apply this request timeout to all HTTP - /// requests, or no (infinity) timeout if `None`. Defaults to - /// 30 seconds. - /// - /// Only the [`attohttpc`] and the [`hyper`] backends obey this option; - /// async code may instead await with a timeout. - pub fn set_request_timeout(&mut self, timeout: Option) { - self.request_timeout = timeout; - } - /// Configure bucket to use the older ListObjects API /// /// If your provider doesn't support the ListObjectsV2 interface, set this to @@ -3016,10 +2859,6 @@ impl Bucket { pub fn extra_query_mut(&mut self) -> &mut Query { &mut self.extra_query } - - pub fn request_timeout(&self) -> Option { - self.request_timeout - } } #[cfg(test)] @@ -4036,6 +3875,7 @@ mod test { #[test] #[ignore] + #[cfg(any(feature = "with-tokio", feature = "sync"))] fn test_builder_composition() { use std::time::Duration; @@ -4044,11 +3884,18 @@ mod test { "eu-central-1".parse().unwrap(), test_aws_credentials(), ) - .unwrap() - .with_request_timeout(Duration::from_secs(10)) .unwrap(); + let bucket = bucket.with_backend( + bucket + .backend() + .with_request_timeout(Some(Duration::from_secs(10))) + .unwrap(), + ); - assert_eq!(bucket.request_timeout(), Some(Duration::from_secs(10))); + assert_eq!( + bucket.backend().request_timeout(), + Some(Duration::from_secs(10)) + ); } #[maybe_async::test( @@ -4118,7 +3965,8 @@ mod test { .with_path_style(); // Set dangerous config (allow invalid certs, allow invalid hostnames) - let bucket = bucket.set_dangereous_config(true, true).unwrap(); + let bucket = + bucket.with_backend(bucket.backend().with_dangereous_config(true, true).unwrap()); // Test that exists() works with the dangerous config // This should not panic or fail due to SSL certificate issues diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index 6de73aca4e..cf635b2877 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -4,7 +4,6 @@ use std::borrow::Cow; use std::pin::Pin; use std::task::{Context, Poll}; -use crate::bucket::Bucket; use crate::error::S3Error; use crate::request::Request; @@ -105,28 +104,21 @@ impl<'a> Request for SurfRequest<'a> { } } -impl<'a> SurfRequest<'a> { - pub fn new(request: http::Request>, _: &Bucket) -> Result { - Ok(Self { +impl SurfBackend { + pub(crate) fn request<'a>(&self, request: http::Request>) -> SurfRequest<'a> { + SurfRequest { request, sync: false, - }) + } } } +#[derive(Clone, Debug, Default)] +pub struct SurfBackend {} + #[cfg(test)] mod tests { use super::*; - use crate::Bucket; - use crate::creds::Credentials; - - // Fake keys - otherwise using Credentials::default will use actual user - // credentials if they exist. - fn fake_credentials() -> Credentials { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() - } #[async_std::test] async fn test_build() { @@ -137,11 +129,9 @@ mod tests { .header("h2", "v2") .body(b"sneaky".into()) .unwrap(); - let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); - let mut r = SurfRequest::new(http_request, &bucket) - .unwrap() + let mut r = SurfBackend::default() + .request(http_request) .build() .unwrap() .build(); diff --git a/s3/src/request/backend.rs b/s3/src/request/backend.rs new file mode 100644 index 0000000000..4538c64120 --- /dev/null +++ b/s3/src/request/backend.rs @@ -0,0 +1,14 @@ +use std::time::Duration; + +#[cfg(feature = "with-async-std")] +pub(crate) use crate::request::async_std_backend::SurfBackend as DefaultBackend; +#[cfg(feature = "sync")] +pub(crate) use crate::request::blocking::AttoBackend as DefaultBackend; +#[cfg(feature = "with-tokio")] +pub(crate) use crate::request::tokio_backend::ReqwestBackend as DefaultBackend; + +/// Default request timeout. Override with s3::Bucket::with_request_timeout. +/// +/// For backward compatibility, only AttoBackend uses this. ReqwestBackend +/// supports a timeout but none is set by default. +pub const DEFAULT_REQUEST_TIMEOUT: Option = Some(Duration::from_secs(60)); diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index dd4e72b23f..5c68314e06 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -2,8 +2,8 @@ extern crate base64; extern crate md5; use attohttpc::header::HeaderName; +use std::time::Duration; -use crate::bucket::Bucket; use crate::error::S3Error; use std::borrow::Cow; @@ -12,7 +12,7 @@ use crate::request::Request; // Temporary structure for making a request pub struct AttoRequest<'a> { request: http::Request>, - request_timeout: Option, + request_timeout: Option, pub sync: bool, } @@ -67,29 +67,42 @@ impl<'a> Request for AttoRequest<'a> { } } -impl<'a> AttoRequest<'a> { - pub fn new(request: http::Request>, bucket: &Bucket) -> Result { - Ok(Self { +impl AttoBackend { + pub(crate) fn request<'a>(&self, request: http::Request>) -> AttoRequest<'a> { + AttoRequest { request, - request_timeout: bucket.request_timeout, + request_timeout: self.request_timeout, sync: false, - }) + } + } +} + +#[derive(Clone, Debug)] +pub struct AttoBackend { + request_timeout: Option, +} + +impl Default for AttoBackend { + fn default() -> Self { + Self { + request_timeout: crate::request::backend::DEFAULT_REQUEST_TIMEOUT, + } + } +} + +impl AttoBackend { + pub fn with_request_timeout(&self, request_timeout: Option) -> Result { + Ok(Self { request_timeout }) + } + + pub fn request_timeout(&self) -> Option { + self.request_timeout } } #[cfg(test)] mod tests { use super::*; - use crate::bucket::Bucket; - use awscreds::Credentials; - - // Fake keys - otherwise using Credentials::default will use actual user - // credentials if they exist. - fn fake_credentials() -> Credentials { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() - } #[test] fn test_build() { @@ -100,10 +113,8 @@ mod tests { .header("h2", "v2") .body(b"sneaky".into()) .unwrap(); - let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); - let req = AttoRequest::new(http_request, &bucket).unwrap(); + let req = AttoBackend::default().request(http_request); let mut r = req.build().unwrap(); assert_eq!(r.inspect().method(), http::Method::POST); diff --git a/s3/src/request/mod.rs b/s3/src/request/mod.rs index 45b6e59c76..c11536af50 100644 --- a/s3/src/request/mod.rs +++ b/s3/src/request/mod.rs @@ -1,5 +1,6 @@ #[cfg(feature = "with-async-std")] pub mod async_std_backend; +pub mod backend; #[cfg(feature = "sync")] pub mod blocking; pub mod request_trait; diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 93e5d956ee..4df24af16f 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -3,9 +3,9 @@ extern crate md5; use maybe_async::maybe_async; use std::borrow::Cow; +use std::time::Duration; use super::request_trait::Request; -use crate::bucket::Bucket; use crate::error::S3Error; #[derive(Clone, Debug, Default)] @@ -18,8 +18,7 @@ pub(crate) struct ClientOptions { pub accept_invalid_hostnames: bool, } -#[cfg(feature = "with-tokio")] -pub(crate) fn client(options: &ClientOptions) -> Result { +fn client(options: &ClientOptions) -> Result { let client = reqwest::Client::builder(); let client = if let Some(timeout) = options.request_timeout { @@ -78,12 +77,95 @@ impl<'a> Request for ReqwestRequest<'a> { } } -impl<'a> ReqwestRequest<'a> { - pub fn new(request: http::Request>, bucket: &Bucket) -> Result { - Ok(Self { +impl ReqwestBackend { + pub(crate) fn request<'a>(&self, request: http::Request>) -> ReqwestRequest<'a> { + ReqwestRequest { request, - client: bucket.http_client(), + client: self.http_client.clone(), sync: false, + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct ReqwestBackend { + http_client: reqwest::Client, + client_options: ClientOptions, +} + +impl ReqwestBackend { + pub fn with_request_timeout(&self, request_timeout: Option) -> Result { + let client_options = ClientOptions { + request_timeout, + ..self.client_options.clone() + }; + Ok(Self { + http_client: client(&client_options)?, + client_options, + }) + } + + pub fn request_timeout(&self) -> Option { + self.client_options.request_timeout + } + + /// Configures a bucket to accept invalid SSL certificates and hostnames. + /// + /// This method is available only when either the `tokio-native-tls` or `tokio-rustls-tls` feature is enabled. + /// + /// # Parameters + /// + /// - `accept_invalid_certs`: A boolean flag that determines whether the client should accept invalid SSL certificates. + /// - `accept_invalid_hostnames`: A boolean flag that determines whether the client should accept invalid hostnames. + /// + /// # Returns + /// + /// Returns a `Result` containing the newly configured `Bucket` instance if successful, or an `S3Error` if an error occurs during client configuration. + /// + /// # Errors + /// + /// This function returns an `S3Error` if the HTTP client configuration fails. + /// + /// # Example + /// + /// ```rust + /// # use s3::bucket::Bucket; + /// # use s3::error::S3Error; + /// # use s3::creds::Credentials; + /// # use s3::Region; + /// # use std::str::FromStr; + /// + /// # fn example() -> Result<(), S3Error> { + /// let bucket = Bucket::new("my-bucket", Region::from_str("us-east-1")?, Credentials::default()?)? + /// .set_dangereous_config(true, true)?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] + pub fn with_dangereous_config( + &self, + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + ) -> Result { + let client_options = ClientOptions { + accept_invalid_certs, + accept_invalid_hostnames, + ..self.client_options.clone() + }; + Ok(Self { + http_client: client(&client_options)?, + client_options, + }) + } + + pub fn with_proxy(&self, proxy: reqwest::Proxy) -> Result { + let client_options = ClientOptions { + proxy: Some(proxy), + ..self.client_options.clone() + }; + Ok(Self { + http_client: client(&client_options)?, + client_options, }) } } @@ -91,18 +173,8 @@ impl<'a> ReqwestRequest<'a> { #[cfg(test)] mod tests { use super::*; - use crate::Bucket; - use crate::creds::Credentials; use http_body_util::BodyExt; - // Fake keys - otherwise using Credentials::default will use actual user - // credentials if they exist. - fn fake_credentials() -> Credentials { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() - } - #[tokio::test] async fn test_build() { let http_request = http::Request::builder() @@ -112,11 +184,9 @@ mod tests { .header("h2", "v2") .body(b"sneaky".into()) .unwrap(); - let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); - let mut r = ReqwestRequest::new(http_request, &bucket) - .unwrap() + let mut r = ReqwestBackend::default() + .request(http_request) .build() .unwrap(); From dbc7b5a3ece23474389a928cfed4dee6dcfa9963 Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Sun, 23 Nov 2025 12:38:39 +0000 Subject: [PATCH 6/9] Refactor to aid upcoming change to HTTP request backend API (bis). This hides the details of s3::request::Request inside Bucket::make_request() in preparation for replacing the backend API with one better suited to pluggable backends. --- s3/src/bucket.rs | 234 ++++++++++++++++++-------------- s3/src/put_object_request.rs | 6 +- s3/src/request/request_trait.rs | 167 ++++++++++------------- 3 files changed, 205 insertions(+), 202 deletions(-) diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 67c4252545..74d6e98913 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -46,7 +46,11 @@ use crate::region::Region; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] use crate::request::ResponseDataStream; use crate::request::backend::DefaultBackend; -use crate::request::{Request, ResponseData, build_presigned, build_request}; +use crate::request::{ + Request, ResponseBody, ResponseData, build_presigned, build_request, response_data, + response_data_to_writer, +}; +use crate::retry; use std::str::FromStr; use std::sync::Arc; @@ -62,9 +66,11 @@ use std::sync::RwLock; pub type Query = HashMap; #[cfg(feature = "with-async-std")] -use async_std::io::Write as AsyncWrite; +use async_std::{io::Write as AsyncWrite, stream::StreamExt as _}; #[cfg(feature = "with-tokio")] use tokio::io::AsyncWrite; +#[cfg(feature = "with-tokio")] +use tokio_stream::StreamExt as _; use std::io::Read; @@ -133,6 +139,18 @@ pub struct Bucket { backend: DefaultBackend, } +#[maybe_async::async_impl] +async fn exec_request( + request: R, +) -> Result, S3Error> { + retry! { request.response().await } +} + +#[maybe_async::sync_impl] +fn exec_request(request: R) -> Result, S3Error> { + retry! { request.response() } +} + impl Bucket { #[maybe_async::async_impl] /// Credential refreshing is done automatically, but can be manually triggered. @@ -154,14 +172,14 @@ impl Bucket { } #[maybe_async::maybe_async] - pub(crate) async fn make_request<'a>( + pub(crate) async fn make_request( &self, path: &str, - command: Command<'a>, - ) -> Result { + command: Command<'_>, + ) -> Result, S3Error> { self.credentials_refresh().await?; let http_request = build_request(self, path, command).await?; - Ok(self.backend.request(http_request)) + exec_request(self.backend.request(http_request)).await } #[maybe_async::maybe_async] @@ -396,13 +414,13 @@ impl Bucket { let command = Command::CreateBucket { config }; let bucket = Bucket::new(name, region, credentials)?; - let request = bucket.make_request("", command).await?; - let response_data = request.response_data(false).await?; - let response_text = response_data.as_str()?; + let response = bucket.make_request("", command).await?; + let data = response_data(response, false).await?; + let response_text = data.as_str()?; Ok(CreateBucketResponse { bucket, response_text: response_text.to_string(), - response_code: response_data.status_code(), + response_code: data.status_code(), }) } @@ -452,12 +470,12 @@ impl Bucket { /// Used by the public `list_buckets` method to retrieve the list of buckets for the configured client. #[maybe_async::maybe_async] async fn _list_buckets(&self) -> Result { - let request = self.make_request("", Command::ListBuckets).await?; - let response = request.response_data(false).await?; + let response = self.make_request("", Command::ListBuckets).await?; + let data = response_data(response, false).await?; Ok(quick_xml::de::from_str::< crate::bucket_ops::ListBucketsResponse, - >(response.as_str()?)?) + >(data.as_str()?)?) } /// Determine whether the instantiated bucket exists. @@ -558,8 +576,8 @@ impl Bucket { let command = Command::CreateBucket { config }; let bucket = Bucket::new(name, region, credentials)?.with_path_style(); - let request = bucket.make_request("", command).await?; - let response_data = request.response_data(false).await?; + let response = bucket.make_request("", command).await?; + let response_data = response_data(response, false).await?; let response_text = response_data.to_string()?; Ok(CreateBucketResponse { @@ -602,8 +620,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn delete(&self) -> Result { let command = Command::DeleteBucket; - let request = self.make_request("", command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request("", command).await?; + let response_data = response_data(response, false).await?; Ok(response_data.status_code()) } @@ -780,8 +798,8 @@ impl Bucket { let command = Command::CopyObject { from: from.as_ref(), }; - let request = self.make_request(to.as_ref(), command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request(to.as_ref(), command).await?; + let response_data = response_data(response, false).await?; Ok(response_data.status_code()) } @@ -819,8 +837,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn get_object>(&self, path: S) -> Result { let command = Command::GetObject; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] @@ -834,12 +852,12 @@ impl Bucket { expected_bucket_owner: expected_bucket_owner.to_string(), version_id, }; - let request = self.make_request(path.as_ref(), command).await?; + let response = self.make_request(path.as_ref(), command).await?; - let response = request.response_data(false).await?; + let response_data = response_data(response, false).await?; Ok(quick_xml::de::from_str::( - response.as_str()?, + response_data.as_str()?, )?) } @@ -888,8 +906,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn object_exists>(&self, path: S) -> Result { let command = Command::HeadObject; - let request = self.make_request(path.as_ref(), command).await?; - let response_data = match request.response_data(false).await { + let response = self.make_request(path.as_ref(), command).await?; + let response_data = match response_data(response, false).await { Ok(response_data) => response_data, Err(S3Error::HttpFailWithBody(status_code, error)) => { if status_code == 404 { @@ -912,8 +930,8 @@ impl Bucket { expected_bucket_owner: expected_bucket_owner.to_string(), configuration: cors_config.clone(), }; - let request = self.make_request("", command).await?; - request.response_data(false).await + let response = self.make_request("", command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] @@ -924,10 +942,10 @@ impl Bucket { let command = Command::GetBucketCors { expected_bucket_owner: expected_bucket_owner.to_string(), }; - let request = self.make_request("", command).await?; - let response = request.response_data(false).await?; + let response = self.make_request("", command).await?; + let response_data = response_data(response, false).await?; Ok(quick_xml::de::from_str::( - response.as_str()?, + response_data.as_str()?, )?) } @@ -939,16 +957,16 @@ impl Bucket { let command = Command::DeleteBucketCors { expected_bucket_owner: expected_bucket_owner.to_string(), }; - let request = self.make_request("", command).await?; - request.response_data(false).await + let response = self.make_request("", command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] pub async fn get_bucket_lifecycle(&self) -> Result { - let request = self.make_request("", Command::GetBucketLifecycle).await?; - let response = request.response_data(false).await?; + let response = self.make_request("", Command::GetBucketLifecycle).await?; + let response_data = response_data(response, false).await?; Ok(quick_xml::de::from_str::( - response.as_str()?, + response_data.as_str()?, )?) } @@ -960,16 +978,16 @@ impl Bucket { let command = Command::PutBucketLifecycle { configuration: lifecycle_config, }; - let request = self.make_request("", command).await?; - request.response_data(false).await + let response = self.make_request("", command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] pub async fn delete_bucket_lifecycle(&self) -> Result { - let request = self + let response = self .make_request("", Command::DeleteBucketLifecycle) .await?; - request.response_data(false).await + response_data(response, false).await } /// Gets torrent from an S3 path. @@ -1009,8 +1027,8 @@ impl Bucket { path: S, ) -> Result { let command = Command::GetObjectTorrent; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Gets specified inclusive byte range of file from an S3 path. @@ -1057,8 +1075,8 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Stream range of bytes from S3 path to a local file, generic over T: Write. @@ -1118,8 +1136,8 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data_to_writer(writer).await + let response = self.make_request(path.as_ref(), command).await?; + response_data_to_writer(response, writer).await } #[maybe_async::sync_impl] @@ -1135,8 +1153,8 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = self.make_request(path.as_ref(), command)?; - request.response_data_to_writer(writer) + let response = self.make_request(path.as_ref(), command)?; + response_data_to_writer(response, writer) } /// Stream file from S3 path to a local file, generic over T: Write. @@ -1183,8 +1201,8 @@ impl Bucket { writer: &mut T, ) -> Result { let command = Command::GetObject; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data_to_writer(writer).await + let response = self.make_request(path.as_ref(), command).await?; + response_data_to_writer(response, writer).await } #[maybe_async::sync_impl] @@ -1194,8 +1212,8 @@ impl Bucket { writer: &mut T, ) -> Result { let command = Command::GetObject; - let request = self.make_request(path.as_ref(), command)?; - request.response_data_to_writer(writer) + let response = self.make_request(path.as_ref(), command)?; + response_data_to_writer(response, writer) } /// Stream file from S3 path to a local file using an async stream. @@ -1244,9 +1262,19 @@ impl Bucket { &self, path: S, ) -> Result { + use http_body_util::BodyExt; let command = Command::GetObject; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data_to_stream().await + let response = self.make_request(path.as_ref(), command).await?; + let status_code = response.status(); + let stream = response.into_body().into_data_stream().map(|i| match i { + Ok(data) => Ok(data.into()), + Err(e) => Err(e.into()), + }); + + Ok(ResponseDataStream { + bytes: Box::pin(stream), + status_code: status_code.as_u16(), + }) } /// Stream file from local path to s3, generic over T: Write. @@ -1448,8 +1476,8 @@ impl Bucket { custom_headers: None, content_type, }; - let request = self.make_request(path, command).await?; - request.response_data(true).await + let response = self.make_request(path, command).await?; + response_data(response, true).await } #[maybe_async::async_impl] @@ -1541,7 +1569,7 @@ impl Bucket { // Use FuturesUnordered for bounded parallelism use futures_util::FutureExt; - use futures_util::stream::{FuturesUnordered, StreamExt}; + use futures_util::stream::FuturesUnordered; let mut part_number: u32 = 0; let mut total_size = 0; @@ -1728,14 +1756,13 @@ impl Bucket { content_type: &str, ) -> Result { let command = Command::InitiateMultipartUpload { content_type }; - let request = self.make_request(s3_path, command).await?; - let response_data = request.response_data(false).await?; - if response_data.status_code() >= 300 { - return Err(error_from_response_data(response_data)?); + let response = self.make_request(s3_path, command).await?; + let data = response_data(response, false).await?; + if data.status_code() >= 300 { + return Err(error_from_response_data(data)?); } - let msg: InitiateMultipartUploadResponse = - quick_xml::de::from_str(response_data.as_str()?)?; + let msg: InitiateMultipartUploadResponse = quick_xml::de::from_str(data.as_str()?)?; Ok(msg) } @@ -1746,14 +1773,13 @@ impl Bucket { content_type: &str, ) -> Result { let command = Command::InitiateMultipartUpload { content_type }; - let request = self.make_request(s3_path, command)?; - let response_data = request.response_data(false)?; - if response_data.status_code() >= 300 { - return Err(error_from_response_data(response_data)?); + let response = self.make_request(s3_path, command)?; + let data = response_data(response, false)?; + if data.status_code() >= 300 { + return Err(error_from_response_data(data)?); } - let msg: InitiateMultipartUploadResponse = - quick_xml::de::from_str(response_data.as_str()?)?; + let msg: InitiateMultipartUploadResponse = quick_xml::de::from_str(data.as_str()?)?; Ok(msg) } @@ -1802,8 +1828,8 @@ impl Bucket { custom_headers: None, content_type, }; - let request = self.make_request(path, command).await?; - let response_data = request.response_data(true).await?; + let response = self.make_request(path, command).await?; + let response_data = response_data(response, true).await?; if !(200..300).contains(&response_data.status_code()) { // if chunk upload failed - abort the upload match self.abort_upload(path, upload_id).await { @@ -1838,20 +1864,20 @@ impl Bucket { custom_headers: None, content_type, }; - let request = self.make_request(path, command)?; - let response_data = request.response_data(true)?; - if !(200..300).contains(&response_data.status_code()) { + let response = self.make_request(path, command)?; + let data = response_data(response, true)?; + if !(200..300).contains(&data.status_code()) { // if chunk upload failed - abort the upload match self.abort_upload(path, upload_id) { Ok(_) => { - return Err(error_from_response_data(response_data)?); + return Err(error_from_response_data(data)?); } Err(error) => { return Err(error); } } } - let etag = response_data.as_str()?; + let etag = data.as_str()?; Ok(Part { etag: etag.to_string(), part_number, @@ -1868,8 +1894,8 @@ impl Bucket { ) -> Result { let data = CompleteMultipartUploadData { parts }; let complete = Command::CompleteMultipartUpload { upload_id, data }; - let complete_request = self.make_request(path, complete).await?; - complete_request.response_data(false).await + let complete_response = self.make_request(path, complete).await?; + response_data(complete_response, false).await } #[maybe_async::sync_impl] @@ -1881,8 +1907,8 @@ impl Bucket { ) -> Result { let data = CompleteMultipartUploadData { parts }; let complete = Command::CompleteMultipartUpload { upload_id, data }; - let complete_request = self.make_request(path, complete)?; - complete_request.response_data(false) + let complete_response = self.make_request(path, complete)?; + response_data(complete_response, false) } /// Get Bucket location. @@ -1919,10 +1945,10 @@ impl Bucket { /// ``` #[maybe_async::maybe_async] pub async fn location(&self) -> Result<(Region, u16), S3Error> { - let request = self + let response = self .make_request("?location", Command::GetBucketLocation) .await?; - let response_data = request.response_data(false).await?; + let response_data = response_data(response, false).await?; let region_string = String::from_utf8_lossy(response_data.as_slice()); let region = match quick_xml::de::from_reader(region_string.as_bytes()) { Ok(r) => { @@ -1981,8 +2007,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn delete_object>(&self, path: S) -> Result { let command = Command::DeleteObject; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Head object from S3. @@ -2023,10 +2049,10 @@ impl Bucket { path: S, ) -> Result<(HeadObjectResult, u16), S3Error> { let command = Command::HeadObject; - let request = self.make_request(path.as_ref(), command).await?; - let (headers, status) = request.response_header().await?; - let header_object = HeadObjectResult::from(&headers); - Ok((header_object, status)) + let response = self.make_request(path.as_ref(), command).await?; + let (head, _) = response.into_parts(); + let header_object = HeadObjectResult::from(&head.headers); + Ok((header_object, head.status.as_u16())) } /// Put into an S3 bucket, with explicit content-type. @@ -2075,8 +2101,8 @@ impl Bucket { custom_headers: None, multipart: None, }; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(true).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, true).await } /// Put into an S3 bucket, with explicit content-type and custom headers for the request. @@ -2135,8 +2161,8 @@ impl Bucket { custom_headers, multipart: None, }; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(true).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, true).await } /// Put into an S3 bucket, with custom headers for the request. @@ -2329,8 +2355,8 @@ impl Bucket { ) -> Result { let content = self._tags_xml(tags); let command = Command::PutObjectTagging { tags: &content }; - let request = self.make_request(path, command).await?; - request.response_data(false).await + let response = self.make_request(path, command).await?; + response_data(response, false).await } /// Delete tags from an S3 object. @@ -2371,8 +2397,8 @@ impl Bucket { path: S, ) -> Result { let command = Command::DeleteObjectTagging; - let request = self.make_request(path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Retrieve an S3 object list of tags. @@ -2414,8 +2440,8 @@ impl Bucket { path: S, ) -> Result<(Vec, u16), S3Error> { let command = Command::GetObjectTagging {}; - let request = self.make_request(path.as_ref(), command).await?; - let result = request.response_data(false).await?; + let response = self.make_request(path.as_ref(), command).await?; + let result = response_data(response, false).await?; let mut tags = Vec::new(); @@ -2487,8 +2513,8 @@ impl Bucket { max_keys, } }; - let request = self.make_request("/", command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request("/", command).await?; + let response_data = response_data(response, false).await?; let list_bucket_result = quick_xml::de::from_reader(response_data.as_slice())?; Ok((list_bucket_result, response_data.status_code())) @@ -2571,8 +2597,8 @@ impl Bucket { key_marker, max_uploads, }; - let request = self.make_request("/", command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request("/", command).await?; + let response_data = response_data(response, false).await?; let list_bucket_result = quick_xml::de::from_reader(response_data.as_slice())?; Ok((list_bucket_result, response_data.status_code())) @@ -2675,8 +2701,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn abort_upload(&self, key: &str, upload_id: &str) -> Result<(), S3Error> { let abort = Command::AbortMultipartUpload { upload_id }; - let abort_request = self.make_request(key, abort).await?; - let response_data = abort_request.response_data(false).await?; + let abort_response = self.make_request(key, abort).await?; + let response_data = response_data(abort_response, false).await?; if (200..300).contains(&response_data.status_code()) { Ok(()) diff --git a/s3/src/put_object_request.rs b/s3/src/put_object_request.rs index 014b90f808..353677bac5 100644 --- a/s3/src/put_object_request.rs +++ b/s3/src/put_object_request.rs @@ -4,7 +4,7 @@ //! various options including custom headers, content type, and other metadata. use crate::error::S3Error; -use crate::request::{Request as _, ResponseData}; +use crate::request::{ResponseData, response_data}; use crate::{Bucket, command::Command}; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -188,8 +188,8 @@ impl<'a> PutObjectRequest<'a> { multipart: None, }; - let request = self.bucket.make_request(&self.path, command).await?; - request.response_data(true).await + let response = self.bucket.make_request(&self.path, command).await?; + response_data(response, true).await } } diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index 6578da1bc6..a6af5772e5 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -17,8 +17,8 @@ use crate::LONG_DATETIME; use crate::bucket::Bucket; use crate::command::{Command, HttpMethod}; use crate::error::S3Error; +use crate::signing; use crate::utils::now_utc; -use crate::{retry, signing}; use bytes::Bytes; use http::HeaderMap; use http::header::{ @@ -255,108 +255,85 @@ pub trait Request { type ResponseBody: ResponseBody; async fn response(&self) -> Result, S3Error>; +} - #[maybe_async::async_impl] - async fn response_with_retry(&self) -> Result, S3Error> { - retry! { self.response().await } - } - - #[maybe_async::sync_impl] - fn response_with_retry(&self) -> Result, S3Error> { - retry! { self.response() } - } - - async fn response_data(&self, etag: bool) -> Result { - let response = self.response_with_retry().await?; - let (mut head, body) = response.into_parts(); - let status_code = head.status.as_u16(); - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = head.headers.remove("ETag") { - Bytes::from(etag.to_str()?.to_string()) - } else { - Bytes::from("") - } +#[maybe_async::maybe_async] +pub(crate) async fn response_data( + response: http::Response, + etag: bool, +) -> Result { + let (mut head, body) = response.into_parts(); + let status_code = head.status.as_u16(); + // When etag=true, we extract the ETag header and return it as the body. + // This is used for PUT operations (regular puts, multipart chunks) where: + // 1. S3 returns an empty or non-useful response body + // 2. The ETag header contains the essential information we need + // 3. The calling code expects to get the ETag via response_data.as_str() + // + // Note: This approach means we discard any actual response body when etag=true, + // but for the operations that use this (PUTs), the body is typically empty + // or contains redundant information already available in headers. + // + // TODO: Refactor this to properly return the response body and access ETag + // from headers instead of replacing the body. This would be a breaking change. + let body_vec = if etag { + if let Some(etag) = head.headers.remove("ETag") { + Bytes::from(etag.to_str()?.to_string()) } else { - body.into_bytes().await.map_err(Into::::into)? - }; - let response_headers = head - .headers - .into_iter() - .filter_map(|(k, v)| { - k.map(|k| { - ( - k.to_string(), - v.to_str() - .unwrap_or("could-not-decode-header-value") - .to_string(), - ) - }) - }) - .collect::>(); - Ok(ResponseData::new(body_vec, status_code, response_headers)) - } - - #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result { - let response = self.response_with_retry().await?; - - let status_code = response.status(); - let mut stream = pin!(response.into_body().into_data_stream()); - - while let Some(item) = stream.next().await { - writer.write_all(item.map_err(Into::into)?.as_ref()).await?; + Bytes::from("") } + } else { + body.into_bytes().await.map_err(Into::::into)? + }; + let response_headers = head + .headers + .into_iter() + .filter_map(|(k, v)| { + k.map(|k| { + ( + k.to_string(), + v.to_str() + .unwrap_or("could-not-decode-header-value") + .to_string(), + ) + }) + }) + .collect::>(); + Ok(ResponseData::new(body_vec, status_code, response_headers)) +} - Ok(status_code.as_u16()) - } +#[cfg(any(feature = "with-tokio", feature = "with-async-std"))] +pub(crate) async fn response_data_to_writer( + response: http::Response, + writer: &mut T, +) -> Result +where + R: ResponseBody, + T: AsyncWrite + Send + Unpin + ?Sized, +{ + let status_code = response.status(); + let mut stream = pin!(response.into_body().into_data_stream()); - #[cfg(feature = "sync")] - fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result { - let response = self.response_with_retry().await?; - let status_code = response.status(); - let mut body = response.into_body(); - std::io::copy(&mut body, writer)?; - Ok(status_code.as_u16()) + while let Some(item) = stream.next().await { + writer.write_all(item.map_err(Into::into)?.as_ref()).await?; } - #[cfg(any(feature = "with-async-std", feature = "with-tokio"))] - async fn response_data_to_stream(&self) -> Result { - let response = self.response_with_retry().await?; - let status_code = response.status(); - let stream = response.into_body().into_data_stream().map(|i| match i { - Ok(data) => Ok(data.into()), - Err(e) => Err(e.into()), - }); - - Ok(ResponseDataStream { - bytes: Box::pin(stream), - status_code: status_code.as_u16(), - }) - } + Ok(status_code.as_u16()) +} - async fn response_header(&self) -> Result<(http::HeaderMap, u16), S3Error> { - let response = self.response_with_retry().await?; - let (head, _) = response.into_parts(); - Ok((head.headers, head.status.as_u16())) - } +#[cfg(feature = "sync")] +pub(crate) fn response_data_to_writer( + response: http::Response, + writer: &mut T, +) -> Result +where + R: ResponseBody, + T: std::io::Write + Send + ?Sized, +{ + let status_code = response.status(); + let mut body = response.into_body(); + std::io::copy(&mut body, writer)?; + Ok(status_code.as_u16()) } struct BuildHelper<'temp, 'body> { From ea7c61a8408884dc4e62b2f7a45c723a29cc773e Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Sun, 23 Nov 2025 19:06:36 +0000 Subject: [PATCH 7/9] Use tower::Service (or a sync version of it) to interface with backends. --- s3/Cargo.toml | 5 +- s3/src/bucket.rs | 37 +++++--- s3/src/request/async_std_backend.rs | 129 +++++++++++++--------------- s3/src/request/backend.rs | 12 +++ s3/src/request/blocking.rs | 77 +++++++---------- s3/src/request/request_trait.rs | 7 -- s3/src/request/tokio_backend.rs | 122 +++++++++++++++----------- s3/src/utils/mod.rs | 34 ++++++++ 8 files changed, 235 insertions(+), 188 deletions(-) diff --git a/s3/Cargo.toml b/s3/Cargo.toml index 5a9f70bbb1..ad199f5fc9 100644 --- a/s3/Cargo.toml +++ b/s3/Cargo.toml @@ -81,6 +81,7 @@ tokio = { version = "1", features = [ "io-util", ], optional = true, default-features = false } tokio-stream = { version = "0.1", optional = true } +tower-service = { version = "0.3.3", optional = true } url = "2" [features] @@ -88,8 +89,8 @@ default = ["fail-on-err", "tags", "tokio-native-tls"] sync = ["attohttpc", "maybe-async/is_sync"] with-async-std-hyper = ["with-async-std", "surf/hyper-client"] -with-async-std = ["dep:http-body", "dep:http-body-util", "async-std", "futures-util", "surf", "sysinfo"] -with-tokio = ["dep:http-body", "dep:http-body-util", "futures-util", "reqwest", "tokio", "tokio/fs", "tokio-stream", "sysinfo"] +with-async-std = ["dep:http-body", "dep:http-body-util", "dep:tower-service", "async-std", "futures-util", "surf", "sysinfo"] +with-tokio = ["dep:http-body", "dep:http-body-util", "dep:tower-service", "futures-util", "reqwest", "tokio", "tokio/fs", "tokio-stream", "sysinfo"] blocking = ["block_on_proc", "tokio/rt", "tokio/rt-multi-thread"] fail-on-err = [] diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 74d6e98913..c97085c446 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -37,6 +37,7 @@ use block_on_proc::block_on; #[cfg(feature = "tags")] use minidom::Element; +use std::borrow::Cow; use std::collections::HashMap; use crate::bucket_ops::{BucketConfiguration, CreateBucketResponse}; @@ -47,7 +48,7 @@ use crate::region::Region; use crate::request::ResponseDataStream; use crate::request::backend::DefaultBackend; use crate::request::{ - Request, ResponseBody, ResponseData, build_presigned, build_request, response_data, + ResponseBody, ResponseData, build_presigned, build_request, response_data, response_data_to_writer, }; use crate::retry; @@ -139,18 +140,6 @@ pub struct Bucket { backend: DefaultBackend, } -#[maybe_async::async_impl] -async fn exec_request( - request: R, -) -> Result, S3Error> { - retry! { request.response().await } -} - -#[maybe_async::sync_impl] -fn exec_request(request: R) -> Result, S3Error> { - retry! { request.response() } -} - impl Bucket { #[maybe_async::async_impl] /// Credential refreshing is done automatically, but can be manually triggered. @@ -171,6 +160,26 @@ impl Bucket { &self.backend } + #[maybe_async::async_impl] + async fn exec_request( + &self, + request: http::Request>, + ) -> Result, S3Error> { + use tower_service::Service as _; + let mut backend = self.backend.clone(); + retry! { crate::utils::service_ready::Ready::new(&mut backend).await?.call(request.clone()).await } + } + + #[maybe_async::sync_impl] + fn exec_request( + &self, + request: http::Request>, + ) -> Result, S3Error> { + use crate::request::backend::SyncService as _; + let mut backend = self.backend.clone(); + retry! { backend.call(request.clone()) } + } + #[maybe_async::maybe_async] pub(crate) async fn make_request( &self, @@ -179,7 +188,7 @@ impl Bucket { ) -> Result, S3Error> { self.credentials_refresh().await?; let http_request = build_request(self, path, command).await?; - exec_request(self.backend.request(http_request)).await + self.exec_request(http_request).await } #[maybe_async::maybe_async] diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index cf635b2877..ed25a94d08 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -1,53 +1,46 @@ use bytes::Bytes; use futures_util::AsyncBufRead as _; -use std::borrow::Cow; use std::pin::Pin; use std::task::{Context, Poll}; +use tower_service::Service; use crate::error::S3Error; -use crate::request::Request; +use crate::request::backend::BackendRequestBody; use http_body::Frame; -use maybe_async::maybe_async; use surf::http::Method; use surf::http::headers::{HeaderName, HeaderValue}; -// Temporary structure for making a request -pub struct SurfRequest<'a> { - request: http::Request>, - pub sync: bool, -} - -impl SurfRequest<'_> { - fn build(&self) -> Result { - let url = format!("{}", self.request.uri()).parse()?; - let mut request = match *self.request.method() { - http::Method::GET => surf::Request::builder(Method::Get, url), - http::Method::DELETE => surf::Request::builder(Method::Delete, url), - http::Method::PUT => surf::Request::builder(Method::Put, url), - http::Method::POST => surf::Request::builder(Method::Post, url), - http::Method::HEAD => surf::Request::builder(Method::Head, url), - ref m => surf::Request::builder( - m.as_str() - .parse() - .map_err(|e: surf::Error| S3Error::Surf(e.to_string()))?, - url, - ), - } - .body(self.request.body().clone().into_owned()); - - for (name, value) in self.request.headers().iter() { - request = request.header( - HeaderName::from_bytes(AsRef::<[u8]>::as_ref(&name).to_vec()) - .expect("Could not parse heaeder name"), - HeaderValue::from_bytes(AsRef::<[u8]>::as_ref(&value).to_vec()) - .expect("Could not parse header value"), - ); - } - - Ok(request) +fn http_request_to_surf_request( + request: http::Request>, +) -> Result { + let url = format!("{}", request.uri()).parse()?; + let mut builder = match *request.method() { + http::Method::GET => surf::Request::builder(Method::Get, url), + http::Method::DELETE => surf::Request::builder(Method::Delete, url), + http::Method::PUT => surf::Request::builder(Method::Put, url), + http::Method::POST => surf::Request::builder(Method::Post, url), + http::Method::HEAD => surf::Request::builder(Method::Head, url), + ref m => surf::Request::builder( + m.as_str() + .parse() + .map_err(|e: surf::Error| S3Error::Surf(e.to_string()))?, + url, + ), + } + .body(request.body().clone().into_owned()); + + for (name, value) in request.headers().iter() { + builder = builder.header( + HeaderName::from_bytes(AsRef::<[u8]>::as_ref(&name).to_vec()) + .expect("Could not parse heaeder name"), + HeaderValue::from_bytes(AsRef::<[u8]>::as_ref(&value).to_vec()) + .expect("Could not parse header value"), + ); } + + Ok(builder) } pub struct SurfBody(surf::Body); @@ -78,37 +71,37 @@ impl http_body::Body for SurfBody { } } -#[maybe_async] -impl<'a> Request for SurfRequest<'a> { - type ResponseBody = SurfBody; +impl Service>> for SurfBackend { + type Response = http::Response; + type Error = S3Error; + type Future = Pin> + Send>>; - async fn response(&self) -> Result, S3Error> { - let mut response = self - .build()? - .send() - .await - .map_err(|e| S3Error::Surf(e.to_string()))?; - - if cfg!(feature = "fail-on-err") && !response.status().is_success() { - return Err(S3Error::HttpFail); - } - - let mut builder = - http::Response::builder().status(http::StatusCode::from_u16(response.status().into())?); - for (name, values) in response.iter() { - for value in values { - builder = builder.header(name.as_str(), value.as_str()); - } - } - Ok(builder.body(SurfBody(response.take_body()))?) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } -} -impl SurfBackend { - pub(crate) fn request<'a>(&self, request: http::Request>) -> SurfRequest<'a> { - SurfRequest { - request, - sync: false, + fn call(&mut self, request: http::Request>) -> Self::Future { + match http_request_to_surf_request(request) { + Ok(request) => { + let fut = request.send(); + Box::pin(async move { + let mut response = fut.await.map_err(|e| S3Error::Surf(e.to_string()))?; + + if cfg!(feature = "fail-on-err") && !response.status().is_success() { + return Err(S3Error::HttpFail); + } + + let mut builder = http::Response::builder() + .status(http::StatusCode::from_u16(response.status().into())?); + for (name, values) in response.iter() { + for value in values { + builder = builder.header(name.as_str(), value.as_str()); + } + } + Ok(builder.body(SurfBody(response.take_body()))?) + }) + } + Err(e) => Box::pin(std::future::ready(Err(e))), } } } @@ -130,11 +123,7 @@ mod tests { .body(b"sneaky".into()) .unwrap(); - let mut r = SurfBackend::default() - .request(http_request) - .build() - .unwrap() - .build(); + let mut r = http_request_to_surf_request(http_request).unwrap().build(); assert_eq!(r.method(), Method::Post); assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); diff --git a/s3/src/request/backend.rs b/s3/src/request/backend.rs index 4538c64120..344be83c22 100644 --- a/s3/src/request/backend.rs +++ b/s3/src/request/backend.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::time::Duration; #[cfg(feature = "with-async-std")] @@ -12,3 +13,14 @@ pub(crate) use crate::request::tokio_backend::ReqwestBackend as DefaultBackend; /// For backward compatibility, only AttoBackend uses this. ReqwestBackend /// supports a timeout but none is set by default. pub const DEFAULT_REQUEST_TIMEOUT: Option = Some(Duration::from_secs(60)); + +pub type BackendRequestBody<'a> = Cow<'a, [u8]>; + +/// A simplified version of tower_service::Service without async +#[cfg(feature = "sync")] +pub trait SyncService { + type Response; + type Error; + + fn call(&mut self, _: R) -> Result; +} diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index 5c68314e06..b344acab52 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -7,50 +7,46 @@ use std::time::Duration; use crate::error::S3Error; use std::borrow::Cow; -use crate::request::Request; +use crate::request::backend::{BackendRequestBody, SyncService}; -// Temporary structure for making a request -pub struct AttoRequest<'a> { - request: http::Request>, +fn http_request_to_atto_request( + request: http::Request>, request_timeout: Option, - pub sync: bool, -} +) -> Result>>, S3Error> { + let mut session = attohttpc::Session::new(); -impl AttoRequest<'_> { - fn build( - &self, - ) -> Result>>, S3Error> { - let mut session = attohttpc::Session::new(); + for (name, value) in request.headers().iter() { + session.header(HeaderName::from_bytes(name.as_ref())?, value.to_str()?); + } - for (name, value) in self.request.headers().iter() { - session.header(HeaderName::from_bytes(name.as_ref())?, value.to_str()?); - } + if let Some(timeout) = request_timeout { + session.timeout(timeout) + } - if let Some(timeout) = self.request_timeout { - session.timeout(timeout) + let url = format!("{}", request.uri()); + let builder = match *request.method() { + http::Method::GET => session.get(url), + http::Method::DELETE => session.delete(url), + http::Method::PUT => session.put(url), + http::Method::POST => session.post(url), + http::Method::HEAD => session.head(url), + _ => { + return Err(S3Error::HttpFailWithBody(405, "".into())); } + }; - let url = format!("{}", self.request.uri()); - let request = match *self.request.method() { - http::Method::GET => session.get(url), - http::Method::DELETE => session.delete(url), - http::Method::PUT => session.put(url), - http::Method::POST => session.post(url), - http::Method::HEAD => session.head(url), - _ => { - return Err(S3Error::HttpFailWithBody(405, "".into())); - } - }; - - Ok(request.bytes(self.request.body().clone())) - } + Ok(builder.bytes(request.body().clone())) } -impl<'a> Request for AttoRequest<'a> { - type ResponseBody = attohttpc::ResponseReader; +impl SyncService>> for AttoBackend { + type Response = http::Response; + type Error = S3Error; - fn response(&self) -> Result, S3Error> { - let response = self.build()?.send()?; + fn call( + &mut self, + request: http::Request>, + ) -> Result, S3Error> { + let response = http_request_to_atto_request(request, self.request_timeout)?.send()?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { let status = response.status().as_u16(); @@ -67,16 +63,6 @@ impl<'a> Request for AttoRequest<'a> { } } -impl AttoBackend { - pub(crate) fn request<'a>(&self, request: http::Request>) -> AttoRequest<'a> { - AttoRequest { - request, - request_timeout: self.request_timeout, - sync: false, - } - } -} - #[derive(Clone, Debug)] pub struct AttoBackend { request_timeout: Option, @@ -114,8 +100,7 @@ mod tests { .body(b"sneaky".into()) .unwrap(); - let req = AttoBackend::default().request(http_request); - let mut r = req.build().unwrap(); + let mut r = http_request_to_atto_request(http_request, None).unwrap(); assert_eq!(r.inspect().method(), http::Method::POST); assert_eq!(r.inspect().url().as_str(), "https://example.com/foo?bar=1"); diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index a6af5772e5..f29ae9bd13 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -250,13 +250,6 @@ impl ResponseBody for T { } } -#[maybe_async::maybe_async] -pub trait Request { - type ResponseBody: ResponseBody; - - async fn response(&self) -> Result, S3Error>; -} - #[maybe_async::maybe_async] pub(crate) async fn response_data( response: http::Response, diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 4df24af16f..9ab0b09d0c 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -1,11 +1,12 @@ extern crate base64; extern crate md5; -use maybe_async::maybe_async; -use std::borrow::Cow; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; +use tower_service::Service; -use super::request_trait::Request; +use super::backend::BackendRequestBody; use crate::error::S3Error; #[derive(Clone, Debug, Default)] @@ -47,53 +48,48 @@ fn client(options: &ClientOptions) -> Result { Ok(client.build()?) } -// Temporary structure for making a request -pub struct ReqwestRequest<'a> { - request: http::Request>, - client: reqwest::Client, - pub sync: bool, -} -impl ReqwestRequest<'_> { - fn build(&self) -> Result { - Ok(self.request.clone().map(|b| b.into_owned()).try_into()?) +impl Service>> for ReqwestBackend +where + T: Service + + Send + + 'static, + T::Future: Send, +{ + type Response = http::Response; + type Error = S3Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.http_client.poll_ready(cx).map_err(Into::into) } -} - -#[maybe_async] -impl<'a> Request for ReqwestRequest<'a> { - type ResponseBody = reqwest::Body; - - async fn response(&self) -> Result, S3Error> { - let response = self.client.execute(self.build()?).await?; - - if cfg!(feature = "fail-on-err") && !response.status().is_success() { - let status = response.status().as_u16(); - let text = response.text().await?; - return Err(S3Error::HttpFailWithBody(status, text)); - } - - Ok(response.into()) - } -} -impl ReqwestBackend { - pub(crate) fn request<'a>(&self, request: http::Request>) -> ReqwestRequest<'a> { - ReqwestRequest { - request, - client: self.http_client.clone(), - sync: false, + fn call(&mut self, request: http::Request>) -> Self::Future { + match request.map(|b| b.into_owned()).try_into() { + Ok::(request) => { + let fut = self.http_client.call(request); + Box::pin(async move { + let response = fut.await?; + if cfg!(feature = "fail-on-err") && !response.status().is_success() { + let status = response.status().as_u16(); + let text = response.text().await?; + return Err(S3Error::HttpFailWithBody(status, text)); + } + Ok(response.into()) + }) + } + Err(e) => Box::pin(std::future::ready(Err(e.into()))), } } } #[derive(Clone, Debug, Default)] -pub struct ReqwestBackend { - http_client: reqwest::Client, +pub struct ReqwestBackend { + http_client: T, client_options: ClientOptions, } -impl ReqwestBackend { +impl ReqwestBackend { pub fn with_request_timeout(&self, request_timeout: Option) -> Result { let client_options = ClientOptions { request_timeout, @@ -175,6 +171,35 @@ mod tests { use super::*; use http_body_util::BodyExt; + #[derive(Clone, Default)] + struct MockReqwestClient; + + impl Service for MockReqwestClient { + type Response = reqwest::Response; + type Error = reqwest::Error; + type Future = + Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut r: reqwest::Request) -> Self::Future { + assert_eq!(r.method(), http::Method::POST); + assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.headers().get("h1").unwrap(), "v1"); + assert_eq!(r.headers().get("h2").unwrap(), "v2"); + Box::pin(async move { + let body = r.body_mut().take().unwrap().collect().await; + assert_eq!(body.unwrap().to_bytes().as_ref(), b"sneaky"); + Ok(http::Response::builder() + .body(reqwest::Body::from("")) + .unwrap() + .into()) + }) + } + } + #[tokio::test] async fn test_build() { let http_request = http::Request::builder() @@ -185,16 +210,15 @@ mod tests { .body(b"sneaky".into()) .unwrap(); - let mut r = ReqwestBackend::default() - .request(http_request) - .build() + let mut backend = ReqwestBackend { + http_client: MockReqwestClient, + ..Default::default() + }; + crate::utils::service_ready::Ready::new(&mut backend) + .await + .unwrap() + .call(http_request) + .await .unwrap(); - - assert_eq!(r.method(), http::Method::POST); - assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); - assert_eq!(r.headers().get("h1").unwrap(), "v1"); - assert_eq!(r.headers().get("h2").unwrap(), "v2"); - let body = r.body_mut().take().unwrap().collect().await; - assert_eq!(body.unwrap().to_bytes().as_ref(), b"sneaky"); } } diff --git a/s3/src/utils/mod.rs b/s3/src/utils/mod.rs index 04918ccf19..cda56b4db8 100644 --- a/s3/src/utils/mod.rs +++ b/s3/src/utils/mod.rs @@ -427,6 +427,40 @@ macro_rules! retry { }}; } +/// Like tower::util::ServiceExt::ready but avoiding a dependency on tower. +#[cfg(not(feature = "sync"))] +pub(crate) mod service_ready { + use std::marker::PhantomData; + use std::pin::Pin; + use std::task::{Context, Poll}; + use tower_service::Service; + + pub struct Ready<'a, T, R>(pub Option<&'a mut T>, pub PhantomData R>); + + impl<'a, T, R> Ready<'a, T, R> { + pub fn new(inner: &'a mut T) -> Self { + Self(Some(inner), PhantomData) + } + } + + impl<'a, T: Service, R> Future for Ready<'a, T, R> { + type Output = Result<&'a mut T, T::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self + .0 + .as_mut() + .expect("poll after Poll::Ready") + .poll_ready(cx) + { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => Poll::Ready(Ok(self.0.take().unwrap())), + } + } + } +} + #[cfg(test)] mod test { use crate::utils::etag_for_path; From 0eb0f374a583c6cdc1708da622fc821fcbe5534c Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Wed, 26 Nov 2025 22:24:21 +0000 Subject: [PATCH 8/9] Rearrange the order of some method in impl Bucket. No functional change. The next change needs to split the Bucket type into different impl blocks with different generics. This textually moves the some of the methods so that they are grouped to fit together into the impl blocks they will need to belong to. Nothing is added or removed. This change exists only to make the diff of the next change easier to follow. --- s3/src/bucket.rs | 537 ++++++++++++++++++++++++----------------------- 1 file changed, 274 insertions(+), 263 deletions(-) diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index c97085c446..881e165e25 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -160,6 +160,84 @@ impl Bucket { &self.backend } + #[maybe_async::maybe_async] + async fn make_presigned(&self, path: &str, command: Command<'_>) -> Result { + self.credentials_refresh().await?; + build_presigned(self, path, command).await + } + + pub fn with_backend(&self, backend: DefaultBackend) -> Bucket { + Bucket { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query: self.extra_query.clone(), + path_style: self.path_style, + listobjects_v2: self.listobjects_v2, + backend, + } + } +} + +impl Bucket { + pub fn with_path_style(&self) -> Box { + Box::new(Bucket { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query: self.extra_query.clone(), + path_style: true, + listobjects_v2: self.listobjects_v2, + backend: self.backend.clone(), + }) + } + + pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result { + Ok(Bucket { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers, + extra_query: self.extra_query.clone(), + path_style: self.path_style, + listobjects_v2: self.listobjects_v2, + backend: self.backend.clone(), + }) + } + + pub fn with_extra_query( + &self, + extra_query: HashMap, + ) -> Result { + Ok(Bucket { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query, + path_style: self.path_style, + listobjects_v2: self.listobjects_v2, + backend: self.backend.clone(), + }) + } + + pub fn with_listobjects_v1(&self) -> Bucket { + Bucket { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query: self.extra_query.clone(), + path_style: self.path_style, + listobjects_v2: false, + backend: self.backend.clone(), + } + } +} + +impl Bucket { #[maybe_async::async_impl] async fn exec_request( &self, @@ -190,12 +268,6 @@ impl Bucket { let http_request = build_request(self, path, command).await?; self.exec_request(http_request).await } - - #[maybe_async::maybe_async] - async fn make_presigned(&self, path: &str, command: Command<'_>) -> Result { - self.credentials_refresh().await?; - build_presigned(self, path, command).await - } } fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> { @@ -211,164 +283,6 @@ fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> { block_on("async-std") )] impl Bucket { - /// Get a presigned url for getting object on a given path - /// - /// # Example: - /// - /// ```no_run - /// use std::collections::HashMap; - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// // Add optional custom queries - /// let mut custom_queries = HashMap::new(); - /// custom_queries.insert( - /// "response-content-disposition".into(), - /// "attachment; filename=\"test.png\"".into(), - /// ); - /// - /// let url = bucket.presign_get("/test.file", 86400, Some(custom_queries)).await.unwrap(); - /// println!("Presigned url: {}", url); - /// } - /// ``` - #[maybe_async::maybe_async] - pub async fn presign_get>( - &self, - path: S, - expiry_secs: u32, - custom_queries: Option>, - ) -> Result { - validate_expiry(expiry_secs)?; - self.make_presigned( - path.as_ref(), - Command::PresignGet { - expiry_secs, - custom_queries, - }, - ) - .await - } - - /// Get a presigned url for posting an object to a given path - /// - /// # Example: - /// - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// use s3::post_policy::*; - /// use std::borrow::Cow; - /// - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// let post_policy = PostPolicy::new(86400).condition( - /// PostPolicyField::Key, - /// PostPolicyValue::StartsWith(Cow::from("user/user1/")) - /// ).unwrap(); - /// - /// let presigned_post = bucket.presign_post(post_policy).await.unwrap(); - /// println!("Presigned url: {}, fields: {:?}", presigned_post.url, presigned_post.fields); - /// } - /// ``` - #[maybe_async::maybe_async] - #[allow(clippy::needless_lifetimes)] - pub async fn presign_post<'a>( - &self, - post_policy: PostPolicy<'a>, - ) -> Result { - post_policy.sign(Box::new(self.clone())).await - } - - /// Get a presigned url for putting object to a given path - /// - /// # Example: - /// - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// use http::HeaderMap; - /// use http::header::HeaderName; - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// // Add optional custom headers - /// let mut custom_headers = HeaderMap::new(); - /// custom_headers.insert( - /// HeaderName::from_static("custom_header"), - /// "custom_value".parse().unwrap(), - /// ); - /// - /// let url = bucket.presign_put("/test.file", 86400, Some(custom_headers), None).await.unwrap(); - /// println!("Presigned url: {}", url); - /// } - /// ``` - #[maybe_async::maybe_async] - pub async fn presign_put>( - &self, - path: S, - expiry_secs: u32, - custom_headers: Option, - custom_queries: Option>, - ) -> Result { - validate_expiry(expiry_secs)?; - self.make_presigned( - path.as_ref(), - Command::PresignPut { - expiry_secs, - custom_headers, - custom_queries, - }, - ) - .await - } - - /// Get a presigned url for deleting object on a given path - /// - /// # Example: - /// - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// - /// - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// let url = bucket.presign_delete("/test.file", 86400).await.unwrap(); - /// println!("Presigned url: {}", url); - /// } - /// ``` - #[maybe_async::maybe_async] - pub async fn presign_delete>( - &self, - path: S, - expiry_secs: u32, - ) -> Result { - validate_expiry(expiry_secs)?; - self.make_presigned(path.as_ref(), Command::PresignDelete { expiry_secs }) - .await - } - /// Create a new `Bucket` and instantiate it /// /// ```no_run @@ -596,52 +510,14 @@ impl Bucket { }) } - /// Delete existing `Bucket` + /// Instantiate an existing `Bucket`. /// /// # Example - /// ```rust,no_run - /// use s3::Bucket; + /// ```no_run + /// use s3::bucket::Bucket; /// use s3::creds::Credentials; - /// use anyhow::Result; /// - /// # #[tokio::main] - /// # async fn main() -> Result<()> { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// // Async variant with `tokio` or `async-std` features - /// bucket.delete().await.unwrap(); - /// // `sync` fature will produce an identical method - /// - /// #[cfg(feature = "sync")] - /// bucket.delete().unwrap(); - /// // Blocking variant, generated with `blocking` feature in combination - /// // with `tokio` or `async-std` features. - /// - /// #[cfg(feature = "blocking")] - /// bucket.delete_blocking().unwrap(); - /// - /// # Ok(()) - /// # } - /// ``` - #[maybe_async::maybe_async] - pub async fn delete(&self) -> Result { - let command = Command::DeleteBucket; - let response = self.make_request("", command).await?; - let response_data = response_data(response, false).await?; - Ok(response_data.status_code()) - } - - /// Instantiate an existing `Bucket`. - /// - /// # Example - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// - /// // Fake credentials so we don't access user's real credentials in tests + /// // Fake credentials so we don't access user's real credentials in tests /// let bucket_name = "rust-s3-test"; /// let region = "us-east-1".parse().unwrap(); /// let credentials = Credentials::default().unwrap(); @@ -688,73 +564,208 @@ impl Bucket { backend: DefaultBackend::default(), }) } +} - pub fn with_path_style(&self) -> Box { - Box::new(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - path_style: true, - listobjects_v2: self.listobjects_v2, - backend: self.backend.clone(), - }) +#[cfg_attr(all(feature = "with-tokio", feature = "blocking"), block_on("tokio"))] +#[cfg_attr( + all(feature = "with-async-std", feature = "blocking"), + block_on("async-std") +)] +impl Bucket { + /// Get a presigned url for getting object on a given path + /// + /// # Example: + /// + /// ```no_run + /// use std::collections::HashMap; + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// // Add optional custom queries + /// let mut custom_queries = HashMap::new(); + /// custom_queries.insert( + /// "response-content-disposition".into(), + /// "attachment; filename=\"test.png\"".into(), + /// ); + /// + /// let url = bucket.presign_get("/test.file", 86400, Some(custom_queries)).await.unwrap(); + /// println!("Presigned url: {}", url); + /// } + /// ``` + #[maybe_async::maybe_async] + pub async fn presign_get>( + &self, + path: S, + expiry_secs: u32, + custom_queries: Option>, + ) -> Result { + validate_expiry(expiry_secs)?; + self.make_presigned( + path.as_ref(), + Command::PresignGet { + expiry_secs, + custom_queries, + }, + ) + .await } - pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result { - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers, - extra_query: self.extra_query.clone(), - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - backend: self.backend.clone(), - }) + /// Get a presigned url for posting an object to a given path + /// + /// # Example: + /// + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// use s3::post_policy::*; + /// use std::borrow::Cow; + /// + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// let post_policy = PostPolicy::new(86400).condition( + /// PostPolicyField::Key, + /// PostPolicyValue::StartsWith(Cow::from("user/user1/")) + /// ).unwrap(); + /// + /// let presigned_post = bucket.presign_post(post_policy).await.unwrap(); + /// println!("Presigned url: {}, fields: {:?}", presigned_post.url, presigned_post.fields); + /// } + /// ``` + #[maybe_async::maybe_async] + #[allow(clippy::needless_lifetimes)] + pub async fn presign_post<'a>( + &self, + post_policy: PostPolicy<'a>, + ) -> Result { + post_policy.sign(Box::new(self.clone())).await } - pub fn with_extra_query( + /// Get a presigned url for putting object to a given path + /// + /// # Example: + /// + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// use http::HeaderMap; + /// use http::header::HeaderName; + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// // Add optional custom headers + /// let mut custom_headers = HeaderMap::new(); + /// custom_headers.insert( + /// HeaderName::from_static("custom_header"), + /// "custom_value".parse().unwrap(), + /// ); + /// + /// let url = bucket.presign_put("/test.file", 86400, Some(custom_headers), None).await.unwrap(); + /// println!("Presigned url: {}", url); + /// } + /// ``` + #[maybe_async::maybe_async] + pub async fn presign_put>( &self, - extra_query: HashMap, - ) -> Result { - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - backend: self.backend.clone(), - }) + path: S, + expiry_secs: u32, + custom_headers: Option, + custom_queries: Option>, + ) -> Result { + validate_expiry(expiry_secs)?; + self.make_presigned( + path.as_ref(), + Command::PresignPut { + expiry_secs, + custom_headers, + custom_queries, + }, + ) + .await } - pub fn with_backend(&self, backend: DefaultBackend) -> Bucket { - Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - backend, - } + /// Get a presigned url for deleting object on a given path + /// + /// # Example: + /// + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// + /// + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// let url = bucket.presign_delete("/test.file", 86400).await.unwrap(); + /// println!("Presigned url: {}", url); + /// } + /// ``` + #[maybe_async::maybe_async] + pub async fn presign_delete>( + &self, + path: S, + expiry_secs: u32, + ) -> Result { + validate_expiry(expiry_secs)?; + self.make_presigned(path.as_ref(), Command::PresignDelete { expiry_secs }) + .await } - pub fn with_listobjects_v1(&self) -> Bucket { - Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - path_style: self.path_style, - listobjects_v2: false, - backend: self.backend.clone(), - } + /// Delete existing `Bucket` + /// + /// # Example + /// ```rust,no_run + /// use s3::Bucket; + /// use s3::creds::Credentials; + /// use anyhow::Result; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// // Async variant with `tokio` or `async-std` features + /// bucket.delete().await.unwrap(); + /// // `sync` fature will produce an identical method + /// + /// #[cfg(feature = "sync")] + /// bucket.delete().unwrap(); + /// // Blocking variant, generated with `blocking` feature in combination + /// // with `tokio` or `async-std` features. + /// + /// #[cfg(feature = "blocking")] + /// bucket.delete_blocking().unwrap(); + /// + /// # Ok(()) + /// # } + /// ``` + #[maybe_async::maybe_async] + pub async fn delete(&self) -> Result { + let command = Command::DeleteBucket; + let response = self.make_request("", command).await?; + let response_data = response_data(response, false).await?; + Ok(response_data.status_code()) } /// Copy file from an S3 path, internally within the same bucket. From 5a48df0c5f4c91aea84064d0fb92d81053d17a89 Mon Sep 17 00:00:00 2001 From: Kim Vandry Date: Wed, 26 Nov 2025 22:36:16 +0000 Subject: [PATCH 9/9] Make the request backend pluggable. let backend = something that impls tower::Service>; let bucket = Bucket::new(...)?.with_backend(backend); There is still plenty to do, for example there is no API yet to create a bucket with a non-default backend, only get an existing bucket. But this is the basic feature. --- s3/src/bucket.rs | 144 +++++++++++++++++++++----------- s3/src/bucket_ops.rs | 6 +- s3/src/error.rs | 6 ++ s3/src/post_policy.rs | 15 ++-- s3/src/put_object_request.rs | 21 +++-- s3/src/request/backend.rs | 38 +++++++++ s3/src/request/request_trait.rs | 14 ++-- s3/src/utils/mod.rs | 3 + s3/src/utils/testing.rs | 36 ++++++++ 9 files changed, 207 insertions(+), 76 deletions(-) create mode 100644 s3/src/utils/testing.rs diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 881e165e25..0f4dc46806 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -46,7 +46,7 @@ use crate::creds::Credentials; use crate::region::Region; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] use crate::request::ResponseDataStream; -use crate::request::backend::DefaultBackend; +use crate::request::backend::{Backend, DefaultBackend}; use crate::request::{ ResponseBody, ResponseData, build_presigned, build_request, response_data, response_data_to_writer, @@ -129,7 +129,7 @@ impl Tag { /// let bucket = Bucket::new(bucket_name, region, credentials); /// ``` #[derive(Clone, Debug)] -pub struct Bucket { +pub struct Bucket { pub name: String, pub region: Region, credentials: Arc>, @@ -137,10 +137,10 @@ pub struct Bucket { pub extra_query: Query, path_style: bool, listobjects_v2: bool, - backend: DefaultBackend, + backend: B, } -impl Bucket { +impl Bucket { #[maybe_async::async_impl] /// Credential refreshing is done automatically, but can be manually triggered. pub async fn credentials_refresh(&self) -> Result<(), S3Error> { @@ -156,7 +156,7 @@ impl Bucket { } } - pub fn backend(&self) -> &DefaultBackend { + pub fn backend(&self) -> &B { &self.backend } @@ -166,7 +166,7 @@ impl Bucket { build_presigned(self, path, command).await } - pub fn with_backend(&self, backend: DefaultBackend) -> Bucket { + pub fn with_backend(&self, backend: T) -> Bucket { Bucket { name: self.name.clone(), region: self.region.clone(), @@ -180,9 +180,9 @@ impl Bucket { } } -impl Bucket { - pub fn with_path_style(&self) -> Box { - Box::new(Bucket { +impl Bucket { + pub fn with_path_style(&self) -> Box { + Box::new(Self { name: self.name.clone(), region: self.region.clone(), credentials: self.credentials.clone(), @@ -194,8 +194,8 @@ impl Bucket { }) } - pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result { - Ok(Bucket { + pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result { + Ok(Self { name: self.name.clone(), region: self.region.clone(), credentials: self.credentials.clone(), @@ -207,11 +207,8 @@ impl Bucket { }) } - pub fn with_extra_query( - &self, - extra_query: HashMap, - ) -> Result { - Ok(Bucket { + pub fn with_extra_query(&self, extra_query: HashMap) -> Result { + Ok(Self { name: self.name.clone(), region: self.region.clone(), credentials: self.credentials.clone(), @@ -223,8 +220,8 @@ impl Bucket { }) } - pub fn with_listobjects_v1(&self) -> Bucket { - Bucket { + pub fn with_listobjects_v1(&self) -> Self { + Self { name: self.name.clone(), region: self.region.clone(), credentials: self.credentials.clone(), @@ -237,15 +234,21 @@ impl Bucket { } } -impl Bucket { +impl>, RB: ResponseBody> Bucket { #[maybe_async::async_impl] async fn exec_request( &self, request: http::Request>, ) -> Result, S3Error> { - use tower_service::Service as _; let mut backend = self.backend.clone(); - retry! { crate::utils::service_ready::Ready::new(&mut backend).await?.call(request.clone()).await } + retry! { + crate::utils::service_ready::Ready::new(&mut backend) + .await + .map_err(Into::into)? + .call(request.clone()) + .await + .map_err(Into::into) + } } #[maybe_async::sync_impl] @@ -253,9 +256,8 @@ impl Bucket { &self, request: http::Request>, ) -> Result, S3Error> { - use crate::request::backend::SyncService as _; let mut backend = self.backend.clone(); - retry! { backend.call(request.clone()) } + retry! { backend.call(request.clone()).map_err(Into::into) } } #[maybe_async::maybe_async] @@ -282,7 +284,7 @@ fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> { all(feature = "with-async-std", feature = "blocking"), block_on("async-std") )] -impl Bucket { +impl Bucket { /// Create a new `Bucket` and instantiate it /// /// ```no_run @@ -321,7 +323,7 @@ impl Bucket { region: Region, credentials: Credentials, config: BucketConfiguration, - ) -> Result { + ) -> Result, S3Error> { let mut config = config; // Check if we should skip location constraint for LocalStack/Minio compatibility @@ -336,7 +338,7 @@ impl Bucket { } let command = Command::CreateBucket { config }; - let bucket = Bucket::new(name, region, credentials)?; + let bucket = Self::new(name, region, credentials)?; let response = bucket.make_request("", command).await?; let data = response_data(response, false).await?; let response_text = data.as_str()?; @@ -385,7 +387,7 @@ impl Bucket { region: Region, credentials: Credentials, ) -> Result { - let dummy_bucket = Bucket::new("", region, credentials)?.with_path_style(); + let dummy_bucket = Self::new("", region, credentials)?.with_path_style(); dummy_bucket._list_buckets().await } @@ -483,7 +485,7 @@ impl Bucket { region: Region, credentials: Credentials, config: BucketConfiguration, - ) -> Result { + ) -> Result, S3Error> { let mut config = config; // Check if we should skip location constraint for LocalStack/Minio compatibility @@ -498,7 +500,7 @@ impl Bucket { } let command = Command::CreateBucket { config }; - let bucket = Bucket::new(name, region, credentials)?.with_path_style(); + let bucket = Self::new(name, region, credentials)?.with_path_style(); let response = bucket.make_request("", command).await?; let response_data = response_data(response, false).await?; let response_text = response_data.to_string()?; @@ -524,12 +526,8 @@ impl Bucket { /// /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); /// ``` - pub fn new( - name: &str, - region: Region, - credentials: Credentials, - ) -> Result, S3Error> { - Ok(Box::new(Bucket { + pub fn new(name: &str, region: Region, credentials: Credentials) -> Result, S3Error> { + Ok(Box::new(Self { name: name.into(), region, credentials: Arc::new(RwLock::new(credentials)), @@ -552,8 +550,8 @@ impl Bucket { /// /// let bucket = Bucket::new_public(bucket_name, region).unwrap(); /// ``` - pub fn new_public(name: &str, region: Region) -> Result { - Ok(Bucket { + pub fn new_public(name: &str, region: Region) -> Result { + Ok(Self { name: name.into(), region, credentials: Arc::new(RwLock::new(Credentials::anonymous()?)), @@ -571,7 +569,7 @@ impl Bucket { all(feature = "with-async-std", feature = "blocking"), block_on("async-std") )] -impl Bucket { +impl>, RB: ResponseBody> Bucket { /// Get a presigned url for getting object on a given path /// /// # Example: @@ -1393,7 +1391,7 @@ impl Bucket { pub fn put_object_stream_builder>( &self, path: S, - ) -> crate::put_object_request::PutObjectStreamRequest<'_> { + ) -> crate::put_object_request::PutObjectStreamRequest<'_, B> { crate::put_object_request::PutObjectStreamRequest::new(self, path) } @@ -2310,7 +2308,7 @@ impl Bucket { &self, path: S, content: &[u8], - ) -> crate::put_object_request::PutObjectRequest<'_> { + ) -> crate::put_object_request::PutObjectRequest<'_, B> { crate::put_object_request::PutObjectRequest::new(self, path, content) } @@ -2734,7 +2732,14 @@ impl Bucket { )) } } +} +#[cfg_attr(all(feature = "with-tokio", feature = "blocking"), block_on("tokio"))] +#[cfg_attr( + all(feature = "with-async-std", feature = "blocking"), + block_on("async-std") +)] +impl Bucket { /// Get path_style field of the Bucket struct pub fn is_path_style(&self) -> bool { self.path_style @@ -2909,14 +2914,17 @@ impl Bucket { #[cfg(test)] mod test { - + use super::DefaultBackend; use crate::BucketConfiguration; use crate::Tag; use crate::creds::Credentials; use crate::post_policy::{PostPolicyField, PostPolicyValue}; use crate::region::Region; + use crate::request::ResponseBody; + use crate::request::backend::Backend; use crate::serde_types::CorsConfiguration; use crate::serde_types::CorsRule; + use crate::utils::testing::AlwaysFailBackend; use crate::{Bucket, PostPolicy}; use http::header::{HeaderMap, HeaderName, HeaderValue, CACHE_CONTROL}; use std::env; @@ -2991,7 +2999,7 @@ mod test { .unwrap() } - fn test_aws_bucket() -> Box { + fn test_aws_bucket() -> Box> { Bucket::new( "rust-s3-test", "eu-central-1".parse().unwrap(), @@ -3000,7 +3008,7 @@ mod test { .unwrap() } - fn test_wasabi_bucket() -> Box { + fn test_wasabi_bucket() -> Box> { Bucket::new( "rust-s3", "wa-eu-central-1".parse().unwrap(), @@ -3009,7 +3017,7 @@ mod test { .unwrap() } - fn test_gc_bucket() -> Box { + fn test_gc_bucket() -> Box> { let mut bucket = Bucket::new( "rust-s3", Region::Custom { @@ -3023,7 +3031,7 @@ mod test { bucket } - fn test_minio_bucket() -> Box { + fn test_minio_bucket() -> Box> { Bucket::new( "rust-s3", Region::Custom { @@ -3037,11 +3045,11 @@ mod test { } #[allow(dead_code)] - fn test_digital_ocean_bucket() -> Box { + fn test_digital_ocean_bucket() -> Box> { Bucket::new("rust-s3", Region::DoFra1, test_digital_ocean_credentials()).unwrap() } - fn test_r2_bucket() -> Box { + fn test_r2_bucket() -> Box> { Bucket::new( "rust-s3", Region::R2 { @@ -3057,7 +3065,11 @@ mod test { } #[maybe_async::maybe_async] - async fn put_head_get_delete_object(bucket: Bucket, head: bool) { + async fn put_head_get_delete_object(bucket: Bucket, head: bool) + where + B: Backend>, + RB: ResponseBody, + { let s3_path = "/+test.file"; let non_existant_path = "/+non_existant.file"; let test: Vec = object(3072); @@ -3098,7 +3110,11 @@ mod test { } #[maybe_async::maybe_async] - async fn put_head_delete_object_with_headers(bucket: Bucket) { + async fn put_head_delete_object_with_headers(bucket: Bucket) + where + B: Backend>, + RB: ResponseBody, + { let s3_path = "/+test.file"; let non_existant_path = "/+non_existant.file"; let test: Vec = object(3072); @@ -3283,7 +3299,11 @@ mod test { // Test multi-part upload #[maybe_async::maybe_async] - async fn streaming_test_put_get_delete_big_object(bucket: Bucket) { + async fn streaming_test_put_get_delete_big_object(bucket: Bucket) + where + B: Backend>, + RB: ResponseBody, + { #[cfg(feature = "with-async-std")] use async_std::fs::File; #[cfg(feature = "with-async-std")] @@ -3407,7 +3427,7 @@ mod test { } #[maybe_async::maybe_async] - async fn streaming_test_put_get_delete_small_object(bucket: Box) { + async fn streaming_test_put_get_delete_small_object(bucket: Box>) { init(); let remote_path = "+stream_test_small"; let content: Vec = object(1000); @@ -4066,4 +4086,26 @@ mod test { let exists = exists_result.unwrap(); assert!(exists, "Test bucket should exist"); } + + #[maybe_async::test( + feature = "sync", + async(all(not(feature = "sync"), feature = "with-tokio"), tokio::test), + async( + all(not(feature = "sync"), feature = "with-async-std"), + async_std::test + ) + )] + async fn test_always_fail_backend() { + init(); + + let credentials = Credentials::anonymous().unwrap(); + let region = "eu-central-1".parse().unwrap(); + let bucket_name = "rust-s3-test"; + + let bucket = Bucket::new(bucket_name, region, credentials) + .unwrap() + .with_backend(AlwaysFailBackend); + let response_data = bucket.get_object("foo").await.unwrap(); + assert_eq!(response_data.status_code(), 418); + } } diff --git a/s3/src/bucket_ops.rs b/s3/src/bucket_ops.rs index 4ea0f76d9c..06131c6798 100644 --- a/s3/src/bucket_ops.rs +++ b/s3/src/bucket_ops.rs @@ -231,13 +231,13 @@ impl BucketConfiguration { } #[allow(dead_code)] -pub struct CreateBucketResponse { - pub bucket: Box, +pub struct CreateBucketResponse { + pub bucket: Box>, pub response_text: String, pub response_code: u16, } -impl CreateBucketResponse { +impl CreateBucketResponse { pub fn success(&self) -> bool { self.response_code == 200 } diff --git a/s3/src/error.rs b/s3/src/error.rs index 5746867154..6e9f8ff507 100644 --- a/s3/src/error.rs +++ b/s3/src/error.rs @@ -74,3 +74,9 @@ pub enum S3Error { #[error("xml serialization error: {0}")] XmlSeError(#[from] quick_xml::SeError), } + +impl From for S3Error { + fn from(_: std::convert::Infallible) -> Self { + unreachable!(); + } +} diff --git a/s3/src/post_policy.rs b/s3/src/post_policy.rs index dd3ef0bc50..616fd304fe 100644 --- a/s3/src/post_policy.rs +++ b/s3/src/post_policy.rs @@ -58,10 +58,10 @@ impl<'a> PostPolicy<'a> { /// Build a finalized post policy with credentials #[maybe_async::maybe_async] - async fn build( + async fn build( &self, now: &OffsetDateTime, - bucket: &Bucket, + bucket: &Bucket, ) -> Result, S3Error> { let access_key = bucket.access_key().await?.ok_or(S3Error::Credentials( CredentialsError::ConfigMissingAccessKeyId, @@ -110,7 +110,7 @@ impl<'a> PostPolicy<'a> { } #[maybe_async::maybe_async] - pub async fn sign(&self, bucket: Box) -> Result { + pub async fn sign(&self, bucket: Box>) -> Result { use hmac::Mac; bucket.credentials_refresh().await?; @@ -457,11 +457,12 @@ mod test { use crate::creds::Credentials; use crate::region::Region; + use crate::utils::testing::AlwaysFailBackend; use crate::utils::with_timestamp; use serde_json::json; - fn test_bucket() -> Box { + fn test_bucket() -> Bucket { Bucket::new( "rust-s3", Region::UsEast1, @@ -475,9 +476,10 @@ mod test { .unwrap(), ) .unwrap() + .with_backend(AlwaysFailBackend) } - fn test_bucket_with_security_token() -> Box { + fn test_bucket_with_security_token() -> Bucket { Bucket::new( "rust-s3", Region::UsEast1, @@ -491,6 +493,7 @@ mod test { .unwrap(), ) .unwrap() + .with_backend(AlwaysFailBackend) } mod conditions { @@ -769,7 +772,7 @@ mod test { let bucket = test_bucket(); let _ts = with_timestamp(1_451_347_200); - let post = policy.sign(bucket).await.unwrap(); + let post = policy.sign(Box::new(bucket)).await.unwrap(); assert_eq!(post.url, "https://rust-s3.s3.amazonaws.com"); assert_eq!( diff --git a/s3/src/put_object_request.rs b/s3/src/put_object_request.rs index 353677bac5..ac1411da63 100644 --- a/s3/src/put_object_request.rs +++ b/s3/src/put_object_request.rs @@ -4,7 +4,8 @@ //! various options including custom headers, content type, and other metadata. use crate::error::S3Error; -use crate::request::{ResponseData, response_data}; +use crate::request::backend::Backend; +use crate::request::{ResponseBody, ResponseData, response_data}; use crate::{Bucket, command::Command}; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -37,17 +38,17 @@ use async_std::io::Read as AsyncRead; /// # } /// ``` #[derive(Debug, Clone)] -pub struct PutObjectRequest<'a> { - bucket: &'a Bucket, +pub struct PutObjectRequest<'a, B> { + bucket: &'a Bucket, path: String, content: Vec, content_type: String, custom_headers: HeaderMap, } -impl<'a> PutObjectRequest<'a> { +impl<'a, B: Backend>, RB: ResponseBody> PutObjectRequest<'a, B> { /// Create a new PUT object request builder - pub(crate) fn new>(bucket: &'a Bucket, path: S, content: &[u8]) -> Self { + pub(crate) fn new>(bucket: &'a Bucket, path: S, content: &[u8]) -> Self { Self { bucket, path: path.as_ref().to_string(), @@ -196,17 +197,19 @@ impl<'a> PutObjectRequest<'a> { /// Builder for streaming PUT operations #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] #[derive(Debug, Clone)] -pub struct PutObjectStreamRequest<'a> { - bucket: &'a Bucket, +pub struct PutObjectStreamRequest<'a, B> { + bucket: &'a Bucket, path: String, content_type: String, custom_headers: HeaderMap, } #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] -impl<'a> PutObjectStreamRequest<'a> { +impl<'a, B: Backend>, RB: ResponseBody> + PutObjectStreamRequest<'a, B> +{ /// Create a new streaming PUT request builder - pub(crate) fn new>(bucket: &'a Bucket, path: S) -> Self { + pub(crate) fn new>(bucket: &'a Bucket, path: S) -> Self { Self { bucket, path: path.as_ref().to_string(), diff --git a/s3/src/request/backend.rs b/s3/src/request/backend.rs index 344be83c22..feefdfd8f0 100644 --- a/s3/src/request/backend.rs +++ b/s3/src/request/backend.rs @@ -1,6 +1,8 @@ use std::borrow::Cow; use std::time::Duration; +use crate::error::S3Error; + #[cfg(feature = "with-async-std")] pub(crate) use crate::request::async_std_backend::SurfBackend as DefaultBackend; #[cfg(feature = "sync")] @@ -24,3 +26,39 @@ pub trait SyncService { fn call(&mut self, _: R) -> Result; } + +#[cfg(not(feature = "sync"))] +pub trait Backend: + for<'a> tower_service::Service< + http::Request>, + Error: Into, + Future: Send, + > + Clone + + Send + + Sync +{ +} + +#[cfg(not(feature = "sync"))] +impl Backend for T where + for<'a> T: tower_service::Service< + http::Request>, + Error: Into, + Future: Send, + > + Clone + + Send + + Sync +{ +} + +#[cfg(feature = "sync")] +pub trait Backend: + for<'a> SyncService>, Error: Into> + Clone +{ +} + +#[cfg(feature = "sync")] +impl Backend for T where + for<'a> T: SyncService>, Error: Into> + Clone +{ +} diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index f29ae9bd13..6592aa0dfc 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -329,15 +329,15 @@ where Ok(status_code.as_u16()) } -struct BuildHelper<'temp, 'body> { - bucket: &'temp Bucket, +struct BuildHelper<'temp, 'body, B> { + bucket: &'temp Bucket, path: &'temp str, command: Command<'body>, datetime: OffsetDateTime, } #[maybe_async::maybe_async] -impl<'temp, 'body> BuildHelper<'temp, 'body> { +impl<'temp, 'body, B> BuildHelper<'temp, 'body, B> { async fn signing_key(&self) -> Result, S3Error> { signing::signing_key( &self.datetime, @@ -963,8 +963,8 @@ fn make_body(command: Command<'_>) -> Result, S3Error> { } #[maybe_async::maybe_async] -pub(crate) async fn build_request<'body>( - bucket: &Bucket, +pub(crate) async fn build_request<'body, B>( + bucket: &Bucket, path: &str, command: Command<'body>, ) -> Result>, S3Error> { @@ -1001,8 +1001,8 @@ pub(crate) async fn build_request<'body>( } #[maybe_async::maybe_async] -pub(crate) async fn build_presigned( - bucket: &Bucket, +pub(crate) async fn build_presigned( + bucket: &Bucket, path: &str, command: Command<'_>, ) -> Result { diff --git a/s3/src/utils/mod.rs b/s3/src/utils/mod.rs index cda56b4db8..8b07a35b69 100644 --- a/s3/src/utils/mod.rs +++ b/s3/src/utils/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +pub mod testing; + mod time_utils; pub use time_utils::*; diff --git a/s3/src/utils/testing.rs b/s3/src/utils/testing.rs new file mode 100644 index 0000000000..791c070ba8 --- /dev/null +++ b/s3/src/utils/testing.rs @@ -0,0 +1,36 @@ +#[derive(Clone)] +pub struct AlwaysFailBackend; + +#[cfg(not(feature = "sync"))] +impl tower_service::Service for AlwaysFailBackend { + type Response = http::Response>; + type Error = http::Error; + type Future = std::future::Ready>; + + fn poll_ready( + &mut self, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, _: R) -> Self::Future { + std::future::ready( + http::Response::builder() + .status(http::StatusCode::IM_A_TEAPOT) + .body(http_body_util::Empty::new()), + ) + } +} + +#[cfg(feature = "sync")] +impl crate::request::backend::SyncService for AlwaysFailBackend { + type Response = http::Response<&'static [u8]>; + type Error = http::Error; + + fn call(&mut self, _: R) -> Result { + http::Response::builder() + .status(http::StatusCode::IM_A_TEAPOT) + .body(b"") + } +}