Skip to content

Commit 3482910

Browse files
authored
Add custom authenticator (#24)
* Add custom authenticator * Format * . * Format * Add test * . * Format * . * Clean up * PR comments * Fix trait name * . * .
1 parent 5210ce1 commit 3482910

File tree

4 files changed

+228
-25
lines changed

4 files changed

+228
-25
lines changed

crates/catalog/rest/src/catalog.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use std::any::Any;
2121
use std::collections::HashMap;
2222
use std::future::Future;
2323
use std::str::FromStr;
24+
use std::sync::Arc;
2425

2526
use async_trait::async_trait;
2627
use iceberg::io::{self, FileIO};
@@ -38,7 +39,8 @@ use tokio::sync::OnceCell;
3839
use typed_builder::TypedBuilder;
3940

4041
use crate::client::{
41-
HttpClient, deserialize_catalog_response, deserialize_unexpected_catalog_error,
42+
CustomAuthenticator, HttpClient, deserialize_catalog_response,
43+
deserialize_unexpected_catalog_error,
4244
};
4345
use crate::types::{
4446
CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest,
@@ -67,6 +69,7 @@ impl Default for RestCatalogBuilder {
6769
warehouse: None,
6870
props: HashMap::new(),
6971
client: None,
72+
authenticator: None,
7073
})
7174
}
7275
}
@@ -124,6 +127,24 @@ impl RestCatalogBuilder {
124127
self.0.client = Some(client);
125128
self
126129
}
130+
131+
/// Set a custom token authenticator.
132+
///
133+
/// The authenticator will be used to obtain tokens instead of using static tokens
134+
/// or OAuth credentials.
135+
///
136+
/// # Example
137+
/// ```ignore
138+
/// let authenticator = Arc::new(MyAuthenticator::new());
139+
/// let catalog = RestCatalogBuilder::default()
140+
/// .with_token_authenticator(authenticator)
141+
/// .load("rest", config)
142+
/// .await?;
143+
/// ```
144+
pub fn with_token_authenticator(mut self, authenticator: Arc<dyn CustomAuthenticator>) -> Self {
145+
self.0.authenticator = Some(authenticator);
146+
self
147+
}
127148
}
128149

129150
/// Rest catalog configuration.
@@ -142,6 +163,9 @@ pub(crate) struct RestCatalogConfig {
142163

143164
#[builder(default)]
144165
client: Option<Client>,
166+
167+
#[builder(default)]
168+
authenticator: Option<Arc<dyn CustomAuthenticator>>,
145169
}
146170

147171
impl RestCatalogConfig {
@@ -349,7 +373,13 @@ impl RestCatalog {
349373
async fn context(&self) -> Result<&RestContext> {
350374
self.ctx
351375
.get_or_try_init(|| async {
352-
let client = HttpClient::new(&self.user_config)?;
376+
let mut client = HttpClient::new(&self.user_config)?;
377+
378+
// Set authenticator if one was configured
379+
if let Some(authenticator) = &self.user_config.authenticator {
380+
client = client.with_authenticator(authenticator.clone());
381+
}
382+
353383
let catalog_config = RestCatalog::load_config(&client, &self.user_config).await?;
354384
let config = self.user_config.clone().merge_with_config(catalog_config);
355385
let client = client.update_with(&config)?;

crates/catalog/rest/src/client.rs

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use std::collections::HashMap;
1919
use std::fmt::{Debug, Formatter};
20+
use std::sync::Arc;
2021

2122
use http::StatusCode;
2223
use iceberg::{Error, ErrorKind, Result};
@@ -28,6 +29,17 @@ use tokio::sync::Mutex;
2829
use crate::RestCatalogConfig;
2930
use crate::types::{ErrorResponse, TokenResponse};
3031

32+
/// Trait for custom token authentication.
33+
///
34+
/// Implement this trait to provide custom token generation/refresh logic
35+
/// instead of using OAuth credentials.
36+
#[async_trait::async_trait]
37+
pub trait CustomAuthenticator: Send + Sync + Debug {
38+
/// Get or refresh the authentication token.
39+
/// Called when the client needs a token for authentication.
40+
async fn get_token(&self) -> Result<String>;
41+
}
42+
3143
pub(crate) struct HttpClient {
3244
client: Client,
3345

@@ -39,6 +51,8 @@ pub(crate) struct HttpClient {
3951
token_endpoint: String,
4052
/// The credential to be used for authentication.
4153
credential: Option<(Option<String>, String)>,
54+
/// Custom token authenticator (takes precedence over credential/token)
55+
authenticator: Option<Arc<dyn CustomAuthenticator>>,
4256
/// Extra headers to be added to each request.
4357
extra_headers: HeaderMap,
4458
/// Extra oauth parameters to be added to each authentication request.
@@ -63,6 +77,7 @@ impl HttpClient {
6377
token: Mutex::new(cfg.token()),
6478
token_endpoint: cfg.get_token_endpoint(),
6579
credential: cfg.credential(),
80+
authenticator: None,
6681
extra_headers,
6782
extra_oauth_params: cfg.extra_oauth_params(),
6883
})
@@ -86,6 +101,7 @@ impl HttpClient {
86101
self.token_endpoint
87102
},
88103
credential: cfg.credential().or(self.credential),
104+
authenticator: self.authenticator,
89105
extra_headers,
90106
extra_oauth_params: if !cfg.extra_oauth_params().is_empty() {
91107
cfg.extra_oauth_params()
@@ -174,6 +190,27 @@ impl HttpClient {
174190
Ok(auth_res.access_token)
175191
}
176192

193+
/// Set a custom token authenticator.
194+
///
195+
/// When set, the authenticator will be called to get tokens instead of using
196+
/// static tokens or OAuth credentials. This allows for custom token management
197+
/// such as reading from files, APIs, or other custom sources.
198+
pub fn with_authenticator(mut self, authenticator: Arc<dyn CustomAuthenticator>) -> Self {
199+
self.authenticator = Some(authenticator);
200+
self
201+
}
202+
203+
/// Add bearer token to request authorization header.
204+
fn set_bearer_token(req: &mut Request, token: &str, error_msg: &str) -> Result<()> {
205+
req.headers_mut().insert(
206+
http::header::AUTHORIZATION,
207+
format!("Bearer {token}")
208+
.parse()
209+
.map_err(|e| Error::new(ErrorKind::DataInvalid, error_msg).with_source(e))?,
210+
);
211+
Ok(())
212+
}
213+
177214
/// Invalidate the current token without generating a new one. On the next request, the client
178215
/// will attempt to generate a new token.
179216
pub(crate) async fn invalidate_token(&self) -> Result<()> {
@@ -195,18 +232,24 @@ impl HttpClient {
195232

196233
/// Authenticates the request by adding a bearer token to the authorization header.
197234
///
198-
/// This method supports three authentication modes:
235+
/// This method supports four authentication modes (in order of precedence):
199236
///
200-
/// 1. **No authentication** - Skip authentication when both `credential` and `token` are missing.
201-
/// 2. **Token authentication** - Use the provided `token` directly for authentication.
202-
/// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it for authentication.
237+
/// 1. **Custom authenticator** - If set, use the custom CustomAuthenticator to get tokens.
238+
/// 2. **Token authentication** - Use the provided static `token` directly.
239+
/// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it.
240+
/// 4. **No authentication** - Skip authentication when none of the above are available.
203241
///
204-
/// When both `credential` and `token` are present, `token` takes precedence.
205-
///
206-
/// # TODO: Support automatic token refreshing.
242+
/// When an authenticator is provided, it takes precedence over static tokens and credentials.
207243
async fn authenticate(&self, req: &mut Request) -> Result<()> {
244+
// Try authenticator first (highest priority)
245+
if let Some(authenticator) = &self.authenticator {
246+
let token = authenticator.get_token().await?;
247+
Self::set_bearer_token(req, &token, "Invalid custom token")?;
248+
return Ok(());
249+
}
250+
208251
// Clone the token from lock without holding the lock for entire function.
209-
let token = self.token.lock().await.clone();
252+
let token: Option<String> = self.token.lock().await.clone();
210253

211254
if self.credential.is_none() && token.is_none() {
212255
return Ok(());
@@ -224,18 +267,7 @@ impl HttpClient {
224267
}
225268
};
226269

227-
// Insert token in request.
228-
req.headers_mut().insert(
229-
http::header::AUTHORIZATION,
230-
format!("Bearer {token}").parse().map_err(|e| {
231-
Error::new(
232-
ErrorKind::DataInvalid,
233-
"Invalid token received from catalog server!",
234-
)
235-
.with_source(e)
236-
})?,
237-
);
238-
270+
Self::set_bearer_token(req, &token, "Invalid token received from catalog server!")?;
239271
Ok(())
240272
}
241273

crates/catalog/rest/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@ mod client;
5656
mod types;
5757

5858
pub use catalog::*;
59+
pub use client::CustomAuthenticator;

crates/catalog/rest/tests/rest_catalog_test.rs

Lines changed: 143 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919
2020
use std::collections::HashMap;
2121
use std::net::SocketAddr;
22-
use std::sync::RwLock;
22+
use std::sync::{Arc, Mutex, RwLock};
2323

24+
use async_trait::async_trait;
2425
use ctor::{ctor, dtor};
2526
use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
2627
use iceberg::transaction::{ApplyTransactionAction, Transaction};
27-
use iceberg::{Catalog, CatalogBuilder, Namespace, NamespaceIdent, TableCreation, TableIdent};
28-
use iceberg_catalog_rest::{REST_CATALOG_PROP_URI, RestCatalog, RestCatalogBuilder};
28+
use iceberg::{
29+
Catalog, CatalogBuilder, Namespace, NamespaceIdent, Result as IcebergResult, TableCreation,
30+
TableIdent,
31+
};
32+
use iceberg_catalog_rest::{
33+
CustomAuthenticator, REST_CATALOG_PROP_URI, RestCatalog, RestCatalogBuilder,
34+
};
2935
use iceberg_test_utils::docker::DockerCompose;
3036
use iceberg_test_utils::{normalize_test_name, set_up};
3137
use port_scanner::scan_port_addr;
@@ -449,3 +455,137 @@ async fn test_register_table() {
449455
table_registered.identifier().to_string()
450456
);
451457
}
458+
459+
#[derive(Debug)]
460+
struct CountingAuthenticator {
461+
count: Arc<Mutex<usize>>,
462+
}
463+
464+
#[async_trait]
465+
impl CustomAuthenticator for CountingAuthenticator {
466+
async fn get_token(&self) -> IcebergResult<String> {
467+
let mut c = self.count.lock().unwrap();
468+
*c += 1;
469+
// Return a unique token each time to ensure dynamic generation
470+
Ok(format!("token_{}", *c))
471+
}
472+
}
473+
474+
async fn get_catalog_with_authenticator(
475+
authenticator: Arc<dyn CustomAuthenticator>,
476+
) -> RestCatalog {
477+
set_up();
478+
479+
let rest_catalog_ip = {
480+
let guard = DOCKER_COMPOSE_ENV.read().unwrap();
481+
let docker_compose = guard.as_ref().unwrap();
482+
docker_compose.get_container_ip("rest")
483+
};
484+
485+
let rest_socket_addr = SocketAddr::new(rest_catalog_ip, REST_CATALOG_PORT);
486+
while !scan_port_addr(rest_socket_addr) {
487+
info!("Waiting for 1s rest catalog to ready...");
488+
sleep(std::time::Duration::from_millis(1000)).await;
489+
}
490+
491+
RestCatalogBuilder::default()
492+
.with_token_authenticator(authenticator)
493+
.load(
494+
"rest",
495+
HashMap::from([(
496+
REST_CATALOG_PROP_URI.to_string(),
497+
format!("http://{rest_socket_addr}"),
498+
)]),
499+
)
500+
.await
501+
.unwrap()
502+
}
503+
504+
#[tokio::test]
505+
async fn test_authenticator_token_refresh() {
506+
// Track how many times tokens were requested
507+
let token_request_count = Arc::new(Mutex::new(0));
508+
let token_request_count_clone = token_request_count.clone();
509+
510+
let authenticator = Arc::new(CountingAuthenticator {
511+
count: token_request_count_clone,
512+
});
513+
514+
let catalog_with_auth = get_catalog_with_authenticator(authenticator).await;
515+
516+
// Perform multiple operations that should trigger token requests
517+
let ns1 = Namespace::with_properties(
518+
NamespaceIdent::from_strs(["test_refresh_1"]).unwrap(),
519+
HashMap::new(),
520+
);
521+
catalog_with_auth
522+
.create_namespace(ns1.name(), HashMap::new())
523+
.await
524+
.unwrap();
525+
526+
let ns2 = Namespace::with_properties(
527+
NamespaceIdent::from_strs(["test_refresh_2"]).unwrap(),
528+
HashMap::new(),
529+
);
530+
catalog_with_auth
531+
.create_namespace(ns2.name(), HashMap::new())
532+
.await
533+
.unwrap();
534+
535+
// Verify authenticator was called multiple times
536+
let count = *token_request_count.lock().unwrap();
537+
assert!(
538+
count >= 2,
539+
"Authenticator should have been called at least twice, but was called {} times",
540+
count
541+
);
542+
}
543+
544+
#[tokio::test]
545+
async fn test_authenticator_persists_across_operations() {
546+
let operation_count = Arc::new(Mutex::new(0));
547+
let operation_count_clone = operation_count.clone();
548+
549+
let authenticator = Arc::new(CountingAuthenticator {
550+
count: operation_count_clone,
551+
});
552+
553+
let catalog_with_auth = get_catalog_with_authenticator(authenticator).await;
554+
555+
// Create a namespace
556+
let ns = Namespace::with_properties(
557+
NamespaceIdent::from_strs(["test_persist", "auth"]).unwrap(),
558+
HashMap::new(),
559+
);
560+
catalog_with_auth
561+
.create_namespace(ns.name(), HashMap::new())
562+
.await
563+
.unwrap();
564+
565+
let count_after_create = *operation_count.lock().unwrap();
566+
567+
// List the namespace children (should use the same authenticator)
568+
// We need to list children of "test_persist" to find "auth"
569+
let list_result = catalog_with_auth
570+
.list_namespaces(Some(&NamespaceIdent::from_strs(["test_persist"]).unwrap()))
571+
.await
572+
.unwrap();
573+
assert!(
574+
list_result.contains(&NamespaceIdent::from_strs(["test_persist", "auth"]).unwrap()),
575+
"Namespace {:?} not found in list {:?}",
576+
ns.name(),
577+
list_result
578+
);
579+
580+
let count_after_list = *operation_count.lock().unwrap();
581+
582+
// Verify authenticator was used for both operations
583+
assert!(
584+
count_after_create > 0,
585+
"Authenticator should be used for create"
586+
);
587+
assert!(
588+
count_after_list > count_after_create,
589+
"Authenticator should be used for list operation too"
590+
);
591+
}

0 commit comments

Comments
 (0)