diff --git a/s3/Cargo.toml b/s3/Cargo.toml index cba235af67..ad199f5fc9 100644 --- a/s3/Cargo.toml +++ b/s3/Cargo.toml @@ -55,6 +55,8 @@ futures-util = { version = "0.3", optional = true, default-features = false } hex = "0.4" hmac = "0.12" http = "1" +http-body = { version = "1.0.1", optional = true } +http-body-util = { version = "0.1.3", optional = true } log = "0.4" maybe-async = { version = "0.2" } md5 = "0.8" @@ -79,6 +81,7 @@ tokio = { version = "1", features = [ "io-util", ], optional = true, default-features = false } tokio-stream = { version = "0.1", optional = true } +tower-service = { version = "0.3.3", optional = true } url = "2" [features] @@ -86,8 +89,8 @@ default = ["fail-on-err", "tags", "tokio-native-tls"] sync = ["attohttpc", "maybe-async/is_sync"] with-async-std-hyper = ["with-async-std", "surf/hyper-client"] -with-async-std = ["async-std", "futures-util", "surf", "sysinfo"] -with-tokio = ["futures-util", "reqwest", "tokio", "tokio/fs", "tokio-stream", "sysinfo"] +with-async-std = ["dep:http-body", "dep:http-body-util", "dep:tower-service", "async-std", "futures-util", "surf", "sysinfo"] +with-tokio = ["dep:http-body", "dep:http-body-util", "dep:tower-service", "futures-util", "reqwest", "tokio", "tokio/fs", "tokio-stream", "sysinfo"] blocking = ["block_on_proc", "tokio/rt", "tokio/rt-multi-thread"] fail-on-err = [] diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 9d240f43ce..0f4dc46806 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -37,8 +37,8 @@ use block_on_proc::block_on; #[cfg(feature = "tags")] use minidom::Element; +use std::borrow::Cow; use std::collections::HashMap; -use std::time::Duration; use crate::bucket_ops::{BucketConfiguration, CreateBucketResponse}; use crate::command::{Command, Multipart}; @@ -46,11 +46,12 @@ use crate::creds::Credentials; use crate::region::Region; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] use crate::request::ResponseDataStream; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::ClientOptions; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::client; -use crate::request::{Request as _, ResponseData}; +use crate::request::backend::{Backend, DefaultBackend}; +use crate::request::{ + ResponseBody, ResponseData, build_presigned, build_request, response_data, + response_data_to_writer, +}; +use crate::retry; use std::str::FromStr; use std::sync::Arc; @@ -66,17 +67,12 @@ use std::sync::RwLock; pub type Query = HashMap; #[cfg(feature = "with-async-std")] -use crate::request::async_std_backend::SurfRequest as RequestImpl; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::ReqwestRequest as RequestImpl; - -#[cfg(feature = "with-async-std")] -use async_std::io::Write as AsyncWrite; +use async_std::{io::Write as AsyncWrite, stream::StreamExt as _}; #[cfg(feature = "with-tokio")] use tokio::io::AsyncWrite; +#[cfg(feature = "with-tokio")] +use tokio_stream::StreamExt as _; -#[cfg(feature = "sync")] -use crate::request::blocking::AttoRequest as RequestImpl; use std::io::Read; #[cfg(feature = "with-tokio")] @@ -102,8 +98,6 @@ use sysinfo::{MemoryRefreshKind, System}; pub const CHUNK_SIZE: usize = 8_388_608; // 8 Mebibytes, min is 5 (5_242_880); -const DEFAULT_REQUEST_TIMEOUT: Option = Some(Duration::from_secs(60)); - #[derive(Debug, PartialEq, Eq)] pub struct Tag { key: String, @@ -135,22 +129,18 @@ impl Tag { /// let bucket = Bucket::new(bucket_name, region, credentials); /// ``` #[derive(Clone, Debug)] -pub struct Bucket { +pub struct Bucket { pub name: String, pub region: Region, credentials: Arc>, pub extra_headers: HeaderMap, pub extra_query: Query, - pub request_timeout: Option, path_style: bool, listobjects_v2: bool, - #[cfg(feature = "with-tokio")] - http_client: reqwest::Client, - #[cfg(feature = "with-tokio")] - client_options: crate::request::tokio_backend::ClientOptions, + backend: B, } -impl Bucket { +impl Bucket { #[maybe_async::async_impl] /// Credential refreshing is done automatically, but can be manually triggered. pub async fn credentials_refresh(&self) -> Result<(), S3Error> { @@ -166,188 +156,135 @@ impl Bucket { } } - #[cfg(feature = "with-tokio")] - pub fn http_client(&self) -> reqwest::Client { - self.http_client.clone() + pub fn backend(&self) -> &B { + &self.backend + } + + #[maybe_async::maybe_async] + async fn make_presigned(&self, path: &str, command: Command<'_>) -> Result { + self.credentials_refresh().await?; + build_presigned(self, path, command).await + } + + pub fn with_backend(&self, backend: T) -> Bucket { + Bucket { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query: self.extra_query.clone(), + path_style: self.path_style, + listobjects_v2: self.listobjects_v2, + backend, + } } } -fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> { - if 604800 < expiry_secs { - return Err(S3Error::MaxExpiry(expiry_secs)); +impl Bucket { + pub fn with_path_style(&self) -> Box { + Box::new(Self { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query: self.extra_query.clone(), + path_style: true, + listobjects_v2: self.listobjects_v2, + backend: self.backend.clone(), + }) + } + + pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result { + Ok(Self { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers, + extra_query: self.extra_query.clone(), + path_style: self.path_style, + listobjects_v2: self.listobjects_v2, + backend: self.backend.clone(), + }) + } + + pub fn with_extra_query(&self, extra_query: HashMap) -> Result { + Ok(Self { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query, + path_style: self.path_style, + listobjects_v2: self.listobjects_v2, + backend: self.backend.clone(), + }) + } + + pub fn with_listobjects_v1(&self) -> Self { + Self { + name: self.name.clone(), + region: self.region.clone(), + credentials: self.credentials.clone(), + extra_headers: self.extra_headers.clone(), + extra_query: self.extra_query.clone(), + path_style: self.path_style, + listobjects_v2: false, + backend: self.backend.clone(), + } } - Ok(()) } -#[cfg_attr(all(feature = "with-tokio", feature = "blocking"), block_on("tokio"))] -#[cfg_attr( - all(feature = "with-async-std", feature = "blocking"), - block_on("async-std") -)] -impl Bucket { - /// Get a presigned url for getting object on a given path - /// - /// # Example: - /// - /// ```no_run - /// use std::collections::HashMap; - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// // Add optional custom queries - /// let mut custom_queries = HashMap::new(); - /// custom_queries.insert( - /// "response-content-disposition".into(), - /// "attachment; filename=\"test.png\"".into(), - /// ); - /// - /// let url = bucket.presign_get("/test.file", 86400, Some(custom_queries)).await.unwrap(); - /// println!("Presigned url: {}", url); - /// } - /// ``` - #[maybe_async::maybe_async] - pub async fn presign_get>( +impl>, RB: ResponseBody> Bucket { + #[maybe_async::async_impl] + async fn exec_request( &self, - path: S, - expiry_secs: u32, - custom_queries: Option>, - ) -> Result { - validate_expiry(expiry_secs)?; - let request = RequestImpl::new( - self, - path.as_ref(), - Command::PresignGet { - expiry_secs, - custom_queries, - }, - ) - .await?; - request.presigned().await + request: http::Request>, + ) -> Result, S3Error> { + let mut backend = self.backend.clone(); + retry! { + crate::utils::service_ready::Ready::new(&mut backend) + .await + .map_err(Into::into)? + .call(request.clone()) + .await + .map_err(Into::into) + } } - /// Get a presigned url for posting an object to a given path - /// - /// # Example: - /// - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// use s3::post_policy::*; - /// use std::borrow::Cow; - /// - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// let post_policy = PostPolicy::new(86400).condition( - /// PostPolicyField::Key, - /// PostPolicyValue::StartsWith(Cow::from("user/user1/")) - /// ).unwrap(); - /// - /// let presigned_post = bucket.presign_post(post_policy).await.unwrap(); - /// println!("Presigned url: {}, fields: {:?}", presigned_post.url, presigned_post.fields); - /// } - /// ``` - #[maybe_async::maybe_async] - #[allow(clippy::needless_lifetimes)] - pub async fn presign_post<'a>( + #[maybe_async::sync_impl] + fn exec_request( &self, - post_policy: PostPolicy<'a>, - ) -> Result { - post_policy.sign(Box::new(self.clone())).await + request: http::Request>, + ) -> Result, S3Error> { + let mut backend = self.backend.clone(); + retry! { backend.call(request.clone()).map_err(Into::into) } } - /// Get a presigned url for putting object to a given path - /// - /// # Example: - /// - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// use http::HeaderMap; - /// use http::header::HeaderName; - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// // Add optional custom headers - /// let mut custom_headers = HeaderMap::new(); - /// custom_headers.insert( - /// HeaderName::from_static("custom_header"), - /// "custom_value".parse().unwrap(), - /// ); - /// - /// let url = bucket.presign_put("/test.file", 86400, Some(custom_headers), None).await.unwrap(); - /// println!("Presigned url: {}", url); - /// } - /// ``` #[maybe_async::maybe_async] - pub async fn presign_put>( + pub(crate) async fn make_request( &self, - path: S, - expiry_secs: u32, - custom_headers: Option, - custom_queries: Option>, - ) -> Result { - validate_expiry(expiry_secs)?; - let request = RequestImpl::new( - self, - path.as_ref(), - Command::PresignPut { - expiry_secs, - custom_headers, - custom_queries, - }, - ) - .await?; - request.presigned().await + path: &str, + command: Command<'_>, + ) -> Result, S3Error> { + self.credentials_refresh().await?; + let http_request = build_request(self, path, command).await?; + self.exec_request(http_request).await } +} - /// Get a presigned url for deleting object on a given path - /// - /// # Example: - /// - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// - /// - /// #[tokio::main] - /// async fn main() { - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// - /// let url = bucket.presign_delete("/test.file", 86400).await.unwrap(); - /// println!("Presigned url: {}", url); - /// } - /// ``` - #[maybe_async::maybe_async] - pub async fn presign_delete>( - &self, - path: S, - expiry_secs: u32, - ) -> Result { - validate_expiry(expiry_secs)?; - let request = - RequestImpl::new(self, path.as_ref(), Command::PresignDelete { expiry_secs }).await?; - request.presigned().await +fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> { + if 604800 < expiry_secs { + return Err(S3Error::MaxExpiry(expiry_secs)); } + Ok(()) +} +#[cfg_attr(all(feature = "with-tokio", feature = "blocking"), block_on("tokio"))] +#[cfg_attr( + all(feature = "with-async-std", feature = "blocking"), + block_on("async-std") +)] +impl Bucket { /// Create a new `Bucket` and instantiate it /// /// ```no_run @@ -386,7 +323,7 @@ impl Bucket { region: Region, credentials: Credentials, config: BucketConfiguration, - ) -> Result { + ) -> Result, S3Error> { let mut config = config; // Check if we should skip location constraint for LocalStack/Minio compatibility @@ -401,14 +338,14 @@ impl Bucket { } let command = Command::CreateBucket { config }; - let bucket = Bucket::new(name, region, credentials)?; - let request = RequestImpl::new(&bucket, "", command).await?; - let response_data = request.response_data(false).await?; - let response_text = response_data.as_str()?; + let bucket = Self::new(name, region, credentials)?; + let response = bucket.make_request("", command).await?; + let data = response_data(response, false).await?; + let response_text = data.as_str()?; Ok(CreateBucketResponse { bucket, response_text: response_text.to_string(), - response_code: response_data.status_code(), + response_code: data.status_code(), }) } @@ -450,7 +387,7 @@ impl Bucket { region: Region, credentials: Credentials, ) -> Result { - let dummy_bucket = Bucket::new("", region, credentials)?.with_path_style(); + let dummy_bucket = Self::new("", region, credentials)?.with_path_style(); dummy_bucket._list_buckets().await } @@ -458,12 +395,12 @@ impl Bucket { /// Used by the public `list_buckets` method to retrieve the list of buckets for the configured client. #[maybe_async::maybe_async] async fn _list_buckets(&self) -> Result { - let request = RequestImpl::new(self, "", Command::ListBuckets).await?; - let response = request.response_data(false).await?; + let response = self.make_request("", Command::ListBuckets).await?; + let data = response_data(response, false).await?; Ok(quick_xml::de::from_str::< crate::bucket_ops::ListBucketsResponse, - >(response.as_str()?)?) + >(data.as_str()?)?) } /// Determine whether the instantiated bucket exists. @@ -525,54 +462,270 @@ impl Bucket { /// let credentials = Credentials::default()?; /// let config = BucketConfiguration::default(); /// - /// // Async variant with `tokio` or `async-std` features - /// let create_bucket_response = Bucket::create_with_path_style(bucket_name, region, credentials, config).await?; + /// // Async variant with `tokio` or `async-std` features + /// let create_bucket_response = Bucket::create_with_path_style(bucket_name, region, credentials, config).await?; + /// + /// // `sync` fature will produce an identical method + /// #[cfg(feature = "sync")] + /// let create_bucket_response = Bucket::create_with_path_style(bucket_name, region, credentials, config)?; + /// + /// # let region: Region = "us-east-1".parse()?; + /// # let credentials = Credentials::default()?; + /// # let config = BucketConfiguration::default(); + /// // Blocking variant, generated with `blocking` feature in combination + /// // with `tokio` or `async-std` features. + /// #[cfg(feature = "blocking")] + /// let create_bucket_response = Bucket::create_with_path_style_blocking(bucket_name, region, credentials, config)?; + /// # Ok(()) + /// # } + /// ``` + #[maybe_async::maybe_async] + pub async fn create_with_path_style( + name: &str, + region: Region, + credentials: Credentials, + config: BucketConfiguration, + ) -> Result, S3Error> { + let mut config = config; + + // Check if we should skip location constraint for LocalStack/Minio compatibility + // This env var allows users to create buckets on S3-compatible services that + // don't support or require location constraints in the request body + let skip_constraint = std::env::var("RUST_S3_SKIP_LOCATION_CONSTRAINT") + .unwrap_or_default() + .to_lowercase(); + + if skip_constraint != "true" && skip_constraint != "1" { + config.set_region(region.clone()); + } + + let command = Command::CreateBucket { config }; + let bucket = Self::new(name, region, credentials)?.with_path_style(); + let response = bucket.make_request("", command).await?; + let response_data = response_data(response, false).await?; + let response_text = response_data.to_string()?; + + Ok(CreateBucketResponse { + bucket, + response_text, + response_code: response_data.status_code(), + }) + } + + /// Instantiate an existing `Bucket`. + /// + /// # Example + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// + /// // Fake credentials so we don't access user's real credentials in tests + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// ``` + pub fn new(name: &str, region: Region, credentials: Credentials) -> Result, S3Error> { + Ok(Box::new(Self { + name: name.into(), + region, + credentials: Arc::new(RwLock::new(credentials)), + extra_headers: HeaderMap::new(), + extra_query: HashMap::new(), + path_style: false, + listobjects_v2: true, + backend: DefaultBackend::default(), + })) + } + + /// Instantiate a public existing `Bucket`. + /// + /// # Example + /// ```no_run + /// use s3::bucket::Bucket; + /// + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// + /// let bucket = Bucket::new_public(bucket_name, region).unwrap(); + /// ``` + pub fn new_public(name: &str, region: Region) -> Result { + Ok(Self { + name: name.into(), + region, + credentials: Arc::new(RwLock::new(Credentials::anonymous()?)), + extra_headers: HeaderMap::new(), + extra_query: HashMap::new(), + path_style: false, + listobjects_v2: true, + backend: DefaultBackend::default(), + }) + } +} + +#[cfg_attr(all(feature = "with-tokio", feature = "blocking"), block_on("tokio"))] +#[cfg_attr( + all(feature = "with-async-std", feature = "blocking"), + block_on("async-std") +)] +impl>, RB: ResponseBody> Bucket { + /// Get a presigned url for getting object on a given path + /// + /// # Example: + /// + /// ```no_run + /// use std::collections::HashMap; + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// // Add optional custom queries + /// let mut custom_queries = HashMap::new(); + /// custom_queries.insert( + /// "response-content-disposition".into(), + /// "attachment; filename=\"test.png\"".into(), + /// ); + /// + /// let url = bucket.presign_get("/test.file", 86400, Some(custom_queries)).await.unwrap(); + /// println!("Presigned url: {}", url); + /// } + /// ``` + #[maybe_async::maybe_async] + pub async fn presign_get>( + &self, + path: S, + expiry_secs: u32, + custom_queries: Option>, + ) -> Result { + validate_expiry(expiry_secs)?; + self.make_presigned( + path.as_ref(), + Command::PresignGet { + expiry_secs, + custom_queries, + }, + ) + .await + } + + /// Get a presigned url for posting an object to a given path + /// + /// # Example: + /// + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// use s3::post_policy::*; + /// use std::borrow::Cow; + /// + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// let post_policy = PostPolicy::new(86400).condition( + /// PostPolicyField::Key, + /// PostPolicyValue::StartsWith(Cow::from("user/user1/")) + /// ).unwrap(); + /// + /// let presigned_post = bucket.presign_post(post_policy).await.unwrap(); + /// println!("Presigned url: {}, fields: {:?}", presigned_post.url, presigned_post.fields); + /// } + /// ``` + #[maybe_async::maybe_async] + #[allow(clippy::needless_lifetimes)] + pub async fn presign_post<'a>( + &self, + post_policy: PostPolicy<'a>, + ) -> Result { + post_policy.sign(Box::new(self.clone())).await + } + + /// Get a presigned url for putting object to a given path + /// + /// # Example: + /// + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// use http::HeaderMap; + /// use http::header::HeaderName; + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); + /// + /// // Add optional custom headers + /// let mut custom_headers = HeaderMap::new(); + /// custom_headers.insert( + /// HeaderName::from_static("custom_header"), + /// "custom_value".parse().unwrap(), + /// ); + /// + /// let url = bucket.presign_put("/test.file", 86400, Some(custom_headers), None).await.unwrap(); + /// println!("Presigned url: {}", url); + /// } + /// ``` + #[maybe_async::maybe_async] + pub async fn presign_put>( + &self, + path: S, + expiry_secs: u32, + custom_headers: Option, + custom_queries: Option>, + ) -> Result { + validate_expiry(expiry_secs)?; + self.make_presigned( + path.as_ref(), + Command::PresignPut { + expiry_secs, + custom_headers, + custom_queries, + }, + ) + .await + } + + /// Get a presigned url for deleting object on a given path + /// + /// # Example: + /// + /// ```no_run + /// use s3::bucket::Bucket; + /// use s3::creds::Credentials; + /// /// - /// // `sync` fature will produce an identical method - /// #[cfg(feature = "sync")] - /// let create_bucket_response = Bucket::create_with_path_style(bucket_name, region, credentials, config)?; + /// #[tokio::main] + /// async fn main() { + /// let bucket_name = "rust-s3-test"; + /// let region = "us-east-1".parse().unwrap(); + /// let credentials = Credentials::default().unwrap(); + /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); /// - /// # let region: Region = "us-east-1".parse()?; - /// # let credentials = Credentials::default()?; - /// # let config = BucketConfiguration::default(); - /// // Blocking variant, generated with `blocking` feature in combination - /// // with `tokio` or `async-std` features. - /// #[cfg(feature = "blocking")] - /// let create_bucket_response = Bucket::create_with_path_style_blocking(bucket_name, region, credentials, config)?; - /// # Ok(()) - /// # } + /// let url = bucket.presign_delete("/test.file", 86400).await.unwrap(); + /// println!("Presigned url: {}", url); + /// } /// ``` #[maybe_async::maybe_async] - pub async fn create_with_path_style( - name: &str, - region: Region, - credentials: Credentials, - config: BucketConfiguration, - ) -> Result { - let mut config = config; - - // Check if we should skip location constraint for LocalStack/Minio compatibility - // This env var allows users to create buckets on S3-compatible services that - // don't support or require location constraints in the request body - let skip_constraint = std::env::var("RUST_S3_SKIP_LOCATION_CONSTRAINT") - .unwrap_or_default() - .to_lowercase(); - - if skip_constraint != "true" && skip_constraint != "1" { - config.set_region(region.clone()); - } - - let command = Command::CreateBucket { config }; - let bucket = Bucket::new(name, region, credentials)?.with_path_style(); - let request = RequestImpl::new(&bucket, "", command).await?; - let response_data = request.response_data(false).await?; - let response_text = response_data.to_string()?; - - Ok(CreateBucketResponse { - bucket, - response_text, - response_code: response_data.status_code(), - }) + pub async fn presign_delete>( + &self, + path: S, + expiry_secs: u32, + ) -> Result { + validate_expiry(expiry_secs)?; + self.make_presigned(path.as_ref(), Command::PresignDelete { expiry_secs }) + .await } /// Delete existing `Bucket` @@ -608,263 +761,11 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn delete(&self) -> Result { let command = Command::DeleteBucket; - let request = RequestImpl::new(self, "", command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request("", command).await?; + let response_data = response_data(response, false).await?; Ok(response_data.status_code()) } - /// Instantiate an existing `Bucket`. - /// - /// # Example - /// ```no_run - /// use s3::bucket::Bucket; - /// use s3::creds::Credentials; - /// - /// // Fake credentials so we don't access user's real credentials in tests - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// let credentials = Credentials::default().unwrap(); - /// - /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); - /// ``` - pub fn new( - name: &str, - region: Region, - credentials: Credentials, - ) -> Result, S3Error> { - #[cfg(feature = "with-tokio")] - let options = ClientOptions::default(); - - Ok(Box::new(Bucket { - name: name.into(), - region, - credentials: Arc::new(RwLock::new(credentials)), - extra_headers: HeaderMap::new(), - extra_query: HashMap::new(), - request_timeout: DEFAULT_REQUEST_TIMEOUT, - path_style: false, - listobjects_v2: true, - #[cfg(feature = "with-tokio")] - http_client: client(&options)?, - #[cfg(feature = "with-tokio")] - client_options: options, - })) - } - - /// Instantiate a public existing `Bucket`. - /// - /// # Example - /// ```no_run - /// use s3::bucket::Bucket; - /// - /// let bucket_name = "rust-s3-test"; - /// let region = "us-east-1".parse().unwrap(); - /// - /// let bucket = Bucket::new_public(bucket_name, region).unwrap(); - /// ``` - pub fn new_public(name: &str, region: Region) -> Result { - #[cfg(feature = "with-tokio")] - let options = ClientOptions::default(); - - Ok(Bucket { - name: name.into(), - region, - credentials: Arc::new(RwLock::new(Credentials::anonymous()?)), - extra_headers: HeaderMap::new(), - extra_query: HashMap::new(), - request_timeout: DEFAULT_REQUEST_TIMEOUT, - path_style: false, - listobjects_v2: true, - #[cfg(feature = "with-tokio")] - http_client: client(&options)?, - #[cfg(feature = "with-tokio")] - client_options: options, - }) - } - - pub fn with_path_style(&self) -> Box { - Box::new(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: true, - listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), - }) - } - - pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result { - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers, - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), - }) - } - - pub fn with_extra_query( - &self, - extra_query: HashMap, - ) -> Result { - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query, - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), - }) - } - - #[cfg(not(feature = "with-tokio"))] - pub fn with_request_timeout(&self, request_timeout: Duration) -> Result, S3Error> { - Ok(Box::new(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: Some(request_timeout), - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - })) - } - - #[cfg(feature = "with-tokio")] - pub fn with_request_timeout(&self, request_timeout: Duration) -> Result, S3Error> { - let options = ClientOptions { - request_timeout: Some(request_timeout), - ..Default::default() - }; - - Ok(Box::new(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: Some(request_timeout), - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - #[cfg(feature = "with-tokio")] - http_client: client(&options)?, - #[cfg(feature = "with-tokio")] - client_options: options, - })) - } - - pub fn with_listobjects_v1(&self) -> Bucket { - Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: false, - #[cfg(feature = "with-tokio")] - http_client: self.http_client(), - #[cfg(feature = "with-tokio")] - client_options: self.client_options.clone(), - } - } - - /// Configures a bucket to accept invalid SSL certificates and hostnames. - /// - /// This method is available only when either the `tokio-native-tls` or `tokio-rustls-tls` feature is enabled. - /// - /// # Parameters - /// - /// - `accept_invalid_certs`: A boolean flag that determines whether the client should accept invalid SSL certificates. - /// - `accept_invalid_hostnames`: A boolean flag that determines whether the client should accept invalid hostnames. - /// - /// # Returns - /// - /// Returns a `Result` containing the newly configured `Bucket` instance if successful, or an `S3Error` if an error occurs during client configuration. - /// - /// # Errors - /// - /// This function returns an `S3Error` if the HTTP client configuration fails. - /// - /// # Example - /// - /// ```rust - /// # use s3::bucket::Bucket; - /// # use s3::error::S3Error; - /// # use s3::creds::Credentials; - /// # use s3::Region; - /// # use std::str::FromStr; - /// - /// # fn example() -> Result<(), S3Error> { - /// let bucket = Bucket::new("my-bucket", Region::from_str("us-east-1")?, Credentials::default()?)? - /// .set_dangereous_config(true, true)?; - /// # Ok(()) - /// # } - /// - #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] - pub fn set_dangereous_config( - &self, - accept_invalid_certs: bool, - accept_invalid_hostnames: bool, - ) -> Result { - let mut options = self.client_options.clone(); - options.accept_invalid_certs = accept_invalid_certs; - options.accept_invalid_hostnames = accept_invalid_hostnames; - - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - http_client: client(&options)?, - client_options: options, - }) - } - - #[cfg(feature = "with-tokio")] - pub fn set_proxy(&self, proxy: reqwest::Proxy) -> Result { - let mut options = self.client_options.clone(); - options.proxy = Some(proxy); - - Ok(Bucket { - name: self.name.clone(), - region: self.region.clone(), - credentials: self.credentials.clone(), - extra_headers: self.extra_headers.clone(), - extra_query: self.extra_query.clone(), - request_timeout: self.request_timeout, - path_style: self.path_style, - listobjects_v2: self.listobjects_v2, - http_client: client(&options)?, - client_options: options, - }) - } - /// Copy file from an S3 path, internally within the same bucket. /// /// # Example: @@ -915,8 +816,8 @@ impl Bucket { let command = Command::CopyObject { from: from.as_ref(), }; - let request = RequestImpl::new(self, to.as_ref(), command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request(to.as_ref(), command).await?; + let response_data = response_data(response, false).await?; Ok(response_data.status_code()) } @@ -954,8 +855,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn get_object>(&self, path: S) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] @@ -969,12 +870,12 @@ impl Bucket { expected_bucket_owner: expected_bucket_owner.to_string(), version_id, }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; + let response = self.make_request(path.as_ref(), command).await?; - let response = request.response_data(false).await?; + let response_data = response_data(response, false).await?; Ok(quick_xml::de::from_str::( - response.as_str()?, + response_data.as_str()?, )?) } @@ -1023,8 +924,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn object_exists>(&self, path: S) -> Result { let command = Command::HeadObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - let response_data = match request.response_data(false).await { + let response = self.make_request(path.as_ref(), command).await?; + let response_data = match response_data(response, false).await { Ok(response_data) => response_data, Err(S3Error::HttpFailWithBody(status_code, error)) => { if status_code == 404 { @@ -1047,8 +948,8 @@ impl Bucket { expected_bucket_owner: expected_bucket_owner.to_string(), configuration: cors_config.clone(), }; - let request = RequestImpl::new(self, "", command).await?; - request.response_data(false).await + let response = self.make_request("", command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] @@ -1059,10 +960,10 @@ impl Bucket { let command = Command::GetBucketCors { expected_bucket_owner: expected_bucket_owner.to_string(), }; - let request = RequestImpl::new(self, "", command).await?; - let response = request.response_data(false).await?; + let response = self.make_request("", command).await?; + let response_data = response_data(response, false).await?; Ok(quick_xml::de::from_str::( - response.as_str()?, + response_data.as_str()?, )?) } @@ -1074,16 +975,16 @@ impl Bucket { let command = Command::DeleteBucketCors { expected_bucket_owner: expected_bucket_owner.to_string(), }; - let request = RequestImpl::new(self, "", command).await?; - request.response_data(false).await + let response = self.make_request("", command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] pub async fn get_bucket_lifecycle(&self) -> Result { - let request = RequestImpl::new(self, "", Command::GetBucketLifecycle).await?; - let response = request.response_data(false).await?; + let response = self.make_request("", Command::GetBucketLifecycle).await?; + let response_data = response_data(response, false).await?; Ok(quick_xml::de::from_str::( - response.as_str()?, + response_data.as_str()?, )?) } @@ -1095,14 +996,16 @@ impl Bucket { let command = Command::PutBucketLifecycle { configuration: lifecycle_config, }; - let request = RequestImpl::new(self, "", command).await?; - request.response_data(false).await + let response = self.make_request("", command).await?; + response_data(response, false).await } #[maybe_async::maybe_async] pub async fn delete_bucket_lifecycle(&self) -> Result { - let request = RequestImpl::new(self, "", Command::DeleteBucketLifecycle).await?; - request.response_data(false).await + let response = self + .make_request("", Command::DeleteBucketLifecycle) + .await?; + response_data(response, false).await } /// Gets torrent from an S3 path. @@ -1142,8 +1045,8 @@ impl Bucket { path: S, ) -> Result { let command = Command::GetObjectTorrent; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Gets specified inclusive byte range of file from an S3 path. @@ -1190,8 +1093,8 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Stream range of bytes from S3 path to a local file, generic over T: Write. @@ -1251,8 +1154,8 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data_to_writer(writer).await + let response = self.make_request(path.as_ref(), command).await?; + response_data_to_writer(response, writer).await } #[maybe_async::sync_impl] @@ -1268,8 +1171,8 @@ impl Bucket { } let command = Command::GetObjectRange { start, end }; - let request = RequestImpl::new(self, path.as_ref(), command)?; - request.response_data_to_writer(writer) + let response = self.make_request(path.as_ref(), command)?; + response_data_to_writer(response, writer) } /// Stream file from S3 path to a local file, generic over T: Write. @@ -1316,8 +1219,8 @@ impl Bucket { writer: &mut T, ) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data_to_writer(writer).await + let response = self.make_request(path.as_ref(), command).await?; + response_data_to_writer(response, writer).await } #[maybe_async::sync_impl] @@ -1327,8 +1230,8 @@ impl Bucket { writer: &mut T, ) -> Result { let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command)?; - request.response_data_to_writer(writer) + let response = self.make_request(path.as_ref(), command)?; + response_data_to_writer(response, writer) } /// Stream file from S3 path to a local file using an async stream. @@ -1377,9 +1280,19 @@ impl Bucket { &self, path: S, ) -> Result { + use http_body_util::BodyExt; let command = Command::GetObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data_to_stream().await + let response = self.make_request(path.as_ref(), command).await?; + let status_code = response.status(); + let stream = response.into_body().into_data_stream().map(|i| match i { + Ok(data) => Ok(data.into()), + Err(e) => Err(e.into()), + }); + + Ok(ResponseDataStream { + bytes: Box::pin(stream), + status_code: status_code.as_u16(), + }) } /// Stream file from local path to s3, generic over T: Write. @@ -1478,7 +1391,7 @@ impl Bucket { pub fn put_object_stream_builder>( &self, path: S, - ) -> crate::put_object_request::PutObjectStreamRequest<'_> { + ) -> crate::put_object_request::PutObjectStreamRequest<'_, B> { crate::put_object_request::PutObjectStreamRequest::new(self, path) } @@ -1581,8 +1494,8 @@ impl Bucket { custom_headers: None, content_type, }; - let request = RequestImpl::new(self, path, command).await?; - request.response_data(true).await + let response = self.make_request(path, command).await?; + response_data(response, true).await } #[maybe_async::async_impl] @@ -1674,7 +1587,7 @@ impl Bucket { // Use FuturesUnordered for bounded parallelism use futures_util::FutureExt; - use futures_util::stream::{FuturesUnordered, StreamExt}; + use futures_util::stream::FuturesUnordered; let mut part_number: u32 = 0; let mut total_size = 0; @@ -1861,14 +1774,13 @@ impl Bucket { content_type: &str, ) -> Result { let command = Command::InitiateMultipartUpload { content_type }; - let request = RequestImpl::new(self, s3_path, command).await?; - let response_data = request.response_data(false).await?; - if response_data.status_code() >= 300 { - return Err(error_from_response_data(response_data)?); + let response = self.make_request(s3_path, command).await?; + let data = response_data(response, false).await?; + if data.status_code() >= 300 { + return Err(error_from_response_data(data)?); } - let msg: InitiateMultipartUploadResponse = - quick_xml::de::from_str(response_data.as_str()?)?; + let msg: InitiateMultipartUploadResponse = quick_xml::de::from_str(data.as_str()?)?; Ok(msg) } @@ -1879,14 +1791,13 @@ impl Bucket { content_type: &str, ) -> Result { let command = Command::InitiateMultipartUpload { content_type }; - let request = RequestImpl::new(self, s3_path, command)?; - let response_data = request.response_data(false)?; - if response_data.status_code() >= 300 { - return Err(error_from_response_data(response_data)?); + let response = self.make_request(s3_path, command)?; + let data = response_data(response, false)?; + if data.status_code() >= 300 { + return Err(error_from_response_data(data)?); } - let msg: InitiateMultipartUploadResponse = - quick_xml::de::from_str(response_data.as_str()?)?; + let msg: InitiateMultipartUploadResponse = quick_xml::de::from_str(data.as_str()?)?; Ok(msg) } @@ -1935,8 +1846,8 @@ impl Bucket { custom_headers: None, content_type, }; - let request = RequestImpl::new(self, path, command).await?; - let response_data = request.response_data(true).await?; + let response = self.make_request(path, command).await?; + let response_data = response_data(response, true).await?; if !(200..300).contains(&response_data.status_code()) { // if chunk upload failed - abort the upload match self.abort_upload(path, upload_id).await { @@ -1971,20 +1882,20 @@ impl Bucket { custom_headers: None, content_type, }; - let request = RequestImpl::new(self, path, command)?; - let response_data = request.response_data(true)?; - if !(200..300).contains(&response_data.status_code()) { + let response = self.make_request(path, command)?; + let data = response_data(response, true)?; + if !(200..300).contains(&data.status_code()) { // if chunk upload failed - abort the upload match self.abort_upload(path, upload_id) { Ok(_) => { - return Err(error_from_response_data(response_data)?); + return Err(error_from_response_data(data)?); } Err(error) => { return Err(error); } } } - let etag = response_data.as_str()?; + let etag = data.as_str()?; Ok(Part { etag: etag.to_string(), part_number, @@ -2001,8 +1912,8 @@ impl Bucket { ) -> Result { let data = CompleteMultipartUploadData { parts }; let complete = Command::CompleteMultipartUpload { upload_id, data }; - let complete_request = RequestImpl::new(self, path, complete).await?; - complete_request.response_data(false).await + let complete_response = self.make_request(path, complete).await?; + response_data(complete_response, false).await } #[maybe_async::sync_impl] @@ -2014,8 +1925,8 @@ impl Bucket { ) -> Result { let data = CompleteMultipartUploadData { parts }; let complete = Command::CompleteMultipartUpload { upload_id, data }; - let complete_request = RequestImpl::new(self, path, complete)?; - complete_request.response_data(false) + let complete_response = self.make_request(path, complete)?; + response_data(complete_response, false) } /// Get Bucket location. @@ -2052,8 +1963,10 @@ impl Bucket { /// ``` #[maybe_async::maybe_async] pub async fn location(&self) -> Result<(Region, u16), S3Error> { - let request = RequestImpl::new(self, "?location", Command::GetBucketLocation).await?; - let response_data = request.response_data(false).await?; + let response = self + .make_request("?location", Command::GetBucketLocation) + .await?; + let response_data = response_data(response, false).await?; let region_string = String::from_utf8_lossy(response_data.as_slice()); let region = match quick_xml::de::from_reader(region_string.as_bytes()) { Ok(r) => { @@ -2112,8 +2025,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn delete_object>(&self, path: S) -> Result { let command = Command::DeleteObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Head object from S3. @@ -2154,10 +2067,10 @@ impl Bucket { path: S, ) -> Result<(HeadObjectResult, u16), S3Error> { let command = Command::HeadObject; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - let (headers, status) = request.response_header().await?; - let header_object = HeadObjectResult::from(&headers); - Ok((header_object, status)) + let response = self.make_request(path.as_ref(), command).await?; + let (head, _) = response.into_parts(); + let header_object = HeadObjectResult::from(&head.headers); + Ok((header_object, head.status.as_u16())) } /// Put into an S3 bucket, with explicit content-type. @@ -2206,8 +2119,8 @@ impl Bucket { custom_headers: None, multipart: None, }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(true).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, true).await } /// Put into an S3 bucket, with explicit content-type and custom headers for the request. @@ -2266,8 +2179,8 @@ impl Bucket { custom_headers, multipart: None, }; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(true).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, true).await } /// Put into an S3 bucket, with custom headers for the request. @@ -2395,7 +2308,7 @@ impl Bucket { &self, path: S, content: &[u8], - ) -> crate::put_object_request::PutObjectRequest<'_> { + ) -> crate::put_object_request::PutObjectRequest<'_, B> { crate::put_object_request::PutObjectRequest::new(self, path, content) } @@ -2460,8 +2373,8 @@ impl Bucket { ) -> Result { let content = self._tags_xml(tags); let command = Command::PutObjectTagging { tags: &content }; - let request = RequestImpl::new(self, path, command).await?; - request.response_data(false).await + let response = self.make_request(path, command).await?; + response_data(response, false).await } /// Delete tags from an S3 object. @@ -2502,8 +2415,8 @@ impl Bucket { path: S, ) -> Result { let command = Command::DeleteObjectTagging; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - request.response_data(false).await + let response = self.make_request(path.as_ref(), command).await?; + response_data(response, false).await } /// Retrieve an S3 object list of tags. @@ -2545,8 +2458,8 @@ impl Bucket { path: S, ) -> Result<(Vec, u16), S3Error> { let command = Command::GetObjectTagging {}; - let request = RequestImpl::new(self, path.as_ref(), command).await?; - let result = request.response_data(false).await?; + let response = self.make_request(path.as_ref(), command).await?; + let result = response_data(response, false).await?; let mut tags = Vec::new(); @@ -2618,8 +2531,8 @@ impl Bucket { max_keys, } }; - let request = RequestImpl::new(self, "/", command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request("/", command).await?; + let response_data = response_data(response, false).await?; let list_bucket_result = quick_xml::de::from_reader(response_data.as_slice())?; Ok((list_bucket_result, response_data.status_code())) @@ -2702,8 +2615,8 @@ impl Bucket { key_marker, max_uploads, }; - let request = RequestImpl::new(self, "/", command).await?; - let response_data = request.response_data(false).await?; + let response = self.make_request("/", command).await?; + let response_data = response_data(response, false).await?; let list_bucket_result = quick_xml::de::from_reader(response_data.as_slice())?; Ok((list_bucket_result, response_data.status_code())) @@ -2806,8 +2719,8 @@ impl Bucket { #[maybe_async::maybe_async] pub async fn abort_upload(&self, key: &str, upload_id: &str) -> Result<(), S3Error> { let abort = Command::AbortMultipartUpload { upload_id }; - let abort_request = RequestImpl::new(self, key, abort).await?; - let response_data = abort_request.response_data(false).await?; + let abort_response = self.make_request(key, abort).await?; + let response_data = response_data(abort_response, false).await?; if (200..300).contains(&response_data.status_code()) { Ok(()) @@ -2819,7 +2732,14 @@ impl Bucket { )) } } +} +#[cfg_attr(all(feature = "with-tokio", feature = "blocking"), block_on("tokio"))] +#[cfg_attr( + all(feature = "with-async-std", feature = "blocking"), + block_on("async-std") +)] +impl Bucket { /// Get path_style field of the Bucket struct pub fn is_path_style(&self) -> bool { self.path_style @@ -2840,16 +2760,6 @@ impl Bucket { self.path_style = false; } - /// Configure bucket to apply this request timeout to all HTTP - /// requests, or no (infinity) timeout if `None`. Defaults to - /// 30 seconds. - /// - /// Only the [`attohttpc`] and the [`hyper`] backends obey this option; - /// async code may instead await with a timeout. - pub fn set_request_timeout(&mut self, timeout: Option) { - self.request_timeout = timeout; - } - /// Configure bucket to use the older ListObjects API /// /// If your provider doesn't support the ListObjectsV2 interface, set this to @@ -3000,22 +2910,21 @@ impl Bucket { pub fn extra_query_mut(&mut self) -> &mut Query { &mut self.extra_query } - - pub fn request_timeout(&self) -> Option { - self.request_timeout - } } #[cfg(test)] mod test { - + use super::DefaultBackend; use crate::BucketConfiguration; use crate::Tag; use crate::creds::Credentials; use crate::post_policy::{PostPolicyField, PostPolicyValue}; use crate::region::Region; + use crate::request::ResponseBody; + use crate::request::backend::Backend; use crate::serde_types::CorsConfiguration; use crate::serde_types::CorsRule; + use crate::utils::testing::AlwaysFailBackend; use crate::{Bucket, PostPolicy}; use http::header::{HeaderMap, HeaderName, HeaderValue, CACHE_CONTROL}; use std::env; @@ -3090,7 +2999,7 @@ mod test { .unwrap() } - fn test_aws_bucket() -> Box { + fn test_aws_bucket() -> Box> { Bucket::new( "rust-s3-test", "eu-central-1".parse().unwrap(), @@ -3099,7 +3008,7 @@ mod test { .unwrap() } - fn test_wasabi_bucket() -> Box { + fn test_wasabi_bucket() -> Box> { Bucket::new( "rust-s3", "wa-eu-central-1".parse().unwrap(), @@ -3108,7 +3017,7 @@ mod test { .unwrap() } - fn test_gc_bucket() -> Box { + fn test_gc_bucket() -> Box> { let mut bucket = Bucket::new( "rust-s3", Region::Custom { @@ -3122,7 +3031,7 @@ mod test { bucket } - fn test_minio_bucket() -> Box { + fn test_minio_bucket() -> Box> { Bucket::new( "rust-s3", Region::Custom { @@ -3136,11 +3045,11 @@ mod test { } #[allow(dead_code)] - fn test_digital_ocean_bucket() -> Box { + fn test_digital_ocean_bucket() -> Box> { Bucket::new("rust-s3", Region::DoFra1, test_digital_ocean_credentials()).unwrap() } - fn test_r2_bucket() -> Box { + fn test_r2_bucket() -> Box> { Bucket::new( "rust-s3", Region::R2 { @@ -3156,7 +3065,11 @@ mod test { } #[maybe_async::maybe_async] - async fn put_head_get_delete_object(bucket: Bucket, head: bool) { + async fn put_head_get_delete_object(bucket: Bucket, head: bool) + where + B: Backend>, + RB: ResponseBody, + { let s3_path = "/+test.file"; let non_existant_path = "/+non_existant.file"; let test: Vec = object(3072); @@ -3197,7 +3110,11 @@ mod test { } #[maybe_async::maybe_async] - async fn put_head_delete_object_with_headers(bucket: Bucket) { + async fn put_head_delete_object_with_headers(bucket: Bucket) + where + B: Backend>, + RB: ResponseBody, + { let s3_path = "/+test.file"; let non_existant_path = "/+non_existant.file"; let test: Vec = object(3072); @@ -3382,7 +3299,11 @@ mod test { // Test multi-part upload #[maybe_async::maybe_async] - async fn streaming_test_put_get_delete_big_object(bucket: Bucket) { + async fn streaming_test_put_get_delete_big_object(bucket: Bucket) + where + B: Backend>, + RB: ResponseBody, + { #[cfg(feature = "with-async-std")] use async_std::fs::File; #[cfg(feature = "with-async-std")] @@ -3506,7 +3427,7 @@ mod test { } #[maybe_async::maybe_async] - async fn streaming_test_put_get_delete_small_object(bucket: Box) { + async fn streaming_test_put_get_delete_small_object(bucket: Box>) { init(); let remote_path = "+stream_test_small"; let content: Vec = object(1000); @@ -4020,6 +3941,7 @@ mod test { #[test] #[ignore] + #[cfg(any(feature = "with-tokio", feature = "sync"))] fn test_builder_composition() { use std::time::Duration; @@ -4028,11 +3950,18 @@ mod test { "eu-central-1".parse().unwrap(), test_aws_credentials(), ) - .unwrap() - .with_request_timeout(Duration::from_secs(10)) .unwrap(); + let bucket = bucket.with_backend( + bucket + .backend() + .with_request_timeout(Some(Duration::from_secs(10))) + .unwrap(), + ); - assert_eq!(bucket.request_timeout(), Some(Duration::from_secs(10))); + assert_eq!( + bucket.backend().request_timeout(), + Some(Duration::from_secs(10)) + ); } #[maybe_async::test( @@ -4102,7 +4031,8 @@ mod test { .with_path_style(); // Set dangerous config (allow invalid certs, allow invalid hostnames) - let bucket = bucket.set_dangereous_config(true, true).unwrap(); + let bucket = + bucket.with_backend(bucket.backend().with_dangereous_config(true, true).unwrap()); // Test that exists() works with the dangerous config // This should not panic or fail due to SSL certificate issues @@ -4156,4 +4086,26 @@ mod test { let exists = exists_result.unwrap(); assert!(exists, "Test bucket should exist"); } + + #[maybe_async::test( + feature = "sync", + async(all(not(feature = "sync"), feature = "with-tokio"), tokio::test), + async( + all(not(feature = "sync"), feature = "with-async-std"), + async_std::test + ) + )] + async fn test_always_fail_backend() { + init(); + + let credentials = Credentials::anonymous().unwrap(); + let region = "eu-central-1".parse().unwrap(); + let bucket_name = "rust-s3-test"; + + let bucket = Bucket::new(bucket_name, region, credentials) + .unwrap() + .with_backend(AlwaysFailBackend); + let response_data = bucket.get_object("foo").await.unwrap(); + assert_eq!(response_data.status_code(), 418); + } } diff --git a/s3/src/bucket_ops.rs b/s3/src/bucket_ops.rs index 4ea0f76d9c..06131c6798 100644 --- a/s3/src/bucket_ops.rs +++ b/s3/src/bucket_ops.rs @@ -231,13 +231,13 @@ impl BucketConfiguration { } #[allow(dead_code)] -pub struct CreateBucketResponse { - pub bucket: Box, +pub struct CreateBucketResponse { + pub bucket: Box>, pub response_text: String, pub response_code: u16, } -impl CreateBucketResponse { +impl CreateBucketResponse { pub fn success(&self) -> bool { self.response_code == 200 } diff --git a/s3/src/error.rs b/s3/src/error.rs index a3191579a7..6e9f8ff507 100644 --- a/s3/src/error.rs +++ b/s3/src/error.rs @@ -21,7 +21,6 @@ pub enum S3Error { UrlParse(#[from] url::ParseError), #[error("io: {0}")] Io(#[from] std::io::Error), - #[cfg(feature = "with-tokio")] #[error("http: {0}")] Http(#[from] http::Error), #[cfg(feature = "with-tokio")] @@ -45,6 +44,8 @@ pub enum S3Error { #[cfg(feature = "with-async-std")] #[error("surf: {0}")] Surf(String), + #[error("{0}")] + InvalidStatusCode(#[from] http::status::InvalidStatusCode), #[cfg(feature = "sync")] #[error("attohttpc: {0}")] Atto(#[from] attohttpc::Error), @@ -73,3 +74,9 @@ pub enum S3Error { #[error("xml serialization error: {0}")] XmlSeError(#[from] quick_xml::SeError), } + +impl From for S3Error { + fn from(_: std::convert::Infallible) -> Self { + unreachable!(); + } +} diff --git a/s3/src/post_policy.rs b/s3/src/post_policy.rs index dd3ef0bc50..616fd304fe 100644 --- a/s3/src/post_policy.rs +++ b/s3/src/post_policy.rs @@ -58,10 +58,10 @@ impl<'a> PostPolicy<'a> { /// Build a finalized post policy with credentials #[maybe_async::maybe_async] - async fn build( + async fn build( &self, now: &OffsetDateTime, - bucket: &Bucket, + bucket: &Bucket, ) -> Result, S3Error> { let access_key = bucket.access_key().await?.ok_or(S3Error::Credentials( CredentialsError::ConfigMissingAccessKeyId, @@ -110,7 +110,7 @@ impl<'a> PostPolicy<'a> { } #[maybe_async::maybe_async] - pub async fn sign(&self, bucket: Box) -> Result { + pub async fn sign(&self, bucket: Box>) -> Result { use hmac::Mac; bucket.credentials_refresh().await?; @@ -457,11 +457,12 @@ mod test { use crate::creds::Credentials; use crate::region::Region; + use crate::utils::testing::AlwaysFailBackend; use crate::utils::with_timestamp; use serde_json::json; - fn test_bucket() -> Box { + fn test_bucket() -> Bucket { Bucket::new( "rust-s3", Region::UsEast1, @@ -475,9 +476,10 @@ mod test { .unwrap(), ) .unwrap() + .with_backend(AlwaysFailBackend) } - fn test_bucket_with_security_token() -> Box { + fn test_bucket_with_security_token() -> Bucket { Bucket::new( "rust-s3", Region::UsEast1, @@ -491,6 +493,7 @@ mod test { .unwrap(), ) .unwrap() + .with_backend(AlwaysFailBackend) } mod conditions { @@ -769,7 +772,7 @@ mod test { let bucket = test_bucket(); let _ts = with_timestamp(1_451_347_200); - let post = policy.sign(bucket).await.unwrap(); + let post = policy.sign(Box::new(bucket)).await.unwrap(); assert_eq!(post.url, "https://rust-s3.s3.amazonaws.com"); assert_eq!( diff --git a/s3/src/put_object_request.rs b/s3/src/put_object_request.rs index 1b48227661..ac1411da63 100644 --- a/s3/src/put_object_request.rs +++ b/s3/src/put_object_request.rs @@ -4,7 +4,8 @@ //! various options including custom headers, content type, and other metadata. use crate::error::S3Error; -use crate::request::{Request as _, ResponseData}; +use crate::request::backend::Backend; +use crate::request::{ResponseBody, ResponseData, response_data}; use crate::{Bucket, command::Command}; use http::{HeaderMap, HeaderName, HeaderValue}; @@ -14,13 +15,6 @@ use tokio::io::AsyncRead; #[cfg(feature = "with-async-std")] use async_std::io::Read as AsyncRead; -#[cfg(feature = "with-async-std")] -use crate::request::async_std_backend::SurfRequest as RequestImpl; -#[cfg(feature = "sync")] -use crate::request::blocking::AttoRequest as RequestImpl; -#[cfg(feature = "with-tokio")] -use crate::request::tokio_backend::ReqwestRequest as RequestImpl; - /// Builder for constructing S3 PUT object requests with custom options /// /// # Example @@ -44,17 +38,17 @@ use crate::request::tokio_backend::ReqwestRequest as RequestImpl; /// # } /// ``` #[derive(Debug, Clone)] -pub struct PutObjectRequest<'a> { - bucket: &'a Bucket, +pub struct PutObjectRequest<'a, B> { + bucket: &'a Bucket, path: String, content: Vec, content_type: String, custom_headers: HeaderMap, } -impl<'a> PutObjectRequest<'a> { +impl<'a, B: Backend>, RB: ResponseBody> PutObjectRequest<'a, B> { /// Create a new PUT object request builder - pub(crate) fn new>(bucket: &'a Bucket, path: S, content: &[u8]) -> Self { + pub(crate) fn new>(bucket: &'a Bucket, path: S, content: &[u8]) -> Self { Self { bucket, path: path.as_ref().to_string(), @@ -195,25 +189,27 @@ impl<'a> PutObjectRequest<'a> { multipart: None, }; - let request = RequestImpl::new(self.bucket, &self.path, command).await?; - request.response_data(true).await + let response = self.bucket.make_request(&self.path, command).await?; + response_data(response, true).await } } /// Builder for streaming PUT operations #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] #[derive(Debug, Clone)] -pub struct PutObjectStreamRequest<'a> { - bucket: &'a Bucket, +pub struct PutObjectStreamRequest<'a, B> { + bucket: &'a Bucket, path: String, content_type: String, custom_headers: HeaderMap, } #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] -impl<'a> PutObjectStreamRequest<'a> { +impl<'a, B: Backend>, RB: ResponseBody> + PutObjectStreamRequest<'a, B> +{ /// Create a new streaming PUT request builder - pub(crate) fn new>(bucket: &'a Bucket, path: S) -> Self { + pub(crate) fn new>(bucket: &'a Bucket, path: S) -> Self { Self { bucket, path: path.as_ref().to_string(), diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index c0a345d93c..ed25a94d08 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -1,290 +1,135 @@ -use async_std::io::Write as AsyncWrite; -use async_std::io::{ReadExt, WriteExt}; -use async_std::stream::StreamExt; use bytes::Bytes; -use futures_util::FutureExt; -use std::collections::HashMap; +use futures_util::AsyncBufRead as _; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tower_service::Service; -use crate::bucket::Bucket; -use crate::command::Command; use crate::error::S3Error; -use crate::utils::now_utc; -use time::OffsetDateTime; -use crate::command::HttpMethod; -use crate::request::{Request, ResponseData, ResponseDataStream}; +use crate::request::backend::BackendRequestBody; -use http::HeaderMap; -use maybe_async::maybe_async; +use http_body::Frame; use surf::http::Method; use surf::http::headers::{HeaderName, HeaderValue}; -// Temporary structure for making a request -pub struct SurfRequest<'a> { - pub bucket: &'a Bucket, - pub path: &'a str, - pub command: Command<'a>, - pub datetime: OffsetDateTime, - pub sync: bool, -} - -#[maybe_async] -impl<'a> Request for SurfRequest<'a> { - type Response = surf::Response; - type HeaderMap = HeaderMap; - - fn datetime(&self) -> OffsetDateTime { - self.datetime - } - - fn bucket(&self) -> Bucket { - self.bucket.clone() - } - - fn command(&self) -> Command<'_> { - self.command.clone() +fn http_request_to_surf_request( + request: http::Request>, +) -> Result { + let url = format!("{}", request.uri()).parse()?; + let mut builder = match *request.method() { + http::Method::GET => surf::Request::builder(Method::Get, url), + http::Method::DELETE => surf::Request::builder(Method::Delete, url), + http::Method::PUT => surf::Request::builder(Method::Put, url), + http::Method::POST => surf::Request::builder(Method::Post, url), + http::Method::HEAD => surf::Request::builder(Method::Head, url), + ref m => surf::Request::builder( + m.as_str() + .parse() + .map_err(|e: surf::Error| S3Error::Surf(e.to_string()))?, + url, + ), } - - fn path(&self) -> String { - self.path.to_string() + .body(request.body().clone().into_owned()); + + for (name, value) in request.headers().iter() { + builder = builder.header( + HeaderName::from_bytes(AsRef::<[u8]>::as_ref(&name).to_vec()) + .expect("Could not parse heaeder name"), + HeaderValue::from_bytes(AsRef::<[u8]>::as_ref(&value).to_vec()) + .expect("Could not parse header value"), + ); } - async fn response(&self) -> Result { - // Build headers - let headers = self.headers().await?; - - let request = match self.command.http_verb() { - HttpMethod::Get => surf::Request::builder(Method::Get, self.url()?), - HttpMethod::Delete => surf::Request::builder(Method::Delete, self.url()?), - HttpMethod::Put => surf::Request::builder(Method::Put, self.url()?), - HttpMethod::Post => surf::Request::builder(Method::Post, self.url()?), - HttpMethod::Head => surf::Request::builder(Method::Head, self.url()?), - }; - - let mut request = request.body(self.request_body()?); - - for (name, value) in headers.iter() { - request = request.header( - HeaderName::from_bytes(AsRef::<[u8]>::as_ref(&name).to_vec()) - .expect("Could not parse heaeder name"), - HeaderValue::from_bytes(AsRef::<[u8]>::as_ref(&value).to_vec()) - .expect("Could not parse header value"), - ); - } - - let response = request - .send() - .await - .map_err(|e| S3Error::Surf(e.to_string()))?; - - if cfg!(feature = "fail-on-err") && !response.status().is_success() { - return Err(S3Error::HttpFail); - } - - Ok(response) - } - - async fn response_data(&self, etag: bool) -> Result { - let mut response = crate::retry! {self.response().await}?; - let status_code = response.status(); - - let response_headers = response - .header_names() - .zip(response.header_values()) - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect::>(); + Ok(builder) +} - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = response.header("ETag") { - Bytes::from(etag.as_str().to_string()) - } else { - Bytes::from("") +pub struct SurfBody(surf::Body); + +impl http_body::Body for SurfBody { + type Data = Bytes; + type Error = std::io::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, std::io::Error>>> { + let mut inner = Pin::new(&mut self.0); + match inner.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(sliceu8)) => { + if sliceu8.is_empty() { + Poll::Ready(None) + } else { + let len = sliceu8.len(); + let frame = Frame::data(Bytes::copy_from_slice(sliceu8)); + inner.as_mut().consume(len); + Poll::Ready(Some(Ok(frame))) + } } - } else { - let body = match response.body_bytes().await { - Ok(bytes) => Ok(Bytes::from(bytes)), - Err(e) => Err(S3Error::Surf(e.to_string())), - }; - body? - }; - Ok(ResponseData::new( - body_vec, - status_code.into(), - response_headers, - )) + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } } +} - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result { - let mut buffer = Vec::new(); - - let response = crate::retry! {self.response().await}?; - - let status_code = response.status(); - - let mut stream = surf::http::Body::from_reader(response, None); - - stream.read_to_end(&mut buffer).await?; - - writer.write_all(&buffer).await?; +impl Service>> for SurfBackend { + type Response = http::Response; + type Error = S3Error; + type Future = Pin> + Send>>; - Ok(status_code.into()) + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } - async fn response_header(&self) -> Result<(HeaderMap, u16), S3Error> { - let mut header_map = HeaderMap::new(); - let response = crate::retry! {self.response().await}?; - let status_code = response.status(); - - for (name, value) in response.iter() { - header_map.insert( - http::header::HeaderName::from_lowercase( - name.to_string().to_ascii_lowercase().as_ref(), - )?, - value.as_str().parse()?, - ); + fn call(&mut self, request: http::Request>) -> Self::Future { + match http_request_to_surf_request(request) { + Ok(request) => { + let fut = request.send(); + Box::pin(async move { + let mut response = fut.await.map_err(|e| S3Error::Surf(e.to_string()))?; + + if cfg!(feature = "fail-on-err") && !response.status().is_success() { + return Err(S3Error::HttpFail); + } + + let mut builder = http::Response::builder() + .status(http::StatusCode::from_u16(response.status().into())?); + for (name, values) in response.iter() { + for value in values { + builder = builder.header(name.as_str(), value.as_str()); + } + } + Ok(builder.body(SurfBody(response.take_body()))?) + }) + } + Err(e) => Box::pin(std::future::ready(Err(e))), } - Ok((header_map, status_code.into())) - } - - async fn response_data_to_stream(&self) -> Result { - let mut response = crate::retry! {self.response().await}?; - let status_code = response.status(); - - let body = response - .take_body() - .bytes() - .filter_map(|n| n.ok()) - .fold(vec![], |mut b, n| { - b.push(n); - b - }) - .then(|b| async move { Ok(Bytes::from(b)) }) - .into_stream(); - - Ok(ResponseDataStream { - bytes: Box::pin(body), - status_code: status_code.into(), - }) } } -impl<'a> SurfRequest<'a> { - pub async fn new<'b>( - bucket: &'b Bucket, - path: &'b str, - command: Command<'b>, - ) -> Result, S3Error> { - bucket.credentials_refresh().await?; - Ok(SurfRequest { - bucket, - path, - command, - datetime: now_utc(), - sync: false, - }) - } -} +#[derive(Clone, Debug, Default)] +pub struct SurfBackend {} #[cfg(test)] mod tests { - use crate::bucket::Bucket; - use crate::command::Command; - use crate::request::Request; - use crate::request::async_std_backend::SurfRequest; - use anyhow::Result; - use awscreds::Credentials; - - // Fake keys - otherwise using Credentials::default will use actual user - // credentials if they exist. - fn fake_credentials() -> Credentials { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() - } - - #[async_std::test] - async fn url_uses_https_by_default() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; - let path = "/my-first/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.url()?.scheme(), "https"); - - let headers = request.headers().await.unwrap(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "my-first-bucket.custom-region".to_string()); - Ok(()) - } + use super::*; #[async_std::test] - async fn url_uses_https_by_default_path_style() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?.with_path_style(); - let path = "/my-first/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await + async fn test_build() { + let http_request = http::Request::builder() + .uri("https://example.com/foo?bar=1") + .method(http::Method::POST) + .header("h1", "v1") + .header("h2", "v2") + .body(b"sneaky".into()) .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "https"); - - let headers = request.headers().await.unwrap(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "custom-region".to_string()); - Ok(()) - } - - #[async_std::test] - async fn url_uses_scheme_from_custom_region_if_defined() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; - let path = "/my-second/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.url().unwrap().scheme(), "http"); - - let headers = request.headers().await.unwrap(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "my-second-bucket.custom-region".to_string()); - Ok(()) - } - - #[async_std::test] - async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?.with_path_style(); - let path = "/my-second/path"; - let request = SurfRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.url().unwrap().scheme(), "http"); - - let headers = request.headers().await.unwrap(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "custom-region".to_string()); + let mut r = http_request_to_surf_request(http_request).unwrap().build(); - Ok(()) + assert_eq!(r.method(), Method::Post); + assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.header("h1").unwrap(), "v1"); + assert_eq!(r.header("h2").unwrap(), "v2"); + let body = r.take_body().into_bytes().await.unwrap(); + assert_eq!(body.as_slice(), b"sneaky"); } } diff --git a/s3/src/request/backend.rs b/s3/src/request/backend.rs new file mode 100644 index 0000000000..feefdfd8f0 --- /dev/null +++ b/s3/src/request/backend.rs @@ -0,0 +1,64 @@ +use std::borrow::Cow; +use std::time::Duration; + +use crate::error::S3Error; + +#[cfg(feature = "with-async-std")] +pub(crate) use crate::request::async_std_backend::SurfBackend as DefaultBackend; +#[cfg(feature = "sync")] +pub(crate) use crate::request::blocking::AttoBackend as DefaultBackend; +#[cfg(feature = "with-tokio")] +pub(crate) use crate::request::tokio_backend::ReqwestBackend as DefaultBackend; + +/// Default request timeout. Override with s3::Bucket::with_request_timeout. +/// +/// For backward compatibility, only AttoBackend uses this. ReqwestBackend +/// supports a timeout but none is set by default. +pub const DEFAULT_REQUEST_TIMEOUT: Option = Some(Duration::from_secs(60)); + +pub type BackendRequestBody<'a> = Cow<'a, [u8]>; + +/// A simplified version of tower_service::Service without async +#[cfg(feature = "sync")] +pub trait SyncService { + type Response; + type Error; + + fn call(&mut self, _: R) -> Result; +} + +#[cfg(not(feature = "sync"))] +pub trait Backend: + for<'a> tower_service::Service< + http::Request>, + Error: Into, + Future: Send, + > + Clone + + Send + + Sync +{ +} + +#[cfg(not(feature = "sync"))] +impl Backend for T where + for<'a> T: tower_service::Service< + http::Request>, + Error: Into, + Future: Send, + > + Clone + + Send + + Sync +{ +} + +#[cfg(feature = "sync")] +pub trait Backend: + for<'a> SyncService>, Error: Into> + Clone +{ +} + +#[cfg(feature = "sync")] +impl Backend for T where + for<'a> T: SyncService>, Error: Into> + Clone +{ +} diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index ac065fd898..b344acab52 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -1,74 +1,52 @@ extern crate base64; extern crate md5; -use std::io; -use std::io::Write; - use attohttpc::header::HeaderName; +use std::time::Duration; -use crate::bucket::Bucket; -use crate::command::Command; use crate::error::S3Error; -use crate::utils::now_utc; -use bytes::Bytes; -use std::collections::HashMap; -use time::OffsetDateTime; - -use crate::command::HttpMethod; -use crate::request::{Request, ResponseData}; - -// Temporary structure for making a request -pub struct AttoRequest<'a> { - pub bucket: &'a Bucket, - pub path: &'a str, - pub command: Command<'a>, - pub datetime: OffsetDateTime, - pub sync: bool, -} +use std::borrow::Cow; -impl<'a> Request for AttoRequest<'a> { - type Response = attohttpc::Response; - type HeaderMap = attohttpc::header::HeaderMap; +use crate::request::backend::{BackendRequestBody, SyncService}; - fn datetime(&self) -> OffsetDateTime { - self.datetime - } - - fn bucket(&self) -> Bucket { - self.bucket.clone() - } +fn http_request_to_atto_request( + request: http::Request>, + request_timeout: Option, +) -> Result>>, S3Error> { + let mut session = attohttpc::Session::new(); - fn command(&self) -> Command<'_> { - self.command.clone() + for (name, value) in request.headers().iter() { + session.header(HeaderName::from_bytes(name.as_ref())?, value.to_str()?); } - fn path(&self) -> String { - self.path.to_string() + if let Some(timeout) = request_timeout { + session.timeout(timeout) } - fn response(&self) -> Result { - // Build headers - let headers = self.headers()?; - - let mut session = attohttpc::Session::new(); - - for (name, value) in headers.iter() { - session.header(HeaderName::from_bytes(name.as_ref())?, value.to_str()?); + let url = format!("{}", request.uri()); + let builder = match *request.method() { + http::Method::GET => session.get(url), + http::Method::DELETE => session.delete(url), + http::Method::PUT => session.put(url), + http::Method::POST => session.post(url), + http::Method::HEAD => session.head(url), + _ => { + return Err(S3Error::HttpFailWithBody(405, "".into())); } + }; - if let Some(timeout) = self.bucket.request_timeout { - session.timeout(timeout) - } + Ok(builder.bytes(request.body().clone())) +} - let request = match self.command.http_verb() { - HttpMethod::Get => session.get(self.url()?), - HttpMethod::Delete => session.delete(self.url()?), - HttpMethod::Put => session.put(self.url()?), - HttpMethod::Post => session.post(self.url()?), - HttpMethod::Head => session.head(self.url()?), - }; +impl SyncService>> for AttoBackend { + type Response = http::Response; + type Error = S3Error; - let response = request.bytes(&self.request_body()?).send()?; + fn call( + &mut self, + request: http::Request>, + ) -> Result, S3Error> { + let response = http_request_to_atto_request(request, self.request_timeout)?.send()?; if cfg!(feature = "fail-on-err") && !response.status().is_success() { let status = response.status().as_u16(); @@ -76,168 +54,60 @@ impl<'a> Request for AttoRequest<'a> { return Err(S3Error::HttpFailWithBody(status, text)); } - Ok(response) - } + let (status, headers, body) = response.split(); + let mut builder = + http::Response::builder().status(http::StatusCode::from_u16(status.into())?); + *builder.headers_mut().unwrap() = headers; - fn response_data(&self, etag: bool) -> Result { - let response = crate::retry! {self.response()}?; - let status_code = response.status().as_u16(); - - let response_headers = response - .headers() - .iter() - .map(|(k, v)| { - ( - k.to_string(), - v.to_str() - .unwrap_or("could-not-decode-header-value") - .to_string(), - ) - }) - .collect::>(); - - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = response.headers().get("ETag") { - Bytes::from(etag.to_str()?.to_string()) - } else { - Bytes::from("") - } - } else { - // HEAD requests don't have a response body - if self.command.http_verb() == HttpMethod::Head { - Bytes::from("") - } else { - Bytes::from(response.bytes()?) - } - }; - Ok(ResponseData::new(body_vec, status_code, response_headers)) + Ok(builder.body(body)?) } +} - fn response_data_to_writer(&self, writer: &mut T) -> Result { - let mut response = crate::retry! {self.response()}?; - - let status_code = response.status(); - io::copy(&mut response, writer)?; +#[derive(Clone, Debug)] +pub struct AttoBackend { + request_timeout: Option, +} - Ok(status_code.as_u16()) +impl Default for AttoBackend { + fn default() -> Self { + Self { + request_timeout: crate::request::backend::DEFAULT_REQUEST_TIMEOUT, + } } +} - fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> { - let response = crate::retry! {self.response()}?; - let status_code = response.status().as_u16(); - let headers = response.headers().clone(); - Ok((headers, status_code)) +impl AttoBackend { + pub fn with_request_timeout(&self, request_timeout: Option) -> Result { + Ok(Self { request_timeout }) } -} -impl<'a> AttoRequest<'a> { - pub fn new<'b>( - bucket: &'b Bucket, - path: &'b str, - command: Command<'b>, - ) -> Result, S3Error> { - bucket.credentials_refresh()?; - Ok(AttoRequest { - bucket, - path, - command, - datetime: now_utc(), - sync: false, - }) + pub fn request_timeout(&self) -> Option { + self.request_timeout } } #[cfg(test)] mod tests { - use crate::bucket::Bucket; - use crate::command::Command; - use crate::request::Request; - use crate::request::blocking::AttoRequest; - use anyhow::Result; - use awscreds::Credentials; - - // Fake keys - otherwise using Credentials::default will use actual user - // credentials if they exist. - fn fake_credentials() -> Credentials { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() - } + use super::*; #[test] - fn url_uses_https_by_default() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; - let path = "/my-first/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.url()?.scheme(), "https"); - - let headers = request.headers().unwrap(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "my-first-bucket.custom-region".to_string()); - Ok(()) - } - - #[test] - fn url_uses_https_by_default_path_style() -> Result<()> { - let region = "custom-region".parse()?; - let bucket = Bucket::new("my-first-bucket", region, fake_credentials())?; - bucket.with_path_style(); - let path = "/my-first/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.url()?.scheme(), "https"); - - let headers = request.headers().unwrap(); - let host = headers.get("Host").unwrap(); - - assert_eq!(*host, "custom-region".to_string()); - Ok(()) - } - - #[test] - fn url_uses_scheme_from_custom_region_if_defined() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; - let path = "/my-second/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.url()?.scheme(), "http"); - - let headers = request.headers().unwrap(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "my-second-bucket.custom-region".to_string()); - Ok(()) - } - - #[test] - fn url_uses_scheme_from_custom_region_if_defined_with_path_style() -> Result<()> { - let region = "http://custom-region".parse()?; - let bucket = Bucket::new("my-second-bucket", region, fake_credentials())?; - bucket.with_path_style(); - let path = "/my-second/path"; - let request = AttoRequest::new(&bucket, path, Command::GetObject).unwrap(); - - assert_eq!(request.url()?.scheme(), "http"); - - let headers = request.headers().unwrap(); - let host = headers.get("Host").unwrap(); - assert_eq!(*host, "custom-region".to_string()); - - Ok(()) + fn test_build() { + let http_request = http::Request::builder() + .uri("https://example.com/foo?bar=1") + .method(http::Method::POST) + .header("h1", "v1") + .header("h2", "v2") + .body(b"sneaky".into()) + .unwrap(); + + let mut r = http_request_to_atto_request(http_request, None).unwrap(); + + assert_eq!(r.inspect().method(), http::Method::POST); + assert_eq!(r.inspect().url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.headers().get("h1").unwrap(), "v1"); + assert_eq!(r.headers().get("h2").unwrap(), "v2"); + let mut i = r.inspect(); + let body = &i.body().0; + assert_eq!(&**body, b"sneaky"); } } diff --git a/s3/src/request/mod.rs b/s3/src/request/mod.rs index 45b6e59c76..c11536af50 100644 --- a/s3/src/request/mod.rs +++ b/s3/src/request/mod.rs @@ -1,5 +1,6 @@ #[cfg(feature = "with-async-std")] pub mod async_std_backend; +pub mod backend; #[cfg(feature = "sync")] pub mod blocking; pub mod request_trait; diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index f5c542b5e6..6592aa0dfc 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -1,19 +1,24 @@ use base64::Engine; use base64::engine::general_purpose; use hmac::Mac; +use http::Method; +#[cfg(any(feature = "with-tokio", feature = "with-async-std"))] +use http_body_util::BodyExt; use quick_xml::se::to_string; +use std::borrow::Cow; use std::collections::HashMap; #[cfg(any(feature = "with-tokio", feature = "with-async-std"))] -use std::pin::Pin; +use std::pin::{Pin, pin}; use time::OffsetDateTime; use time::format_description::well_known::Rfc2822; use url::Url; use crate::LONG_DATETIME; use crate::bucket::Bucket; -use crate::command::Command; +use crate::command::{Command, HttpMethod}; use crate::error::S3Error; use crate::signing; +use crate::utils::now_utc; use bytes::Bytes; use http::HeaderMap; use http::header::{ @@ -22,10 +27,16 @@ use http::header::{ use std::fmt::Write as _; #[cfg(feature = "with-async-std")] -use async_std::stream::Stream; +use async_std::stream::{Stream, StreamExt}; #[cfg(feature = "with-tokio")] -use tokio_stream::Stream; +use tokio_stream::{Stream, StreamExt}; + +#[cfg(feature = "with-async-std")] +use async_std::io::{Write as AsyncWrite, WriteExt as _}; + +#[cfg(feature = "with-tokio")] +use tokio::io::{AsyncWrite, AsyncWriteExt as _}; #[derive(Debug)] @@ -185,119 +196,207 @@ impl async_std::io::Read for ResponseDataStream { } } +mod sealed { + pub struct Sealed; +} + +#[cfg(not(feature = "sync"))] +pub trait ResponseBody: + http_body::Body + AsRef<[u8]> + Send, Error: Into + Send> + + Send + + 'static +{ + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed; + + fn into_bytes(self) -> impl Future> + Send + where + Self: Sized, + { + use futures_util::TryFutureExt as _; + self.collect().map_ok(|c| c.to_bytes()) + } +} + +#[cfg(not(feature = "sync"))] +impl ResponseBody for T +where + T: http_body::Body + Send + 'static, + ::Data: Into + AsRef<[u8]> + Send, + ::Error: Into + Send, +{ + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed { + sealed::Sealed + } +} + +#[cfg(feature = "sync")] +pub trait ResponseBody: std::io::Read { + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed; + + fn into_bytes(mut self) -> Result + where + Self: Sized, + { + let mut buf = Vec::new(); + self.read_to_end(&mut buf)?; + Ok(buf.into()) + } +} + +#[cfg(feature = "sync")] +impl ResponseBody for T { + fn unused_trait_should_have_only_blanket_impl() -> sealed::Sealed { + sealed::Sealed + } +} + #[maybe_async::maybe_async] -pub trait Request { - type Response; - type HeaderMap; - - async fn response(&self) -> Result; - async fn response_data(&self, etag: bool) -> Result; - #[cfg(feature = "with-tokio")] - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result; - #[cfg(feature = "with-async-std")] - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result; - #[cfg(feature = "sync")] - fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result; - #[cfg(any(feature = "with-async-std", feature = "with-tokio"))] - async fn response_data_to_stream(&self) -> Result; - async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>; - fn datetime(&self) -> OffsetDateTime; - fn bucket(&self) -> Bucket; - fn command(&self) -> Command<'_>; - fn path(&self) -> String; +pub(crate) async fn response_data( + response: http::Response, + etag: bool, +) -> Result { + let (mut head, body) = response.into_parts(); + let status_code = head.status.as_u16(); + // When etag=true, we extract the ETag header and return it as the body. + // This is used for PUT operations (regular puts, multipart chunks) where: + // 1. S3 returns an empty or non-useful response body + // 2. The ETag header contains the essential information we need + // 3. The calling code expects to get the ETag via response_data.as_str() + // + // Note: This approach means we discard any actual response body when etag=true, + // but for the operations that use this (PUTs), the body is typically empty + // or contains redundant information already available in headers. + // + // TODO: Refactor this to properly return the response body and access ETag + // from headers instead of replacing the body. This would be a breaking change. + let body_vec = if etag { + if let Some(etag) = head.headers.remove("ETag") { + Bytes::from(etag.to_str()?.to_string()) + } else { + Bytes::from("") + } + } else { + body.into_bytes().await.map_err(Into::::into)? + }; + let response_headers = head + .headers + .into_iter() + .filter_map(|(k, v)| { + k.map(|k| { + ( + k.to_string(), + v.to_str() + .unwrap_or("could-not-decode-header-value") + .to_string(), + ) + }) + }) + .collect::>(); + Ok(ResponseData::new(body_vec, status_code, response_headers)) +} + +#[cfg(any(feature = "with-tokio", feature = "with-async-std"))] +pub(crate) async fn response_data_to_writer( + response: http::Response, + writer: &mut T, +) -> Result +where + R: ResponseBody, + T: AsyncWrite + Send + Unpin + ?Sized, +{ + let status_code = response.status(); + let mut stream = pin!(response.into_body().into_data_stream()); + + while let Some(item) = stream.next().await { + writer.write_all(item.map_err(Into::into)?.as_ref()).await?; + } + + Ok(status_code.as_u16()) +} + +#[cfg(feature = "sync")] +pub(crate) fn response_data_to_writer( + response: http::Response, + writer: &mut T, +) -> Result +where + R: ResponseBody, + T: std::io::Write + Send + ?Sized, +{ + let status_code = response.status(); + let mut body = response.into_body(); + std::io::copy(&mut body, writer)?; + Ok(status_code.as_u16()) +} +struct BuildHelper<'temp, 'body, B> { + bucket: &'temp Bucket, + path: &'temp str, + command: Command<'body>, + datetime: OffsetDateTime, +} + +#[maybe_async::maybe_async] +impl<'temp, 'body, B> BuildHelper<'temp, 'body, B> { async fn signing_key(&self) -> Result, S3Error> { signing::signing_key( - &self.datetime(), + &self.datetime, &self - .bucket() + .bucket .secret_key() .await? .expect("Secret key must be provided to sign headers, found None"), - &self.bucket().region(), + &self.bucket.region(), "s3", ) } - fn request_body(&self) -> Result, S3Error> { - let result = if let Command::PutObject { content, .. } = self.command() { - Vec::from(content) - } else if let Command::PutObjectTagging { tags } = self.command() { - Vec::from(tags) - } else if let Command::UploadPart { content, .. } = self.command() { - Vec::from(content) - } else if let Command::CompleteMultipartUpload { data, .. } = &self.command() { - let body = data.to_string(); - body.as_bytes().to_vec() - } else if let Command::CreateBucket { config } = &self.command() { - if let Some(payload) = config.location_constraint_payload() { - Vec::from(payload) - } else { - Vec::new() - } - } else if let Command::PutBucketLifecycle { configuration, .. } = &self.command() { - quick_xml::se::to_string(configuration)?.as_bytes().to_vec() - } else if let Command::PutBucketCors { configuration, .. } = &self.command() { - let cors = configuration.to_string(); - cors.as_bytes().to_vec() - } else { - Vec::new() - }; - Ok(result) - } - fn long_date(&self) -> Result { - Ok(self.datetime().format(LONG_DATETIME)?) + Ok(self.datetime.format(LONG_DATETIME)?) } fn string_to_sign(&self, request: &str) -> Result { - signing::string_to_sign(&self.datetime(), &self.bucket().region(), request) + signing::string_to_sign(&self.datetime, &self.bucket.region(), request) } fn host_header(&self) -> String { - self.bucket().host() + self.bucket.host() } #[maybe_async::async_impl] async fn presigned(&self) -> Result { - let (expiry, custom_headers, custom_queries) = match self.command() { + let (expiry, custom_headers, custom_queries) = match &self.command { Command::PresignGet { expiry_secs, custom_queries, - } => (expiry_secs, None, custom_queries), + } => (expiry_secs, None, custom_queries.as_ref()), Command::PresignPut { expiry_secs, custom_headers, custom_queries, - } => (expiry_secs, custom_headers, custom_queries), + } => ( + expiry_secs, + custom_headers.as_ref(), + custom_queries.as_ref(), + ), Command::PresignDelete { expiry_secs } => (expiry_secs, None, None), _ => unreachable!(), }; let url = self - .presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref()) + .presigned_url_no_sig(*expiry, custom_headers, custom_queries) .await?; // Build the URL string preserving the original host (including standard ports) // The Url type drops standard ports when converting to string, but we need them // for signature validation - let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket().region() - { + let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket.region() { // Check if we need to preserve a standard port if (endpoint.contains(":80") && url.scheme() == "http" && url.port().is_none()) || (endpoint.contains(":443") && url.scheme() == "https" && url.port().is_none()) { // Rebuild the URL with the original host from the endpoint - let host = self.bucket().host(); + let host = self.bucket.host(); format!( "{}://{}{}{}", url.scheme(), @@ -315,41 +414,38 @@ pub trait Request { Ok(format!( "{}&X-Amz-Signature={}", url_str, - self.presigned_authorization(custom_headers.as_ref()) - .await? + self.presigned_authorization(custom_headers).await? )) } #[maybe_async::sync_impl] async fn presigned(&self) -> Result { - let (expiry, custom_headers, custom_queries) = match self.command() { + let (expiry, custom_headers, custom_queries) = match &self.command { Command::PresignGet { expiry_secs, custom_queries, - } => (expiry_secs, None, custom_queries), + } => (expiry_secs, None, custom_queries.as_ref()), Command::PresignPut { expiry_secs, custom_headers, .. - } => (expiry_secs, custom_headers, None), + } => (expiry_secs, custom_headers.as_ref(), None), Command::PresignDelete { expiry_secs } => (expiry_secs, None, None), _ => unreachable!(), }; - let url = - self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?; + let url = self.presigned_url_no_sig(*expiry, custom_headers, custom_queries)?; // Build the URL string preserving the original host (including standard ports) // The Url type drops standard ports when converting to string, but we need them // for signature validation - let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket().region() - { + let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket.region() { // Check if we need to preserve a standard port if (endpoint.contains(":80") && url.scheme() == "http" && url.port().is_none()) || (endpoint.contains(":443") && url.scheme() == "https" && url.port().is_none()) { // Rebuild the URL with the original host from the endpoint - let host = self.bucket().host(); + let host = self.bucket.host(); format!( "{}://{}{}{}", url.scheme(), @@ -367,7 +463,7 @@ pub trait Request { Ok(format!( "{}&X-Amz-Signature={}", url_str, - self.presigned_authorization(custom_headers.as_ref())? + self.presigned_authorization(custom_headers)? )) } @@ -393,24 +489,28 @@ pub trait Request { } async fn presigned_canonical_request(&self, headers: &HeaderMap) -> Result { - let (expiry, custom_headers, custom_queries) = match self.command() { + let (expiry, custom_headers, custom_queries) = match &self.command { Command::PresignGet { expiry_secs, custom_queries, - } => (expiry_secs, None, custom_queries), + } => (expiry_secs, None, custom_queries.as_ref()), Command::PresignPut { expiry_secs, custom_headers, custom_queries, - } => (expiry_secs, custom_headers, custom_queries), + } => ( + expiry_secs, + custom_headers.as_ref(), + custom_queries.as_ref(), + ), Command::PresignDelete { expiry_secs } => (expiry_secs, None, None), _ => unreachable!(), }; signing::canonical_request( - &self.command().http_verb().to_string(), + &self.command.http_verb().to_string(), &self - .presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref()) + .presigned_url_no_sig(*expiry, custom_headers, custom_queries) .await?, headers, "UNSIGNED-PAYLOAD", @@ -424,7 +524,7 @@ pub trait Request { custom_headers: Option<&HeaderMap>, custom_queries: Option<&HashMap>, ) -> Result { - let bucket = self.bucket(); + let bucket = self.bucket; let token = if let Some(security_token) = bucket.security_token().await? { Some(security_token) } else { @@ -434,9 +534,9 @@ pub trait Request { "{}{}{}", self.url()?, &signing::authorization_query_params_no_sig( - &self.bucket().access_key().await?.unwrap_or_default(), - &self.datetime(), - &self.bucket().region(), + &self.bucket.access_key().await?.unwrap_or_default(), + &self.datetime, + &self.bucket.region(), expiry, custom_headers, token.as_ref() @@ -454,7 +554,7 @@ pub trait Request { custom_headers: Option<&HeaderMap>, custom_queries: Option<&HashMap>, ) -> Result { - let bucket = self.bucket(); + let bucket = self.bucket; let token = if let Some(security_token) = bucket.security_token()? { Some(security_token) } else { @@ -464,9 +564,9 @@ pub trait Request { "{}{}{}", self.url()?, &signing::authorization_query_params_no_sig( - &self.bucket().access_key()?.unwrap_or_default(), - &self.datetime(), - &self.bucket().region(), + &self.bucket.access_key()?.unwrap_or_default(), + &self.datetime, + &self.bucket.region(), expiry, custom_headers, token.as_ref() @@ -478,20 +578,20 @@ pub trait Request { } fn url(&self) -> Result { - let mut url_str = self.bucket().url(); + let mut url_str = self.bucket.url(); - if let Command::ListBuckets { .. } = self.command() { + if let Command::ListBuckets { .. } = self.command { return Ok(Url::parse(&url_str)?); } - if let Command::CreateBucket { .. } = self.command() { + if let Command::CreateBucket { .. } = self.command { return Ok(Url::parse(&url_str)?); } - let path = if self.path().starts_with('/') { - self.path()[1..].to_string() + let path = if self.path.starts_with('/') { + self.path[1..].to_string() } else { - self.path()[..].to_string() + self.path[..].to_string() }; url_str.push('/'); @@ -499,7 +599,7 @@ pub trait Request { // Append to url_path #[allow(clippy::collapsible_match)] - match self.command() { + match &self.command { Command::InitiateMultipartUpload { .. } | Command::ListMultipartUploads { .. } => { url_str.push_str("?uploads") } @@ -554,7 +654,7 @@ pub trait Request { let mut url = Url::parse(&url_str)?; - for (key, value) in &self.bucket().extra_query { + for (key, value) in &self.bucket.extra_query { url.query_pairs_mut().append_pair(key, value); } @@ -564,7 +664,7 @@ pub trait Request { continuation_token, start_after, max_keys, - } = self.command().clone() + } = self.command.clone() { let mut query_pairs = url.query_pairs_mut(); delimiter.map(|d| query_pairs.append_pair("delimiter", &d)); @@ -587,7 +687,7 @@ pub trait Request { delimiter, marker, max_keys, - } = self.command().clone() + } = self.command.clone() { let mut query_pairs = url.query_pairs_mut(); delimiter.map(|d| query_pairs.append_pair("delimiter", &d)); @@ -601,7 +701,7 @@ pub trait Request { } } - match self.command() { + match &self.command { Command::ListMultipartUploads { prefix, delimiter, @@ -614,7 +714,7 @@ pub trait Request { query_pairs.append_pair("prefix", prefix); } if let Some(key_marker) = key_marker { - query_pairs.append_pair("key-marker", &key_marker); + query_pairs.append_pair("key-marker", key_marker); } if let Some(max_uploads) = max_uploads { query_pairs.append_pair("max-uploads", max_uploads.to_string().as_str()); @@ -633,10 +733,10 @@ pub trait Request { fn canonical_request(&self, headers: &HeaderMap) -> Result { signing::canonical_request( - &self.command().http_verb().to_string(), + &self.command.http_verb().to_string(), &self.url()?, headers, - &self.command().sha256()?, + &self.command.sha256()?, ) } @@ -650,31 +750,29 @@ pub trait Request { let signed_header = signing::signed_header_string(headers); signing::authorization_header( &self - .bucket() + .bucket .access_key() .await? .expect("No access_key provided"), - &self.datetime(), - &self.bucket().region(), + &self.datetime, + &self.bucket.region(), &signed_header, &signature, ) } #[maybe_async::maybe_async] - async fn headers(&self) -> Result { + async fn add_headers(&self, headers: &mut HeaderMap) -> Result<(), S3Error> { // Generate this once, but it's used in more than one place. - let sha256 = self.command().sha256()?; + let sha256 = self.command.sha256()?; // Start with extra_headers, that way our headers replace anything with // the same name. - let mut headers = HeaderMap::new(); - - for (k, v) in self.bucket().extra_headers.iter() { + for (k, v) in self.bucket.extra_headers.iter() { if k.as_str().starts_with("x-amz-meta-") { // metadata is invalid on any multipart command other than initiate - match self.command() { + match self.command { Command::UploadPart { .. } | Command::AbortMultipartUpload { .. } | Command::CompleteMultipartUpload { .. } @@ -688,7 +786,7 @@ pub trait Request { } // Append custom headers for PUT request if any - if let Command::PutObject { custom_headers, .. } = self.command() + if let Command::PutObject { custom_headers, .. } = &self.command && let Some(custom_headers) = custom_headers { for (k, v) in custom_headers.iter() { @@ -700,7 +798,7 @@ pub trait Request { headers.insert(HOST, host_header.parse()?); - match self.command() { + match self.command { Command::CopyObject { from } => { headers.insert(HeaderName::from_static("x-amz-copy-source"), from.parse()?); } @@ -713,9 +811,9 @@ pub trait Request { _ => { headers.insert( CONTENT_LENGTH, - self.command().content_length()?.to_string().parse()?, + self.command.content_length()?.to_string().parse()?, ); - headers.insert(CONTENT_TYPE, self.command().content_type().parse()?); + headers.insert(CONTENT_TYPE, self.command.content_type().parse()?); } } headers.insert( @@ -727,34 +825,34 @@ pub trait Request { self.long_date()?.parse()?, ); - if let Some(session_token) = self.bucket().session_token().await? { + if let Some(session_token) = self.bucket.session_token().await? { headers.insert( HeaderName::from_static("x-amz-security-token"), session_token.parse()?, ); - } else if let Some(security_token) = self.bucket().security_token().await? { + } else if let Some(security_token) = self.bucket.security_token().await? { headers.insert( HeaderName::from_static("x-amz-security-token"), security_token.parse()?, ); } - if let Command::PutObjectTagging { tags } = self.command() { + if let Command::PutObjectTagging { tags } = self.command { let digest = md5::compute(tags); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); - } else if let Command::PutObject { content, .. } = self.command() { + } else if let Command::PutObject { content, .. } = self.command { let digest = md5::compute(content); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); - } else if let Command::UploadPart { content, .. } = self.command() { + } else if let Command::UploadPart { content, .. } = self.command { let digest = md5::compute(content); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); - } else if let Command::GetObject {} = self.command() { + } else if let Command::GetObject {} = self.command { headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?); // headers.insert(header::ACCEPT_CHARSET, HeaderValue::from_str("UTF-8")?); - } else if let Command::GetObjectRange { start, end } = self.command() { + } else if let Command::GetObjectRange { start, end } = self.command { headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?); let mut range = format!("bytes={}-", start); @@ -764,9 +862,9 @@ pub trait Request { } headers.insert(RANGE, range.parse()?); - } else if let Command::CreateBucket { ref config } = self.command() { - config.add_headers(&mut headers)?; - } else if let Command::PutBucketLifecycle { ref configuration } = self.command() { + } else if let Command::CreateBucket { ref config } = self.command { + config.add_headers(headers)?; + } else if let Command::PutBucketLifecycle { ref configuration } = self.command { let digest = md5::compute(to_string(configuration)?.as_bytes()); let hash = general_purpose::STANDARD.encode(digest.as_ref()); headers.insert(HeaderName::from_static("content-md5"), hash.parse()?); @@ -775,7 +873,7 @@ pub trait Request { expected_bucket_owner, configuration, .. - } = self.command() + } = &self.command { let digest = md5::compute(configuration.to_string().as_bytes()); let hash = general_purpose::STANDARD.encode(digest.as_ref()); @@ -787,7 +885,7 @@ pub trait Request { ); } else if let Command::GetBucketCors { expected_bucket_owner, - } = self.command() + } = &self.command { headers.insert( HeaderName::from_static("x-amz-expected-bucket-owner"), @@ -795,7 +893,7 @@ pub trait Request { ); } else if let Command::DeleteBucketCors { expected_bucket_owner, - } = self.command() + } = &self.command { headers.insert( HeaderName::from_static("x-amz-expected-bucket-owner"), @@ -804,7 +902,7 @@ pub trait Request { } else if let Command::GetObjectAttributes { expected_bucket_owner, .. - } = self.command() + } = &self.command { headers.insert( HeaderName::from_static("x-amz-expected-bucket-owner"), @@ -817,8 +915,8 @@ pub trait Request { } // This must be last, as it signs the other headers, omitted if no secret key is provided - if self.bucket().secret_key().await?.is_some() { - let authorization = self.authorization(&headers).await?; + if self.bucket.secret_key().await?.is_some() { + let authorization = self.authorization(headers).await?; headers.insert(AUTHORIZATION, authorization.parse()?); } @@ -828,12 +926,96 @@ pub trait Request { // range and can't be used again e.g. reply attacks. Adding this header // after the generation of the Authorization header leaves it out of // the signed headers. - headers.insert(DATE, self.datetime().format(&Rfc2822)?.parse()?); + headers.insert(DATE, self.datetime.format(&Rfc2822)?.parse()?); - Ok(headers) + Ok(()) } } +fn make_body(command: Command<'_>) -> Result, S3Error> { + let result = if let Command::PutObject { content, .. } = command { + content.into() + } else if let Command::PutObjectTagging { tags } = command { + tags.as_bytes().into() + } else if let Command::UploadPart { content, .. } = command { + content.into() + } else if let Command::CompleteMultipartUpload { data, .. } = &command { + let body = data.to_string(); + body.as_bytes().to_vec().into() + } else if let Command::CreateBucket { config } = &command { + if let Some(payload) = config.location_constraint_payload() { + payload.as_bytes().to_vec().into() + } else { + b"".into() + } + } else if let Command::PutBucketLifecycle { configuration, .. } = &command { + quick_xml::se::to_string(configuration)? + .as_bytes() + .to_vec() + .into() + } else if let Command::PutBucketCors { configuration, .. } = &command { + let cors = configuration.to_string(); + cors.as_bytes().to_vec().into() + } else { + b"".into() + }; + Ok(result) +} + +#[maybe_async::maybe_async] +pub(crate) async fn build_request<'body, B>( + bucket: &Bucket, + path: &str, + command: Command<'body>, +) -> Result>, S3Error> { + let method = match command.http_verb() { + HttpMethod::Delete => Method::DELETE, + HttpMethod::Get => Method::GET, + HttpMethod::Post => Method::POST, + HttpMethod::Put => Method::PUT, + HttpMethod::Head => Method::HEAD, + }; + + let mut request_builder = http::Request::builder().method(method); + let headers_builder = BuildHelper { + bucket, + path, + command, + datetime: now_utc(), + }; + headers_builder + .add_headers(request_builder.headers_mut().unwrap()) + .await?; + let url = headers_builder.url()?; + let uri_builder = http::Uri::builder() + .scheme(url.scheme()) + .authority(url.authority()); + let uri_builder = match url.query() { + None => uri_builder.path_and_query(url.path()), + Some(query) => uri_builder.path_and_query(format!("{}?{}", url.path(), query)), + }; + let request = request_builder + .uri(uri_builder.build()?) + .body(make_body(headers_builder.command)?)?; + Ok(request) +} + +#[maybe_async::maybe_async] +pub(crate) async fn build_presigned( + bucket: &Bucket, + path: &str, + command: Command<'_>, +) -> Result { + BuildHelper { + bucket, + path, + command, + datetime: now_utc(), + } + .presigned() + .await +} + #[cfg(all(test, feature = "with-tokio"))] mod tests { use super::*; @@ -1067,3 +1249,152 @@ mod async_std_tests { assert_eq!(output, b"First chunk\nSecond chunk\nThird chunk\n"); } } + +#[cfg(test)] +mod request_tests { + use super::build_request; + use crate::bucket::Bucket; + use crate::command::Command; + use awscreds::Credentials; + use http::header::{HOST, RANGE}; + + /// A trivial spinning async executor we can use to be independent + /// of tokio or async-std + #[cfg(not(feature = "sync"))] + mod test_executor { + #[derive(Default)] + pub(super) struct TestWaker(pub std::sync::atomic::AtomicBool); + + impl std::task::Wake for TestWaker { + fn wake(self: std::sync::Arc) { + self.0.store(true, std::sync::atomic::Ordering::Release); + } + } + } + + #[cfg(not(feature = "sync"))] + fn testawait(fut: F) -> F::Output { + let w = std::sync::Arc::new(test_executor::TestWaker::default()); + let waker: std::task::Waker = w.clone().into(); + let mut cx = std::task::Context::from_waker(&waker); + let mut fut = std::pin::pin!(fut); + loop { + if let std::task::Poll::Ready(o) = fut.as_mut().poll(&mut cx) { + return o; + } + if !w.0.swap(false, std::sync::atomic::Ordering::AcqRel) { + panic!("Future made no progress"); + } + } + } + + #[cfg(feature = "sync")] + fn testawait(val: T) -> T { + val + } + + // Fake keys - otherwise using Credentials::default will use actual user + // credentials if they exist. + fn fake_credentials() -> Credentials { + let access_key = "AKIAIOSFODNN7EXAMPLE"; + let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() + } + + #[test] + fn url_uses_https_by_default() { + let region = "custom-region".parse().unwrap(); + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); + let path = "/my-first/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "https"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + + assert_eq!(*host, "my-first-bucket.custom-region".to_string()); + } + + #[test] + fn url_uses_https_by_default_path_style() { + let region = "custom-region".parse().unwrap(); + let bucket = Bucket::new("my-first-bucket", region, fake_credentials()) + .unwrap() + .with_path_style(); + let path = "/my-first/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "https"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + + assert_eq!(*host, "custom-region".to_string()); + } + + #[test] + fn url_uses_scheme_from_custom_region_if_defined() { + let region = "http://custom-region".parse().unwrap(); + let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap(); + let path = "/my-second/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "http"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + assert_eq!(*host, "my-second-bucket.custom-region".to_string()); + } + + #[test] + fn url_uses_scheme_from_custom_region_if_defined_with_path_style() { + let region = "http://custom-region".parse().unwrap(); + let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) + .unwrap() + .with_path_style(); + let path = "/my-second/path"; + let request = testawait(build_request(&bucket, path, Command::GetObject)).unwrap(); + + assert_eq!(request.uri().scheme_str().unwrap(), "http"); + + let headers = request.headers(); + let host = headers.get(HOST).unwrap(); + assert_eq!(*host, "custom-region".to_string()); + } + + #[test] + fn test_get_object_range_header() { + let region = "http://custom-region".parse().unwrap(); + let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) + .unwrap() + .with_path_style(); + let path = "/my-second/path"; + + let request = testawait(build_request( + &bucket, + path, + Command::GetObjectRange { + start: 0, + end: None, + }, + )) + .unwrap(); + let headers = request.headers(); + let range = headers.get(RANGE).unwrap(); + assert_eq!(range, "bytes=0-"); + + let request = testawait(build_request( + &bucket, + path, + Command::GetObjectRange { + start: 0, + end: Some(1), + }, + )) + .unwrap(); + let headers = request.headers(); + let range = headers.get(RANGE).unwrap(); + assert_eq!(range, "bytes=0-1"); + } +} diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index 498f19135b..9ab0b09d0c 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -1,22 +1,13 @@ extern crate base64; extern crate md5; -use bytes::Bytes; -use futures_util::TryStreamExt; -use maybe_async::maybe_async; -use std::collections::HashMap; -use std::str::FromStr as _; -use time::OffsetDateTime; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; +use tower_service::Service; -use super::request_trait::{Request, ResponseData, ResponseDataStream}; -use crate::bucket::Bucket; -use crate::command::Command; -use crate::command::HttpMethod; +use super::backend::BackendRequestBody; use crate::error::S3Error; -use crate::retry; -use crate::utils::now_utc; - -use tokio_stream::StreamExt; #[derive(Clone, Debug, Default)] pub(crate) struct ClientOptions { @@ -28,8 +19,7 @@ pub(crate) struct ClientOptions { pub accept_invalid_hostnames: bool, } -#[cfg(feature = "with-tokio")] -pub(crate) fn client(options: &ClientOptions) -> Result { +fn client(options: &ClientOptions) -> Result { let client = reqwest::Client::builder(); let client = if let Some(timeout) = options.request_timeout { @@ -58,295 +48,177 @@ pub(crate) fn client(options: &ClientOptions) -> Result { - pub bucket: &'a Bucket, - pub path: &'a str, - pub command: Command<'a>, - pub datetime: OffsetDateTime, - pub sync: bool, -} - -#[maybe_async] -impl<'a> Request for ReqwestRequest<'a> { - type Response = reqwest::Response; - type HeaderMap = reqwest::header::HeaderMap; - - async fn response(&self) -> Result { - let headers = self - .headers() - .await? - .iter() - .map(|(k, v)| { - ( - reqwest::header::HeaderName::from_str(k.as_str()), - reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()), - ) - }) - .filter(|(k, v)| k.is_ok() && v.is_ok()) - .map(|(k, v)| (k.unwrap(), v.unwrap())) - .collect(); - - let client = self.bucket.http_client(); - - let method = match self.command.http_verb() { - HttpMethod::Delete => reqwest::Method::DELETE, - HttpMethod::Get => reqwest::Method::GET, - HttpMethod::Post => reqwest::Method::POST, - HttpMethod::Put => reqwest::Method::PUT, - HttpMethod::Head => reqwest::Method::HEAD, - }; - - let request = client - .request(method, self.url()?.as_str()) - .headers(headers) - .body(self.request_body()?); - - let request = request.build()?; - // println!("Request: {:?}", request); - - let response = client.execute(request).await?; - - if cfg!(feature = "fail-on-err") && !response.status().is_success() { - let status = response.status().as_u16(); - let text = response.text().await?; - return Err(S3Error::HttpFailWithBody(status, text)); - } - - Ok(response) - } - - async fn response_data(&self, etag: bool) -> Result { - let response = retry! {self.response().await }?; - let status_code = response.status().as_u16(); - let mut headers = response.headers().clone(); - let response_headers = headers - .clone() - .iter() - .map(|(k, v)| { - ( - k.to_string(), - v.to_str() - .unwrap_or("could-not-decode-header-value") - .to_string(), - ) - }) - .collect::>(); - // When etag=true, we extract the ETag header and return it as the body. - // This is used for PUT operations (regular puts, multipart chunks) where: - // 1. S3 returns an empty or non-useful response body - // 2. The ETag header contains the essential information we need - // 3. The calling code expects to get the ETag via response_data.as_str() - // - // Note: This approach means we discard any actual response body when etag=true, - // but for the operations that use this (PUTs), the body is typically empty - // or contains redundant information already available in headers. - // - // TODO: Refactor this to properly return the response body and access ETag - // from headers instead of replacing the body. This would be a breaking change. - let body_vec = if etag { - if let Some(etag) = headers.remove("ETag") { - Bytes::from(etag.to_str()?.to_string()) - } else { - Bytes::from("") +impl Service>> for ReqwestBackend +where + T: Service + + Send + + 'static, + T::Future: Send, +{ + type Response = http::Response; + type Error = S3Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.http_client.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, request: http::Request>) -> Self::Future { + match request.map(|b| b.into_owned()).try_into() { + Ok::(request) => { + let fut = self.http_client.call(request); + Box::pin(async move { + let response = fut.await?; + if cfg!(feature = "fail-on-err") && !response.status().is_success() { + let status = response.status().as_u16(); + let text = response.text().await?; + return Err(S3Error::HttpFailWithBody(status, text)); + } + Ok(response.into()) + }) } - } else { - response.bytes().await? - }; - Ok(ResponseData::new(body_vec, status_code, response_headers)) - } - - async fn response_data_to_writer( - &self, - writer: &mut T, - ) -> Result { - use tokio::io::AsyncWriteExt; - let response = retry! {self.response().await}?; - - let status_code = response.status(); - let mut stream = response.bytes_stream(); - - while let Some(item) = stream.next().await { - writer.write_all(&item?).await?; + Err(e) => Box::pin(std::future::ready(Err(e.into()))), } - - Ok(status_code.as_u16()) } +} - async fn response_data_to_stream(&self) -> Result { - let response = retry! {self.response().await}?; - let status_code = response.status(); - let stream = response.bytes_stream().map_err(S3Error::Reqwest); +#[derive(Clone, Debug, Default)] +pub struct ReqwestBackend { + http_client: T, + client_options: ClientOptions, +} - Ok(ResponseDataStream { - bytes: Box::pin(stream), - status_code: status_code.as_u16(), +impl ReqwestBackend { + pub fn with_request_timeout(&self, request_timeout: Option) -> Result { + let client_options = ClientOptions { + request_timeout, + ..self.client_options.clone() + }; + Ok(Self { + http_client: client(&client_options)?, + client_options, }) } - async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> { - let response = retry! {self.response().await}?; - let status_code = response.status().as_u16(); - let headers = response.headers().clone(); - Ok((headers, status_code)) - } - - fn datetime(&self) -> OffsetDateTime { - self.datetime - } - - fn bucket(&self) -> Bucket { - self.bucket.clone() - } - - fn command(&self) -> Command<'_> { - self.command.clone() - } - - fn path(&self) -> String { - self.path.to_string() + pub fn request_timeout(&self) -> Option { + self.client_options.request_timeout + } + + /// Configures a bucket to accept invalid SSL certificates and hostnames. + /// + /// This method is available only when either the `tokio-native-tls` or `tokio-rustls-tls` feature is enabled. + /// + /// # Parameters + /// + /// - `accept_invalid_certs`: A boolean flag that determines whether the client should accept invalid SSL certificates. + /// - `accept_invalid_hostnames`: A boolean flag that determines whether the client should accept invalid hostnames. + /// + /// # Returns + /// + /// Returns a `Result` containing the newly configured `Bucket` instance if successful, or an `S3Error` if an error occurs during client configuration. + /// + /// # Errors + /// + /// This function returns an `S3Error` if the HTTP client configuration fails. + /// + /// # Example + /// + /// ```rust + /// # use s3::bucket::Bucket; + /// # use s3::error::S3Error; + /// # use s3::creds::Credentials; + /// # use s3::Region; + /// # use std::str::FromStr; + /// + /// # fn example() -> Result<(), S3Error> { + /// let bucket = Bucket::new("my-bucket", Region::from_str("us-east-1")?, Credentials::default()?)? + /// .set_dangereous_config(true, true)?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] + pub fn with_dangereous_config( + &self, + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + ) -> Result { + let client_options = ClientOptions { + accept_invalid_certs, + accept_invalid_hostnames, + ..self.client_options.clone() + }; + Ok(Self { + http_client: client(&client_options)?, + client_options, + }) } -} -impl<'a> ReqwestRequest<'a> { - pub async fn new( - bucket: &'a Bucket, - path: &'a str, - command: Command<'a>, - ) -> Result, S3Error> { - bucket.credentials_refresh().await?; + pub fn with_proxy(&self, proxy: reqwest::Proxy) -> Result { + let client_options = ClientOptions { + proxy: Some(proxy), + ..self.client_options.clone() + }; Ok(Self { - bucket, - path, - command, - datetime: now_utc(), - sync: false, + http_client: client(&client_options)?, + client_options, }) } } #[cfg(test)] mod tests { - use crate::bucket::Bucket; - use crate::command::Command; - use crate::request::Request; - use crate::request::tokio_backend::ReqwestRequest; - use awscreds::Credentials; - use http::header::{HOST, RANGE}; + use super::*; + use http_body_util::BodyExt; - // Fake keys - otherwise using Credentials::default will use actual user - // credentials if they exist. - fn fake_credentials() -> Credentials { - let access_key = "AKIAIOSFODNN7EXAMPLE"; - let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; - Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap() - } + #[derive(Clone, Default)] + struct MockReqwestClient; - #[tokio::test] - async fn url_uses_https_by_default() { - let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap(); - let path = "/my-first/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) - .await - .unwrap(); - - assert_eq!(request.url().unwrap().scheme(), "https"); + impl Service for MockReqwestClient { + type Response = reqwest::Response; + type Error = reqwest::Error; + type Future = + Pin> + Send>>; - let headers = request.headers().await.unwrap(); - let host = headers.get(HOST).unwrap(); + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } - assert_eq!(*host, "my-first-bucket.custom-region".to_string()); + fn call(&mut self, mut r: reqwest::Request) -> Self::Future { + assert_eq!(r.method(), http::Method::POST); + assert_eq!(r.url().as_str(), "https://example.com/foo?bar=1"); + assert_eq!(r.headers().get("h1").unwrap(), "v1"); + assert_eq!(r.headers().get("h2").unwrap(), "v2"); + Box::pin(async move { + let body = r.body_mut().take().unwrap().collect().await; + assert_eq!(body.unwrap().to_bytes().as_ref(), b"sneaky"); + Ok(http::Response::builder() + .body(reqwest::Body::from("")) + .unwrap() + .into()) + }) + } } #[tokio::test] - async fn url_uses_https_by_default_path_style() { - let region = "custom-region".parse().unwrap(); - let bucket = Bucket::new("my-first-bucket", region, fake_credentials()) - .unwrap() - .with_path_style(); - let path = "/my-first/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) - .await + async fn test_build() { + let http_request = http::Request::builder() + .uri("https://example.com/foo?bar=1") + .method(http::Method::POST) + .header("h1", "v1") + .header("h2", "v2") + .body(b"sneaky".into()) .unwrap(); - assert_eq!(request.url().unwrap().scheme(), "https"); - - let headers = request.headers().await.unwrap(); - let host = headers.get(HOST).unwrap(); - - assert_eq!(*host, "custom-region".to_string()); - } - - #[tokio::test] - async fn url_uses_scheme_from_custom_region_if_defined() { - let region = "http://custom-region".parse().unwrap(); - let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap(); - let path = "/my-second/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) + let mut backend = ReqwestBackend { + http_client: MockReqwestClient, + ..Default::default() + }; + crate::utils::service_ready::Ready::new(&mut backend) .await - .unwrap(); - - assert_eq!(request.url().unwrap().scheme(), "http"); - - let headers = request.headers().await.unwrap(); - let host = headers.get(HOST).unwrap(); - assert_eq!(*host, "my-second-bucket.custom-region".to_string()); - } - - #[tokio::test] - async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() { - let region = "http://custom-region".parse().unwrap(); - let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) .unwrap() - .with_path_style(); - let path = "/my-second/path"; - let request = ReqwestRequest::new(&bucket, path, Command::GetObject) + .call(http_request) .await .unwrap(); - - assert_eq!(request.url().unwrap().scheme(), "http"); - - let headers = request.headers().await.unwrap(); - let host = headers.get(HOST).unwrap(); - assert_eq!(*host, "custom-region".to_string()); - } - - #[tokio::test] - async fn test_get_object_range_header() { - let region = "http://custom-region".parse().unwrap(); - let bucket = Bucket::new("my-second-bucket", region, fake_credentials()) - .unwrap() - .with_path_style(); - let path = "/my-second/path"; - - let request = ReqwestRequest::new( - &bucket, - path, - Command::GetObjectRange { - start: 0, - end: None, - }, - ) - .await - .unwrap(); - let headers = request.headers().await.unwrap(); - let range = headers.get(RANGE).unwrap(); - assert_eq!(range, "bytes=0-"); - - let request = ReqwestRequest::new( - &bucket, - path, - Command::GetObjectRange { - start: 0, - end: Some(1), - }, - ) - .await - .unwrap(); - let headers = request.headers().await.unwrap(); - let range = headers.get(RANGE).unwrap(); - assert_eq!(range, "bytes=0-1"); } } diff --git a/s3/src/utils/mod.rs b/s3/src/utils/mod.rs index 04918ccf19..8b07a35b69 100644 --- a/s3/src/utils/mod.rs +++ b/s3/src/utils/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +pub mod testing; + mod time_utils; pub use time_utils::*; @@ -427,6 +430,40 @@ macro_rules! retry { }}; } +/// Like tower::util::ServiceExt::ready but avoiding a dependency on tower. +#[cfg(not(feature = "sync"))] +pub(crate) mod service_ready { + use std::marker::PhantomData; + use std::pin::Pin; + use std::task::{Context, Poll}; + use tower_service::Service; + + pub struct Ready<'a, T, R>(pub Option<&'a mut T>, pub PhantomData R>); + + impl<'a, T, R> Ready<'a, T, R> { + pub fn new(inner: &'a mut T) -> Self { + Self(Some(inner), PhantomData) + } + } + + impl<'a, T: Service, R> Future for Ready<'a, T, R> { + type Output = Result<&'a mut T, T::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self + .0 + .as_mut() + .expect("poll after Poll::Ready") + .poll_ready(cx) + { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => Poll::Ready(Ok(self.0.take().unwrap())), + } + } + } +} + #[cfg(test)] mod test { use crate::utils::etag_for_path; diff --git a/s3/src/utils/testing.rs b/s3/src/utils/testing.rs new file mode 100644 index 0000000000..791c070ba8 --- /dev/null +++ b/s3/src/utils/testing.rs @@ -0,0 +1,36 @@ +#[derive(Clone)] +pub struct AlwaysFailBackend; + +#[cfg(not(feature = "sync"))] +impl tower_service::Service for AlwaysFailBackend { + type Response = http::Response>; + type Error = http::Error; + type Future = std::future::Ready>; + + fn poll_ready( + &mut self, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, _: R) -> Self::Future { + std::future::ready( + http::Response::builder() + .status(http::StatusCode::IM_A_TEAPOT) + .body(http_body_util::Empty::new()), + ) + } +} + +#[cfg(feature = "sync")] +impl crate::request::backend::SyncService for AlwaysFailBackend { + type Response = http::Response<&'static [u8]>; + type Error = http::Error; + + fn call(&mut self, _: R) -> Result { + http::Response::builder() + .status(http::StatusCode::IM_A_TEAPOT) + .body(b"") + } +}