Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 160 additions & 15 deletions src/auth/src/credentials/idtoken/mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ use crate::credentials::mds::{
METADATA_ROOT,
};
use crate::errors::CredentialsError;
use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use crate::token_cache::TokenCache;
use crate::{
Expand All @@ -79,6 +80,9 @@ use crate::{
credentials::idtoken::{IDTokenCredentials, parse_id_token_from_str},
};
use async_trait::async_trait;
use gax::backoff_policy::BackoffPolicyArg;
use gax::retry_policy::RetryPolicyArg;
use gax::retry_throttler::RetryThrottlerArg;
use http::{Extensions, HeaderValue};
use reqwest::Client;
use std::sync::Arc;
Expand Down Expand Up @@ -135,6 +139,7 @@ pub struct Builder {
pub(crate) format: Option<Format>,
licenses: Option<String>,
target_audience: String,
retry_builder: RetryTokenProviderBuilder,
}

impl Builder {
Expand All @@ -149,6 +154,7 @@ impl Builder {
endpoint: None,
licenses: None,
target_audience: target_audience.into(),
retry_builder: RetryTokenProviderBuilder::default(),
}
}

Expand Down Expand Up @@ -222,26 +228,97 @@ impl Builder {
self
}

fn build_token_provider(self) -> MDSTokenProvider {
let final_endpoint: String;
/// Configure the retry policy for fetching tokens.
///
/// The retry policy controls how to handle retries, and sets limits on
/// the number of attempts or the total time spent retrying.
///
/// # Example
///
/// ```no_run
/// # use google_cloud_auth::credentials::idtoken;
/// use gax::retry_policy::{AlwaysRetry, RetryPolicyExt};
///
/// let audience = "https://my-service.a.run.app";
/// let credentials = idtoken::mds::Builder::new(audience)
/// .with_retry_policy(AlwaysRetry.with_attempt_limit(3))
/// .build();
/// ```
pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_policy(v.into());
self
}

// Determine the endpoint and whether it was overridden
if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
// Check GCE_METADATA_HOST environment variable first
final_endpoint = format!("http://{host_from_env}");
} else if let Some(builder_endpoint) = self.endpoint {
// Else, check if an endpoint was provided to the mds::Builder
final_endpoint = builder_endpoint;
} else {
// Else, use the default metadata root
final_endpoint = METADATA_ROOT.to_string();
};
/// Configure the retry backoff policy.
///
/// The backoff policy controls how long to wait in between retry attempts.
///
/// # Example
///
/// ```no_run
/// # use google_cloud_auth::credentials::idtoken;
/// use gax::exponential_backoff::ExponentialBackoff;
///
/// let audience = "https://my-service.a.run.app";
/// let credentials = idtoken::mds::Builder::new(audience)
/// .with_backoff_policy(ExponentialBackoff::default())
/// .build();
/// ```
pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
self
}

/// Configure the retry throttler.
///
/// Advanced applications may want to configure a retry throttler to
/// [Address Cascading Failures] and when [Handling Overload] conditions.
/// The authentication library throttles its retry loop, using a policy to
/// control the throttling algorithm. Use this method to fine tune or
/// customize the default retry throttler.
///
/// [Handling Overload]: https://sre.google/sre-book/handling-overload/
/// [Address Cascading Failures]: https://sre.google/sre-book/addressing-cascading-failures/
///
/// # Example
///
/// ```no_run
/// # use google_cloud_auth::credentials::idtoken;
/// use gax::retry_throttler::AdaptiveThrottler;
///
/// let audience = "https://my-service.a.run.app";
/// let credentials = idtoken::mds::Builder::new(audience)
/// .with_retry_throttler(AdaptiveThrottler::default())
/// .build();
/// ```
pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
self
}

fn build_token_provider(self) -> TokenProviderWithRetry<MDSTokenProvider> {
let final_endpoint = self.resolve_endpoint();

MDSTokenProvider {
let tp = MDSTokenProvider {
format: self.format,
licenses: self.licenses,
endpoint: final_endpoint,
target_audience: self.target_audience,
};
self.retry_builder.build(tp)
}

fn resolve_endpoint(&self) -> String {
// Determine the endpoint
if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
// Check GCE_METADATA_HOST environment variable first
format!("http://{host_from_env}")
} else if let Some(builder_endpoint) = self.endpoint.clone() {
// Else, check if an endpoint was provided to the mds::Builder
builder_endpoint
} else {
// Else, use the default metadata root
METADATA_ROOT.to_string()
}
}

Expand Down Expand Up @@ -308,7 +385,11 @@ impl TokenProvider for MDSTokenProvider {
mod tests {
use super::*;
use crate::credentials::idtoken::tests::generate_test_id_token;
use crate::credentials::tests::find_source_error;
use crate::credentials::tests::{
find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy,
get_mock_retry_throttler,
};
use httptest::cycle;
use httptest::matchers::{all_of, contains, request, url_decoded};
use httptest::responders::status_code;
use httptest::{Expectation, Server};
Expand All @@ -319,6 +400,70 @@ mod tests {

type TestResult = anyhow::Result<()>;

#[tokio::test]
#[parallel]
async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
let server = Server::run();
let audience = "test-audience";
server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
])
.times(1)
.respond_with(status_code(401)),
);

let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;

let err = creds.id_token().await.unwrap_err();
let source = find_source_error::<reqwest::Error>(&err);
assert!(
matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
"{err:?}"
);

Ok(())
}

#[tokio::test]
#[parallel]
async fn test_mds_retries_for_success() -> TestResult {
let server = Server::run();
let audience = "test-audience";
let token_string = generate_test_id_token(audience);

server.expect(
Expectation::matching(all_of![
request::path(format!("{MDS_DEFAULT_URI}/identity")),
request::query(url_decoded(contains(("audience", audience)))),
])
.times(3)
.respond_with(cycle![
status_code(503).body("try-again"),
status_code(503).body("try-again"),
status_code(200).body(token_string.clone()),
]),
);

let creds = Builder::new(audience)
.with_endpoint(format!("http://{}", server.addr()))
.with_retry_policy(get_mock_auth_retry_policy(3))
.with_backoff_policy(get_mock_backoff_policy())
.with_retry_throttler(get_mock_retry_throttler())
.build()?;

let id_token = creds.id_token().await?;
assert_eq!(id_token, token_string);

Ok(())
}

#[tokio::test]
#[test_case(Format::Standard)]
#[test_case(Format::Full)]
Expand Down
Loading