diff --git a/src/client.rs b/src/client.rs index 4078d13..9d51f82 100644 --- a/src/client.rs +++ b/src/client.rs @@ -56,3 +56,35 @@ pub trait Client: Sync + Send { Ok(response) } } + +/// Client that wraps another client and adds a bearer token authentication header to each request. +pub struct BearerTokenAuthClient { + token: String, + client: C, +} + +impl BearerTokenAuthClient { + /// Construct from an arbitrary client and a bearer token. + pub fn new(client: C, token: &str) -> Self { + Self { client, token: token.to_string() } + } +} + +#[async_trait::async_trait] +impl Client for BearerTokenAuthClient { + fn base(&self) -> &str { + self.client.base() + } + + async fn send(&self, mut req: Request>) -> Result>, ClientError> { + let bearer = format!("Bearer {}", self.token); + match http::HeaderValue::from_str(&bearer) { + Ok(mut bearer) => { + bearer.set_sensitive(true); + req.headers_mut().insert("Authorization", bearer); + self.client.send(req).await + } + Err(e) => Err(ClientError::GenericError { source: e.into() }) + } + } +} diff --git a/src/clients/reqwest.rs b/src/clients/reqwest.rs index ea7de16..101345b 100644 --- a/src/clients/reqwest.rs +++ b/src/clients/reqwest.rs @@ -1,7 +1,7 @@ //! Contains an implementation of [Client][crate::client::Client] being backed //! by the [reqwest](https://docs.rs/reqwest/) crate. -use crate::{client::Client as RustifyClient, errors::ClientError}; +use crate::{client::{Client as RustifyClient, BearerTokenAuthClient}, errors::ClientError}; use async_trait::async_trait; use http::{Request, Response}; use std::convert::TryFrom; @@ -102,3 +102,10 @@ impl RustifyClient for Client { .map_err(|e| ClientError::ResponseError { source: e.into() }) } } + +impl BearerTokenAuthClient { + /// Construct from a default client using a given base URL and a bearer token. + pub fn default(base: &str, token: &str) -> Self { + BearerTokenAuthClient::new(Client::default(base), token) + } +} diff --git a/tests/endpoint.rs b/tests/endpoint.rs index 8e97b49..eec3a28 100644 --- a/tests/endpoint.rs +++ b/tests/endpoint.rs @@ -5,7 +5,7 @@ use std::fmt::Debug; use common::{Middle, TestGenericWrapper, TestResponse, TestServer}; use derive_builder::Builder; use httpmock::prelude::*; -use rustify::endpoint::Endpoint; +use rustify::{client::BearerTokenAuthClient, endpoint::Endpoint}; use rustify_derive::Endpoint; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::json; @@ -349,3 +349,22 @@ async fn test_complex() { assert!(r.is_ok()); assert_eq!(r.unwrap().parse().unwrap().age, 30); } + +#[test(tokio::test)] +async fn test_bearer_token_auth_client() { + #[derive(Endpoint)] + #[endpoint(path = "test/path")] + struct Test {} + + let t = TestServer::default(); + let e = Test {}; + let m = t.server.mock(|when, then| { + when.header("Authorization", "Bearer 1234567890"); + then.status(200); + }); + let client = BearerTokenAuthClient::new(t.client, "1234567890"); + let r = e.exec(&client).await; + + m.assert(); + assert!(r.is_ok()); +}