Skip to content

Commit

Permalink
fix oauth-client crate
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 8, 2024
1 parent 8c3be27 commit f6e168f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 24 deletions.
7 changes: 5 additions & 2 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use crate::types::{
TryIntoOAuthClientMetadata,
};
use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values};
use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver};
use atrium_common::resolver::Resolver;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_xrpc::HttpClient;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
Expand Down Expand Up @@ -145,7 +146,9 @@ where
} else {
self.client_metadata.redirect_uris[0].clone()
};
let (metadata, identity) = self.resolver.resolve(input.as_ref()).await?;
let result = self.resolver.resolve(input.as_ref()).await?;
let (metadata, identity) =
result.ok_or_else(|| Error::Identity(atrium_identity::Error::NotFound))?;
let Some(dpop_key) = Self::generate_dpop_key(&metadata) else {
return Err(Error::Authorize("none of the algorithms worked".into()));
};
Expand Down
42 changes: 26 additions & 16 deletions atrium-oauth/oauth-client/src/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
use atrium_common::resolver::CachedResolver;
use atrium_common::resolver::Resolver;
use atrium_common::resolver::ThrottledResolver;
use atrium_common::types::cached::r#impl::Cache;
use atrium_common::types::cached::r#impl::CacheImpl;
use atrium_common::types::cached::CacheConfig;
use atrium_common::types::cached::Cacheable;
use atrium_common::types::throttled::Throttleable;
mod oauth_authorization_server_resolver;
mod oauth_protected_resource_resolver;

Expand All @@ -7,10 +15,7 @@ use crate::types::{OAuthAuthorizationServerMetadata, OAuthProtectedResourceMetad
use atrium_identity::identity_resolver::{
IdentityResolver, IdentityResolverConfig, ResolvedIdentity,
};
use atrium_identity::resolver::{
Cacheable, CachedResolver, CachedResolverConfig, Throttleable, ThrottledResolver,
};
use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver};
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_identity::{Error, Result};
use atrium_xrpc::HttpClient;
use std::marker::PhantomData;
Expand All @@ -19,13 +24,13 @@ use std::time::Duration;

#[derive(Clone, Debug)]
pub struct OAuthAuthorizationServerMetadataResolverConfig {
pub cache: CachedResolverConfig,
pub cache: CacheConfig,
}

impl Default for OAuthAuthorizationServerMetadataResolverConfig {
fn default() -> Self {
Self {
cache: CachedResolverConfig {
cache: CacheConfig {
max_capacity: Some(100),
time_to_live: Some(Duration::from_secs(60)),
},
Expand All @@ -35,13 +40,13 @@ impl Default for OAuthAuthorizationServerMetadataResolverConfig {

#[derive(Clone, Debug)]
pub struct OAuthProtectedResourceMetadataResolverConfig {
pub cache: CachedResolverConfig,
pub cache: CacheConfig,
}

impl Default for OAuthProtectedResourceMetadataResolverConfig {
fn default() -> Self {
Self {
cache: CachedResolverConfig {
cache: CacheConfig {
max_capacity: Some(100),
time_to_live: Some(Duration::from_secs(60)),
},
Expand Down Expand Up @@ -81,11 +86,11 @@ where
let protected_resource_resolver =
DefaultOAuthProtectedResourceResolver::new(http_client.clone())
.throttled()
.cached(config.authorization_server_metadata.cache);
.cached(CacheImpl::new(config.authorization_server_metadata.cache));
let authorization_server_resolver =
DefaultOAuthAuthorizationServerResolver::new(http_client.clone())
.throttled()
.cached(config.protected_resource_metadata.cache);
.cached(CacheImpl::new(config.protected_resource_metadata.cache));
Self {
identity_resolver: IdentityResolver::new(IdentityResolverConfig {
did_resolver: config.did_resolver,
Expand All @@ -108,7 +113,9 @@ where
&self,
issuer: impl AsRef<str>,
) -> Result<OAuthAuthorizationServerMetadata> {
self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await
let result =
self.authorization_server_resolver.resolve(&issuer.as_ref().to_string()).await?;
result.ok_or_else(|| Error::NotFound)
}
async fn resolve_from_service(&self, input: &str) -> Result<OAuthAuthorizationServerMetadata> {
// Assume first that input is a PDS URL (as required by ATPROTO)
Expand All @@ -122,15 +129,17 @@ where
&self,
input: &str,
) -> Result<(OAuthAuthorizationServerMetadata, ResolvedIdentity)> {
let identity = self.identity_resolver.resolve(input).await?;
let result = self.identity_resolver.resolve(input).await;
let identity = result.and_then(|result| result.ok_or_else(|| Error::NotFound))?;
let metadata = self.get_resource_server_metadata(&identity.pds).await?;
Ok((metadata, identity))
}
async fn get_resource_server_metadata(
&self,
pds: &str,
) -> Result<OAuthAuthorizationServerMetadata> {
let rs_metadata = self.protected_resource_resolver.resolve(&pds.to_string()).await?;
let result = self.protected_resource_resolver.resolve(&pds.to_string()).await?;
let rs_metadata = result.ok_or_else(|| Error::NotFound)?;
// ATPROTO requires one, and only one, authorization server entry
// > That document MUST contain a single item in the authorization_servers array.
// https://github.com/bluesky-social/proposals/tree/main/0004-oauth#server-metadata
Expand Down Expand Up @@ -182,16 +191,17 @@ where
{
type Input = str;
type Output = (OAuthAuthorizationServerMetadata, Option<ResolvedIdentity>);
type Error = Error;

async fn resolve(&self, input: &Self::Input) -> Result<Self::Output> {
async fn resolve(&self, input: &Self::Input) -> Result<Option<Self::Output>> {
// Allow using an entryway, or PDS url, directly as login input (e.g.
// when the user forgot their handle, or when the handle does not
// resolve to a DID)
Ok(if input.starts_with("https://") {
(self.resolve_from_service(input.as_ref()).await?, None)
Some((self.resolve_from_service(input.as_ref()).await?, None))
} else {
let (metadata, identity) = self.resolve_from_identity(input).await?;
(metadata, Some(identity))
Some((metadata, Some(identity)))
})
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::resolver::Resolver;
use crate::types::OAuthAuthorizationServerMetadata;
use atrium_identity::{Error, Resolver, Result};
use atrium_identity::{Error, Result};
use atrium_xrpc::http::uri::Builder;
use atrium_xrpc::http::{Request, StatusCode, Uri};
use atrium_xrpc::HttpClient;
Expand All @@ -21,8 +22,9 @@ where
{
type Input = String;
type Output = OAuthAuthorizationServerMetadata;
type Error = Error;

async fn resolve(&self, issuer: &Self::Input) -> Result<Self::Output> {
async fn resolve(&self, issuer: &Self::Input) -> Result<Option<Self::Output>> {
let uri = Builder::from(issuer.parse::<Uri>()?)
.path_and_query("/.well-known/oauth-authorization-server")
.build()?;
Expand All @@ -36,7 +38,7 @@ where
let metadata = serde_json::from_slice::<OAuthAuthorizationServerMetadata>(res.body())?;
// https://datatracker.ietf.org/doc/html/rfc8414#section-3.3
if &metadata.issuer == issuer {
Ok(metadata)
Ok(Some(metadata))
} else {
Err(Error::AuthorizationServerMetadata(format!(
"invalid issuer: {}",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::types::OAuthProtectedResourceMetadata;
use atrium_identity::{Error, Resolver, Result};
use atrium_common::resolver::Resolver;
use atrium_identity::{Error, Result};
use atrium_xrpc::http::uri::Builder;
use atrium_xrpc::http::{Request, StatusCode, Uri};
use atrium_xrpc::HttpClient;
Expand All @@ -21,8 +22,9 @@ where
{
type Input = String;
type Output = OAuthProtectedResourceMetadata;
type Error = Error;

async fn resolve(&self, resource: &Self::Input) -> Result<Self::Output> {
async fn resolve(&self, resource: &Self::Input) -> Result<Option<Self::Output>> {
let uri = Builder::from(resource.parse::<Uri>()?)
.path_and_query("/.well-known/oauth-protected-resource")
.build()?;
Expand All @@ -36,7 +38,7 @@ where
let metadata = serde_json::from_slice::<OAuthProtectedResourceMetadata>(res.body())?;
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-resource-metadata-08#section-3.3
if &metadata.resource == resource {
Ok(metadata)
Ok(Some(metadata))
} else {
Err(Error::ProtectedResourceMetadata(format!(
"invalid resource: {}",
Expand Down

0 comments on commit f6e168f

Please sign in to comment.