diff --git a/docs-website/sidebars.js b/docs-website/sidebars.js index be5232de10240a..dd44358ca3eaa6 100644 --- a/docs-website/sidebars.js +++ b/docs-website/sidebars.js @@ -953,6 +953,7 @@ module.exports = { }, "docs/authentication/introducing-metadata-service-authentication", "docs/authentication/personal-access-tokens", + "docs/authentication/external-oauth-providers", ], }, { diff --git a/docs/authentication/README.md b/docs/authentication/README.md index 4c61c0c020f9d5..5acd2999233960 100644 --- a/docs/authentication/README.md +++ b/docs/authentication/README.md @@ -45,6 +45,9 @@ and programmatic calls to DataHub APIs. There are two types of tokens that are i 2. **Personal Access Tokens**: These are tokens generated via the DataHub settings panel useful for interacting with DataHub APIs. They can be used to automate processes like enriching documentation, ownership, tags, and more on DataHub. Learn more about Personal Access Tokens [here](personal-access-tokens.md). +3. **OAuth Provider Tokens**: JWT tokens issued by external OAuth2/OIDC providers (like Okta, Auth0, Azure AD) can be used + for service-to-service authentication. This enables seamless integration with existing OAuth infrastructure and is ideal + for automated services and applications. Learn more about OAuth Provider authentication [here](external-oauth-providers.md). To learn more about DataHub's backend authentication, check out [Introducing Metadata Service Authentication](introducing-metadata-service-authentication.md). diff --git a/docs/authentication/external-oauth-providers.md b/docs/authentication/external-oauth-providers.md new file mode 100644 index 00000000000000..61619b0f01d380 --- /dev/null +++ b/docs/authentication/external-oauth-providers.md @@ -0,0 +1,255 @@ +# External OAuth Authentication + +DataHub supports authenticating API requests using JWT tokens from external identity providers like Okta, Azure AD, Google Identity, and others. This is perfect for service-to-service authentication where your applications need to call DataHub APIs. + +## Overview + +When you configure OAuth authentication, DataHub will: + +1. Accept JWT tokens from your trusted identity provider +2. Validate the token signature and claims +3. Automatically create service accounts for authenticated users +4. Grant API access based on DataHub's permission system + +## Configuration + +Configure OAuth authentication by setting these environment variables in your DataHub deployment: + +Set these environment variables for the `datahub-gms` service: + +```bash +# Enable OAuth authentication +EXTERNAL_OAUTH_ENABLED=true + +# Required: Trusted JWT issuers (comma-separated) +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://auth.example.com,https://okta.company.com + +# Required: Allowed JWT audiences (comma-separated) +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=datahub-api,my-service-id + +# Required: JWKS endpoint for signature verification +EXTERNAL_OAUTH_JWKS_URI=https://auth.example.com/.well-known/jwks.json + +# Optional: JWT claim containing user ID (default: "sub") +EXTERNAL_OAUTH_USER_ID_CLAIM=sub + +# Optional: Signing algorithm (default: "RS256") +EXTERNAL_OAUTH_ALGORITHM=RS256 +``` + +### Docker Compose Example + +```yaml +version: "3.8" +services: + datahub-gms: + image: acryldata/datahub-gms:latest + environment: + # External OAuth Configuration + - EXTERNAL_OAUTH_ENABLED=true + - EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://my-okta-domain.okta.com/oauth2/default + - EXTERNAL_OAUTH_ALLOWED_AUDIENCES=0oa1234567890abcdef + - EXTERNAL_OAUTH_JWKS_URI=https://my-okta-domain.okta.com/oauth2/default/v1/keys + - EXTERNAL_OAUTH_USER_ID_CLAIM=sub + - EXTERNAL_OAUTH_ALGORITHM=RS256 + + # Standard DataHub settings + - DATAHUB_GMS_HOST=0.0.0.0 + - DATAHUB_GMS_PORT=8080 + # ... other configurations +``` + +### Kubernetes Example + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: datahub-gms +spec: + template: + spec: + containers: + - name: datahub-gms + image: acryldata/datahub-gms:latest + env: + - name: EXTERNAL_OAUTH_ENABLED + value: "true" + - name: EXTERNAL_OAUTH_TRUSTED_ISSUERS + value: "https://login.microsoftonline.com/tenant-id/v2.0" + - name: EXTERNAL_OAUTH_ALLOWED_AUDIENCES + value: "api://datahub-prod" + - name: EXTERNAL_OAUTH_JWKS_URI + value: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys" + # ... other environment variables +``` + +### Multiple Providers + +To support multiple OAuth providers, use comma-separated values: + +```bash +# Multiple issuers and audiences +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://okta.company.com,https://auth0.company.com +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=datahub-prod,datahub-staging,service-account-id + +# Single JWKS URI (if providers share keys) or discovery URI +EXTERNAL_OAUTH_JWKS_URI=https://okta.company.com/.well-known/jwks.json + +# Or use discovery URI to auto-derive JWKS +EXTERNAL_OAUTH_DISCOVERY_URI=https://okta.company.com/.well-known/openid-configuration +``` + +### Discovery URI vs JWKS URI + +You can specify either: + +- **JWKS URI**: Direct endpoint to signing keys (recommended for production) +- **Discovery URI**: OIDC discovery document URL (DataHub will auto-derive JWKS URI) + +```bash +# Option 1: Direct JWKS URI (faster, more reliable) +EXTERNAL_OAUTH_JWKS_URI=https://auth.example.com/.well-known/jwks.json + +# Option 2: Discovery URI (convenient, auto-derives JWKS) +EXTERNAL_OAUTH_DISCOVERY_URI=https://auth.example.com/.well-known/openid-configuration +``` + +## Provider Examples + +### Okta + +```bash +EXTERNAL_OAUTH_ENABLED=true +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://your-domain.okta.com/oauth2/default +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=0oa1234567890abcdef +EXTERNAL_OAUTH_JWKS_URI=https://your-domain.okta.com/oauth2/default/v1/keys +``` + +### Auth0 + +```bash +EXTERNAL_OAUTH_ENABLED=true +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://your-domain.auth0.com/ +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=https://your-api-identifier/ +EXTERNAL_OAUTH_JWKS_URI=https://your-domain.auth0.com/.well-known/jwks.json +``` + +### Azure AD / Microsoft Entra + +```bash +EXTERNAL_OAUTH_ENABLED=true +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://login.microsoftonline.com/your-tenant-id/v2.0 +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=api://your-app-id +EXTERNAL_OAUTH_JWKS_URI=https://login.microsoftonline.com/your-tenant-id/discovery/v2.0/keys +``` + +### Google Cloud Identity + +```bash +EXTERNAL_OAUTH_ENABLED=true +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://accounts.google.com +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=your-client-id.apps.googleusercontent.com +EXTERNAL_OAUTH_JWKS_URI=https://www.googleapis.com/oauth2/v3/certs +``` + +### Keycloak + +```bash +EXTERNAL_OAUTH_ENABLED=true +EXTERNAL_OAUTH_TRUSTED_ISSUERS=https://keycloak.company.com/realms/datahub +EXTERNAL_OAUTH_ALLOWED_AUDIENCES=datahub-client +EXTERNAL_OAUTH_JWKS_URI=https://keycloak.company.com/realms/datahub/protocol/openid-connect/certs +``` + +## Using OAuth Tokens + +Once configured, include your JWT token in the Authorization header when making API requests: + +```bash +curl -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + -H "Content-Type: application/json" \ + https://your-datahub.com/api/graphql \ + -d '{"query": "{ corpUsers { total } }"}' +``` + +For Python applications: + +```python +import requests + +headers = { + 'Authorization': f'Bearer {your_jwt_token}', + 'Content-Type': 'application/json' +} + +response = requests.post( + 'https://your-datahub.com/api/graphql', + headers=headers, + json={'query': '{ corpUsers { total } }'} +) +``` + +## Best Practices + +- Use HTTPS for all JWKS URIs and discovery endpoints +- Use specific audience values (not wildcards) for better security +- Use short-lived tokens (< 1 hour recommended) +- Separate environments with different audiences (prod/staging/dev) +- Enable debug logging during setup: `DATAHUB_GMS_LOG_LEVEL=DEBUG` + +## Troubleshooting + +### Common Issues + +**"OAuth authenticator is not configured"** + +- Make sure `EXTERNAL_OAUTH_ENABLED=true` is set +- Verify all required environment variables are configured + +**"No configured OAuth provider matches token issuer"** + +- Check that your JWT issuer exactly matches `EXTERNAL_OAUTH_TRUSTED_ISSUERS` + +**"Invalid or missing audience claim"** + +- Verify your JWT audience is listed in `EXTERNAL_OAUTH_ALLOWED_AUDIENCES` + +**"Failed to load signing keys"** + +- Test your JWKS URI directly: `curl https://your-provider/.well-known/jwks.json` +- Check network connectivity from DataHub to your OAuth provider + +### Debugging + +Enable debug logging to see detailed OAuth messages: + +```bash +# Set environment variable +DATAHUB_GMS_LOG_LEVEL=DEBUG + +# Check logs +docker logs datahub-gms | grep -i oauth +``` + +### Testing Your Setup + +Decode your JWT token to verify the claims: + +```bash +# Replace with your actual token +echo "YOUR_JWT_TOKEN" | cut -d. -f2 | base64 -d | jq +``` + +Make sure the `iss` (issuer) and `aud` (audience) claims match your configuration. + +## Advanced Options + +You can customize which JWT claim contains the user ID: + +```bash +# Use email claim instead of default 'sub' +EXTERNAL_OAUTH_USER_ID_CLAIM=email +``` + +OAuth users are automatically created as service accounts with usernames like `__oauth_{issuer_domain}_{subject}`. diff --git a/metadata-io/src/main/java/com/linkedin/metadata/service/ServiceAccountService.java b/metadata-io/src/main/java/com/linkedin/metadata/service/ServiceAccountService.java new file mode 100644 index 00000000000000..b97eaa41811770 --- /dev/null +++ b/metadata-io/src/main/java/com/linkedin/metadata/service/ServiceAccountService.java @@ -0,0 +1,244 @@ +package com.linkedin.metadata.service; + +import static com.linkedin.metadata.Constants.*; + +import com.linkedin.common.AuditStamp; +import com.linkedin.common.Origin; +import com.linkedin.common.OriginType; +import com.linkedin.common.SubTypes; +import com.linkedin.common.urn.CorpuserUrn; +import com.linkedin.common.urn.UrnUtils; +import com.linkedin.data.template.RecordTemplate; +import com.linkedin.data.template.StringArray; +import com.linkedin.events.metadata.ChangeType; +import com.linkedin.identity.CorpUserInfo; +import com.linkedin.metadata.aspect.batch.AspectsBatch; +import com.linkedin.metadata.entity.EntityService; +import com.linkedin.metadata.entity.ebean.batch.AspectsBatchImpl; +import com.linkedin.metadata.utils.GenericRecordUtils; +import com.linkedin.mxe.MetadataChangeProposal; +import io.datahubproject.metadata.context.OperationContext; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nonnull; +import lombok.extern.slf4j.Slf4j; + +/** + * Service for managing DataHub service accounts, particularly those created from OAuth/OIDC tokens. + * Handles the creation and management of service account users with proper aspects and metadata. + * + *

Service accounts created by this service are marked with: + * + *

+ */ +@Slf4j +public class ServiceAccountService { + + static final String USER_ID_PREFIX = "__oauth_"; + static final String DEFAULT_USER_CLAIM = "sub"; + + /** + * Creates a unique service account user ID from issuer and subject information. The structure + * ensures uniqueness across IdPs in case migrations happen. + * + * @param issuer The issuer URL from the JWT token + * @param subject The subject (user identifier) from the JWT token + * @return Unique service account user ID + */ + public String buildServiceUserUrn(@Nonnull String issuer, @Nonnull String subject) { + String sanitizedIssuer = issuer.replaceAll("https?://", "").replaceAll("[^a-zA-Z0-9]", "_"); + return String.format("%s%s_%s", USER_ID_PREFIX, sanitizedIssuer, subject); + } + + /** + * Creates a service account with the specified name and origin information. + * + * @param userId The unique service account user ID + * @param displayName The display name for the service account + * @param originType The origin type (e.g., EXTERNAL, NATIVE) + * @param externalType Additional origin information (e.g., issuer URL) + * @param entityService The entity service for persistence + * @param operationContext The operation context for the request + * @return true if the service account was created successfully, false if it already exists + */ + public boolean createServiceAccount( + @Nonnull String userId, + @Nonnull String displayName, + @Nonnull OriginType originType, + @Nonnull String externalType, + @Nonnull EntityService entityService, + @Nonnull OperationContext operationContext) { + + try { + CorpuserUrn userUrn = new CorpuserUrn(userId); + + // Check if user already exists + boolean userExists = entityService.exists(operationContext, userUrn, false); + if (userExists) { + log.debug("Service account user already exists: {}", userUrn); + return false; + } + + log.info("Creating new service account user: {}", userUrn); + + // Create the aspects for the new service account + List aspectsToIngest = + createServiceAccountAspects(userUrn, displayName, originType, externalType); + + // Ingest synchronously to ensure user is immediately available + AspectsBatch aspectsBatch = + AspectsBatchImpl.builder() + .mcps( + aspectsToIngest, createSystemAuditStamp(), operationContext.getRetrieverContext()) + .build(operationContext); + + entityService.ingestAspects(operationContext, aspectsBatch, false, true); + + log.info("Successfully created service account user: {}", userId); + return true; + + } catch (Exception e) { + log.error("Failed to create service account user: {}. Error: {}", userId, e.getMessage()); + throw new RuntimeException("Failed to create service account: " + e.getMessage(), e); + } + } + + /** + * Creates a service account from OAuth/OIDC token information. Ensures that a service account + * user exists in DataHub. If the user doesn't exist, creates a new user with CorpUserInfo, + * SubTypes, and Origin aspects. + * + * @param userId The unique service account user ID + * @param issuer The issuer URL from the JWT token + * @param subject The subject (user identifier) from the JWT token + * @param entityService The entity service for persistence + * @param operationContext The operation context for the request + * @return true if the service account was created, false if it already exists + */ + public boolean ensureServiceAccountExists( + @Nonnull String userId, + @Nonnull String issuer, + @Nonnull String subject, + @Nonnull EntityService entityService, + @Nonnull OperationContext operationContext) { + + try { + CorpuserUrn userUrn = new CorpuserUrn(userId); + + // Check if user already exists + boolean userExists = entityService.exists(operationContext, userUrn, false); + if (userExists) { + log.debug("Service account user already exists: {}", userUrn); + return false; + } + + log.info("Creating new service account user: {}", userUrn); + + String displayName = String.format("Service Account: %s @ %s", subject, issuer); + + // Create the aspects for the new service account + List aspectsToIngest = + createServiceAccountAspects(userUrn, displayName, OriginType.EXTERNAL, issuer); + + // Ingest synchronously to ensure user is immediately available + AspectsBatch aspectsBatch = + AspectsBatchImpl.builder() + .mcps( + aspectsToIngest, createSystemAuditStamp(), operationContext.getRetrieverContext()) + .build(operationContext); + + entityService.ingestAspects(operationContext, aspectsBatch, false, true); + + log.info("Successfully created service account user: {} from issuer: {}", userId, issuer); + return true; + + } catch (Exception e) { + // Don't fail authentication if user creation fails - treat as side-effect + log.error( + "Failed to create service account user: {} from issuer: {}. Error: {}", + userId, + issuer, + e.getMessage()); + return false; + } + } + + /** + * Creates the required aspects for a new service account user. + * + * @param userUrn The URN of the user to create + * @param displayName The display name for the service account + * @param originType The origin type (e.g., EXTERNAL, NATIVE) + * @param externalType Additional origin information (e.g., issuer URL) + * @return List of MetadataChangeProposal objects representing the aspects to ingest + */ + public List createServiceAccountAspects( + @Nonnull CorpuserUrn userUrn, + @Nonnull String displayName, + @Nonnull OriginType originType, + @Nonnull String externalType) { + + List aspects = new ArrayList<>(); + + // 1. CorpUserInfo aspect - basic user information + CorpUserInfo corpUserInfo = new CorpUserInfo(); + corpUserInfo.setActive(true); + corpUserInfo.setDisplayName(displayName); + corpUserInfo.setTitle("OAuth Service Account"); + + aspects.add(createMetadataChangeProposal(userUrn, CORP_USER_INFO_ASPECT_NAME, corpUserInfo)); + + // 2. SubTypes aspect - mark as SERVICE + SubTypes subTypes = new SubTypes(); + StringArray typeNames = new StringArray(); + typeNames.add("SERVICE"); + subTypes.setTypeNames(typeNames); + + aspects.add(createMetadataChangeProposal(userUrn, SUB_TYPES_ASPECT_NAME, subTypes)); + + // 3. Origin aspect - mark with origin information + Origin origin = new Origin(); + origin.setType(originType); + origin.setExternalType(externalType); + + aspects.add(createMetadataChangeProposal(userUrn, ORIGIN_ASPECT_NAME, origin)); + + return aspects; + } + + /** + * Helper method to create a MetadataChangeProposal for an aspect. + * + * @param userUrn The URN of the user + * @param aspectName The name of the aspect + * @param aspect The aspect data + * @return MetadataChangeProposal for the aspect + */ + public MetadataChangeProposal createMetadataChangeProposal( + @Nonnull CorpuserUrn userUrn, @Nonnull String aspectName, @Nonnull RecordTemplate aspect) { + + MetadataChangeProposal mcp = new MetadataChangeProposal(); + mcp.setEntityUrn(userUrn); + mcp.setEntityType(userUrn.getEntityType()); + mcp.setAspectName(aspectName); + mcp.setAspect(GenericRecordUtils.serializeAspect(aspect)); + mcp.setChangeType(ChangeType.UPSERT); + + return mcp; + } + + /** + * Creates an AuditStamp for system-level operations. + * + * @return AuditStamp with system context + */ + public AuditStamp createSystemAuditStamp() { + return new AuditStamp() + .setTime(System.currentTimeMillis()) + .setActor(UrnUtils.getUrn(SYSTEM_ACTOR)); + } +} diff --git a/metadata-io/src/test/java/com/linkedin/metadata/service/ServiceAccountServiceTest.java b/metadata-io/src/test/java/com/linkedin/metadata/service/ServiceAccountServiceTest.java new file mode 100644 index 00000000000000..c8612088db2c8e --- /dev/null +++ b/metadata-io/src/test/java/com/linkedin/metadata/service/ServiceAccountServiceTest.java @@ -0,0 +1,316 @@ +package com.linkedin.metadata.service; + +import static com.linkedin.metadata.Constants.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.testng.Assert.*; + +import com.linkedin.common.OriginType; +import com.linkedin.common.urn.CorpuserUrn; +import com.linkedin.identity.CorpUserInfo; +import com.linkedin.metadata.aspect.batch.AspectsBatch; +import com.linkedin.metadata.entity.EntityService; +import com.linkedin.mxe.MetadataChangeProposal; +import io.datahubproject.metadata.context.OperationContext; +import io.datahubproject.test.metadata.context.TestOperationContexts; +import java.util.List; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class ServiceAccountServiceTest { + + @Mock private EntityService mockEntityService; + + private ServiceAccountService serviceAccountService; + private OperationContext operationContext; + + @BeforeMethod + public void setUp() { + MockitoAnnotations.openMocks(this); + serviceAccountService = new ServiceAccountService(); + operationContext = TestOperationContexts.systemContextNoSearchAuthorization(); + } + + @Test + public void testBuildServiceUserUrn() { + // Arrange + String issuer = "https://auth.example.com"; + String subject = "service-account-123"; + + // Act + String userUrn = serviceAccountService.buildServiceUserUrn(issuer, subject); + + // Assert + assertNotNull(userUrn); + assertTrue(userUrn.startsWith("__oauth_")); + assertTrue(userUrn.contains("auth_example_com")); + assertTrue(userUrn.contains("service-account-123")); + assertEquals(userUrn, "__oauth_auth_example_com_service-account-123"); + } + + @Test + public void testBuildServiceUserUrnWithComplexIssuer() { + // Arrange + String issuer = "https://my-sso.company.com/oauth2"; + String subject = "svc_datahub"; + + // Act + String userUrn = serviceAccountService.buildServiceUserUrn(issuer, subject); + + // Assert + assertNotNull(userUrn); + assertEquals(userUrn, "__oauth_my_sso_company_com_oauth2_svc_datahub"); + } + + @Test + public void testCreateServiceAccountAspects() { + // Arrange + CorpuserUrn userUrn = new CorpuserUrn("test-service-account"); + String displayName = "Test Service Account"; + OriginType originType = OriginType.EXTERNAL; + String externalType = "https://auth.example.com"; + + // Act + List aspects = + serviceAccountService.createServiceAccountAspects( + userUrn, displayName, originType, externalType); + + // Assert + assertNotNull(aspects); + assertEquals(aspects.size(), 3); + + // Verify CorpUserInfo aspect + MetadataChangeProposal corpUserInfoMcp = + aspects.stream() + .filter(mcp -> CORP_USER_INFO_ASPECT_NAME.equals(mcp.getAspectName())) + .findFirst() + .orElse(null); + assertNotNull(corpUserInfoMcp); + assertEquals(corpUserInfoMcp.getEntityUrn(), userUrn); + assertEquals(corpUserInfoMcp.getEntityType(), "corpuser"); + + // Verify SubTypes aspect + MetadataChangeProposal subTypesMcp = + aspects.stream() + .filter(mcp -> SUB_TYPES_ASPECT_NAME.equals(mcp.getAspectName())) + .findFirst() + .orElse(null); + assertNotNull(subTypesMcp); + + // Verify Origin aspect + MetadataChangeProposal originMcp = + aspects.stream() + .filter(mcp -> ORIGIN_ASPECT_NAME.equals(mcp.getAspectName())) + .findFirst() + .orElse(null); + assertNotNull(originMcp); + } + + @Test + public void testCreateServiceAccountSuccess() { + // Arrange + String userId = "__oauth_auth_example_com_service123"; + String displayName = "Service Account: service123 @ https://auth.example.com"; + OriginType originType = OriginType.EXTERNAL; + String externalType = "https://auth.example.com"; + + CorpuserUrn expectedUrn = new CorpuserUrn(userId); + + // Mock that user doesn't exist + when(mockEntityService.exists(eq(operationContext), eq(expectedUrn), eq(false))) + .thenReturn(false); + + // Act + boolean result = + serviceAccountService.createServiceAccount( + userId, displayName, originType, externalType, mockEntityService, operationContext); + + // Assert + assertTrue(result); + verify(mockEntityService, times(1)).exists(eq(operationContext), eq(expectedUrn), eq(false)); + verify(mockEntityService, times(1)) + .ingestAspects(eq(operationContext), any(AspectsBatch.class), eq(false), eq(true)); + } + + @Test + public void testCreateServiceAccountAlreadyExists() { + // Arrange + String userId = "__oauth_auth_example_com_service123"; + String displayName = "Service Account: service123 @ https://auth.example.com"; + OriginType originType = OriginType.EXTERNAL; + String externalType = "https://auth.example.com"; + + CorpuserUrn expectedUrn = new CorpuserUrn(userId); + + // Mock that user already exists + when(mockEntityService.exists(eq(operationContext), eq(expectedUrn), eq(false))) + .thenReturn(true); + + // Act + boolean result = + serviceAccountService.createServiceAccount( + userId, displayName, originType, externalType, mockEntityService, operationContext); + + // Assert + assertFalse(result); + verify(mockEntityService, times(1)).exists(eq(operationContext), eq(expectedUrn), eq(false)); + verify(mockEntityService, never()) + .ingestAspects(eq(operationContext), any(AspectsBatch.class), eq(false), eq(true)); + } + + @Test + public void testEnsureServiceAccountExistsFromTokenInfo() { + // Arrange + String userId = "__oauth_auth_example_com_service123"; + String issuer = "https://auth.example.com"; + String subject = "service123"; + + CorpuserUrn expectedUrn = new CorpuserUrn(userId); + + // Mock that user doesn't exist + when(mockEntityService.exists(eq(operationContext), eq(expectedUrn), eq(false))) + .thenReturn(false); + + // Act + boolean result = + serviceAccountService.ensureServiceAccountExists( + userId, issuer, subject, mockEntityService, operationContext); + + // Assert + assertTrue(result); + verify(mockEntityService, times(1)).exists(eq(operationContext), eq(expectedUrn), eq(false)); + verify(mockEntityService, times(1)) + .ingestAspects(eq(operationContext), any(AspectsBatch.class), eq(false), eq(true)); + } + + @Test + public void testEnsureServiceAccountExistsAlreadyExists() { + // Arrange + String userId = "__oauth_auth_example_com_service123"; + String issuer = "https://auth.example.com"; + String subject = "service123"; + + CorpuserUrn expectedUrn = new CorpuserUrn(userId); + + // Mock that user already exists + when(mockEntityService.exists(eq(operationContext), eq(expectedUrn), eq(false))) + .thenReturn(true); + + // Act + boolean result = + serviceAccountService.ensureServiceAccountExists( + userId, issuer, subject, mockEntityService, operationContext); + + // Assert + assertFalse(result); + verify(mockEntityService, times(1)).exists(eq(operationContext), eq(expectedUrn), eq(false)); + verify(mockEntityService, never()) + .ingestAspects(eq(operationContext), any(AspectsBatch.class), eq(false), eq(true)); + } + + @Test + public void testEnsureServiceAccountExistsHandlesErrors() { + // Arrange + String userId = "__oauth_auth_example_com_service123"; + String issuer = "https://auth.example.com"; + String subject = "service123"; + + CorpuserUrn expectedUrn = new CorpuserUrn(userId); + + // Mock that user doesn't exist + when(mockEntityService.exists(eq(operationContext), eq(expectedUrn), eq(false))) + .thenReturn(false); + + // Mock ingestion failure + doThrow(new RuntimeException("Ingestion failed")) + .when(mockEntityService) + .ingestAspects(eq(operationContext), any(AspectsBatch.class), eq(false), eq(true)); + + // Act + boolean result = + serviceAccountService.ensureServiceAccountExists( + userId, issuer, subject, mockEntityService, operationContext); + + // Assert + assertFalse(result); // Should return false on error but not throw exception + verify(mockEntityService, times(1)).exists(eq(operationContext), eq(expectedUrn), eq(false)); + verify(mockEntityService, times(1)) + .ingestAspects(eq(operationContext), any(AspectsBatch.class), eq(false), eq(true)); + } + + @Test + public void testCreateSystemAuditStamp() { + // Act + var auditStamp = serviceAccountService.createSystemAuditStamp(); + + // Assert + assertNotNull(auditStamp); + assertNotNull(auditStamp.getTime()); + assertNotNull(auditStamp.getActor()); + assertTrue(auditStamp.getTime() > 0); + } + + @Test + public void testCreateMetadataChangeProposal() { + // Arrange + CorpuserUrn userUrn = new CorpuserUrn("test-user"); + CorpUserInfo corpUserInfo = new CorpUserInfo(); + corpUserInfo.setActive(true); + corpUserInfo.setDisplayName("Test User"); + + // Act + MetadataChangeProposal mcp = + serviceAccountService.createMetadataChangeProposal( + userUrn, CORP_USER_INFO_ASPECT_NAME, corpUserInfo); + + // Assert + assertNotNull(mcp); + assertEquals(mcp.getEntityUrn(), userUrn); + assertEquals(mcp.getEntityType(), "corpuser"); + assertEquals(mcp.getAspectName(), CORP_USER_INFO_ASPECT_NAME); + assertNotNull(mcp.getAspect()); + assertNotNull(mcp.getChangeType()); + } + + @Test + public void testUserIdUniquenessAcrossIssuers() { + // Arrange + String issuer1 = "https://auth.company1.com"; + String issuer2 = "https://auth.company2.com"; + String subject = "service-account"; + + // Act + String userId1 = serviceAccountService.buildServiceUserUrn(issuer1, subject); + String userId2 = serviceAccountService.buildServiceUserUrn(issuer2, subject); + + // Assert + assertNotNull(userId1); + assertNotNull(userId2); + assertNotEquals(userId1, userId2); + assertTrue(userId1.contains("auth_company1_com")); + assertTrue(userId2.contains("auth_company2_com")); + } + + @Test + public void testIssuerSanitization() { + // Test various issuer formats are properly sanitized + String subject = "test"; + + // Test HTTPS URL + String issuer1 = "https://auth.example.com/oauth2/v1"; + String result1 = serviceAccountService.buildServiceUserUrn(issuer1, subject); + assertTrue(result1.contains("auth_example_com_oauth2_v1")); + + // Test HTTP URL + String issuer2 = "http://localhost:8080/auth"; + String result2 = serviceAccountService.buildServiceUserUrn(issuer2, subject); + assertTrue(result2.contains("localhost_8080_auth")); + + // Test special characters + String issuer3 = "https://auth-server.example.com:443/oauth2"; + String result3 = serviceAccountService.buildServiceUserUrn(issuer3, subject); + assertTrue(result3.contains("auth_server_example_com_443_oauth2")); + } +} diff --git a/metadata-io/src/test/java/com/linkedin/metadata/system_info/collectors/PropertiesCollectorConfigurationTest.java b/metadata-io/src/test/java/com/linkedin/metadata/system_info/collectors/PropertiesCollectorConfigurationTest.java index f9c3b100472252..bfb543d008ccb8 100644 --- a/metadata-io/src/test/java/com/linkedin/metadata/system_info/collectors/PropertiesCollectorConfigurationTest.java +++ b/metadata-io/src/test/java/com/linkedin/metadata/system_info/collectors/PropertiesCollectorConfigurationTest.java @@ -133,7 +133,23 @@ public PropertiesCollector propertiesCollector(Environment environment) { // IAM authentication flags "*.postgresUseIamAuth", - "*.opensearchUseAwsIamAuth"); + "*.opensearchUseAwsIamAuth", + + // Bulk rules + "featureFlags.*", + "*.*nabled", + "*.*.*nabled", + "*.*.*.*nabled", + "*.*.*.*.*nabled", + "*.consumerGroupSuffix", + "*.*.consumerGroupSuffix", + "*.*.*.consumerGroupSuffix", + "authentication.authenticators[*].configs.trustedIssuers", + "authentication.authenticators[*].configs.allowedAudiences", + "authentication.authenticators[*].configs.jwksUri", + "authentication.authenticators[*].configs.userIdClaim", + "authentication.authenticators[*].configs.algorithm", + "authentication.authenticators[*].configs.discoveryUri"); /** * Property keys that should NOT be redacted. Add new non-sensitive properties here when they are diff --git a/metadata-models/src/main/pegasus/com/linkedin/settings/global/GlobalSettingsInfo.pdl b/metadata-models/src/main/pegasus/com/linkedin/settings/global/GlobalSettingsInfo.pdl index 411f1e6c15eaf1..9268441544acf2 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/settings/global/GlobalSettingsInfo.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/settings/global/GlobalSettingsInfo.pdl @@ -13,6 +13,11 @@ record GlobalSettingsInfo { */ sso: optional SsoSettings + /** + * Settings related to the oauth authentication provider + */ + oauth: optional OAuthSettings + /** * Settings related to the Views Feature */ diff --git a/metadata-models/src/main/pegasus/com/linkedin/settings/global/OAuthProvider.pdl b/metadata-models/src/main/pegasus/com/linkedin/settings/global/OAuthProvider.pdl new file mode 100644 index 00000000000000..edd8d4e0d32b21 --- /dev/null +++ b/metadata-models/src/main/pegasus/com/linkedin/settings/global/OAuthProvider.pdl @@ -0,0 +1,38 @@ +namespace com.linkedin.settings.global + +/** + * An OAuth Provider. This provides information required to validate inbound + * requests with OAuth 2.0 bearer tokens. + */ +record OAuthProvider { + /** + * Whether this OAuth provider is enabled. + */ + enabled: boolean + /** + * The name of this OAuth provider. This is used for display purposes only. + */ + name: string + /** + * The URI of the JSON Web Key Set (JWKS) endpoint for this OAuth provider. + */ + jwksUri: optional string + /** + * The expected issuer (iss) claim in the JWTs issued by this OAuth provider. + */ + issuer: string + /** + * The expected audience (aud) claim in the JWTs issued by this OAuth provider. + */ + audience: string + /** + * The JWT signing algorithm required for this provider. + * Prevents algorithm confusion attacks. Common values: RS256, RS384, RS512, PS256, ES256 + */ + algorithm: string = "RS256" + /** + * The JWT claim to use as the user identifier for this provider. + * Different providers use different claims (sub, email, preferred_username, etc.) + */ + userIdClaim: string = "sub" +} diff --git a/metadata-models/src/main/pegasus/com/linkedin/settings/global/OAuthSettings.pdl b/metadata-models/src/main/pegasus/com/linkedin/settings/global/OAuthSettings.pdl new file mode 100644 index 00000000000000..b85e767e385a4a --- /dev/null +++ b/metadata-models/src/main/pegasus/com/linkedin/settings/global/OAuthSettings.pdl @@ -0,0 +1,11 @@ +namespace com.linkedin.settings.global + +/** + * Trust oauth providers to use for authentication. + */ +record OAuthSettings { + /** + * Trusted OAuth Providers + */ + providers: array[OAuthProvider] +} diff --git a/metadata-service/auth-filter/src/main/java/com/datahub/auth/authentication/filter/AuthenticationExtractionFilter.java b/metadata-service/auth-filter/src/main/java/com/datahub/auth/authentication/filter/AuthenticationExtractionFilter.java index 6581eaccc62337..101327253efbf0 100644 --- a/metadata-service/auth-filter/src/main/java/com/datahub/auth/authentication/filter/AuthenticationExtractionFilter.java +++ b/metadata-service/auth-filter/src/main/java/com/datahub/auth/authentication/filter/AuthenticationExtractionFilter.java @@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableMap; import com.linkedin.gms.factory.config.ConfigurationProvider; import com.linkedin.metadata.entity.EntityService; +import io.datahubproject.metadata.context.OperationContext; import jakarta.inject.Named; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -84,6 +85,10 @@ public class AuthenticationExtractionFilter extends OncePerRequestFilter { @Named("dataHubTokenService") private StatefulTokenService _tokenService; + @Autowired + @Named("systemOperationContext") + private OperationContext _systemOperationContext; + @Value("#{new Boolean('${authentication.logAuthenticatorExceptions}')}") private boolean _logAuthenticatorExceptions; @@ -114,7 +119,12 @@ private void buildAuthenticatorChain() { final AuthenticatorContext authenticatorContext = new AuthenticatorContext( ImmutableMap.of( - ENTITY_SERVICE, this._entityService, TOKEN_SERVICE, this._tokenService)); + ENTITY_SERVICE, + this._entityService, + TOKEN_SERVICE, + this._tokenService, + "systemOperationContext", + this._systemOperationContext)); if (isAuthEnabled) { log.info("Auth is enabled. Building extraction authenticator chain..."); diff --git a/metadata-service/auth-filter/src/test/java/com/datahub/auth/authentication/AuthExtractionTestConfiguration.java b/metadata-service/auth-filter/src/test/java/com/datahub/auth/authentication/AuthExtractionTestConfiguration.java index 9ff534abb72331..2c2bb53a75f192 100644 --- a/metadata-service/auth-filter/src/test/java/com/datahub/auth/authentication/AuthExtractionTestConfiguration.java +++ b/metadata-service/auth-filter/src/test/java/com/datahub/auth/authentication/AuthExtractionTestConfiguration.java @@ -12,6 +12,7 @@ import com.linkedin.metadata.config.DataHubConfiguration; import com.linkedin.metadata.config.PluginConfiguration; import com.linkedin.metadata.entity.EntityService; +import io.datahubproject.metadata.context.OperationContext; import io.datahubproject.test.metadata.context.TestOperationContexts; import java.util.List; import java.util.Map; @@ -34,6 +35,11 @@ public EntityService entityService() { return mock(EntityService.class); } + @Bean("systemOperationContext") + public OperationContext systemOperationContext() { + return TestOperationContexts.systemContextNoSearchAuthorization(); + } + @Bean("dataHubTokenService") public StatefulTokenService statefulTokenService( ConfigurationProvider configurationProvider, EntityService entityService) { @@ -83,7 +89,12 @@ public ConfigurationProvider configurationProvider() { } @Bean - @DependsOn({"configurationProvider", "dataHubTokenService", "entityService"}) + @DependsOn({ + "configurationProvider", + "dataHubTokenService", + "entityService", + "systemOperationContext" + }) public AuthenticationExtractionFilter authenticationExtractionFilter() throws ServletException { return new AuthenticationExtractionFilter(); } diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/DataHubOAuthAuthenticator.java b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/DataHubOAuthAuthenticator.java new file mode 100644 index 00000000000000..333fe1a07c3de2 --- /dev/null +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/DataHubOAuthAuthenticator.java @@ -0,0 +1,331 @@ +package com.datahub.authentication.authenticator; + +import static com.datahub.authentication.AuthenticationConstants.*; +import static com.linkedin.metadata.Constants.*; + +import com.datahub.authentication.Actor; +import com.datahub.authentication.ActorType; +import com.datahub.authentication.Authentication; +import com.datahub.authentication.AuthenticationException; +import com.datahub.authentication.AuthenticationRequest; +import com.datahub.authentication.AuthenticatorContext; +import com.datahub.authentication.token.DataHubOAuthSigningKeyResolver; +import com.datahub.plugins.auth.authentication.Authenticator; +import com.linkedin.common.OriginType; +import com.linkedin.common.urn.CorpuserUrn; +import com.linkedin.metadata.entity.EntityService; +import com.linkedin.metadata.service.ServiceAccountService; +import com.linkedin.mxe.MetadataChangeProposal; +import com.linkedin.settings.global.OAuthProvider; +import io.datahubproject.metadata.context.OperationContext; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jws; +import io.jsonwebtoken.Jwts; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.extern.slf4j.Slf4j; + +/** + * Authenticator that validates OAuth2 / OIDC JWT tokens using a unified OAuth provider approach. + * Supports both static configuration (application.yaml) and dynamic configuration (GlobalSettings). + * Primary use case is to authenticate service accounts via OAuth tokens (not individual users). + * + *

Static Configuration (application.yaml): + * + *

    + *
  • trustedIssuers: List of trusted JWT issuers (must match the 'iss' claim in JWT tokens) + *
  • allowedAudiences: List of allowed JWT audiences (must match the 'aud' claim in JWT tokens) + *
  • jwksUri: URI to fetch JWT signing keys, OR discoveryUri to auto-derive JWKS URI + *
  • Optional: userIdClaim (defaults to "sub") + *
  • Optional: algorithm (defaults to "RS256") + *
+ * + *

Dynamic Configuration (GlobalSettings): OAuth providers can also be + * configured dynamically through GlobalSettings.oauth.providers. Dynamic providers are refreshed + * automatically every minute. + * + *

Unified Provider Chain: The authenticator maintains a unified list of OAuth + * providers (static + dynamic) and validates tokens by finding the first matching provider based on + * issuer and audience claims. + * + *

This authenticator creates service account users automatically in DataHub when they first + * authenticate, with proper SubTypes (SERVICE) and Origin aspects. + */ +@Slf4j +public class DataHubOAuthAuthenticator implements Authenticator { + + static final String USER_ID_PREFIX = "__oauth_"; + static final String DEFAULT_USER_CLAIM = "sub"; + + // Configuration fields + private EntityService entityService; + private OperationContext systemOperationContext; + private String userIdClaim; + private String algorithm; + + // Service dependencies + private ServiceAccountService serviceAccountService; + private OAuthConfigurationFetcher configurationFetcher; + + @Override + public void init( + @Nonnull final Map config, @Nullable final AuthenticatorContext context) { + Objects.requireNonNull(config, "Config parameter cannot be null"); + Objects.requireNonNull(context, "Context parameter cannot be null"); + + // Check if OAuth authentication is enabled + boolean enabled = Boolean.parseBoolean(config.getOrDefault("enabled", "false").toString()); + if (!enabled) { + log.info("OAuth authentication is disabled via configuration. Skipping initialization."); + return; + } + + // Get EntityService from context + if (!context.data().containsKey(ENTITY_SERVICE)) { + throw new IllegalArgumentException( + "Unable to initialize DataHubOAuthAuthenticator, entity service reference not found."); + } + final Object entityServiceObj = context.data().get(ENTITY_SERVICE); + if (!(entityServiceObj instanceof EntityService)) { + throw new RuntimeException( + "Unable to initialize DataHubOAuthAuthenticator, entity service reference is not of type: " + + "EntityService.class, found: " + + entityServiceObj.getClass()); + } + this.entityService = (EntityService) entityServiceObj; + + // Get system operation context + if (!context.data().containsKey("systemOperationContext")) { + throw new IllegalArgumentException( + "Unable to initialize DataHubOAuthAuthenticator, system operation context not found."); + } + final Object systemOpContextObj = context.data().get("systemOperationContext"); + if (!(systemOpContextObj instanceof OperationContext)) { + throw new RuntimeException( + "Unable to initialize DataHubOAuthAuthenticator, system operation context is not of type: " + + "OperationContext.class, found: " + + systemOpContextObj.getClass()); + } + this.systemOperationContext = (OperationContext) systemOpContextObj; + + // Initialize services + this.serviceAccountService = new ServiceAccountService(); + this.configurationFetcher = new OAuthConfigurationFetcher(); + + // Load static configuration + loadStaticConfiguration(config); + } + + private void loadStaticConfiguration(@Nonnull final Map config) { + try { + log.debug("Loading OAuth settings from static configuration"); + + // Load basic settings + this.userIdClaim = (String) config.getOrDefault("userIdClaim", DEFAULT_USER_CLAIM); + this.algorithm = (String) config.getOrDefault("algorithm", "RS256"); + + // Initialize configuration fetcher + this.configurationFetcher.initialize(config, entityService, systemOperationContext); + + // Log initial configuration status + if (this.configurationFetcher.isConfigured()) { + List providers = this.configurationFetcher.getCachedConfiguration(); + for (OAuthProvider provider : providers) { + log.debug( + "OAuth Provider - Name: {}, Issuer: {}, Audience: {}, JWKS URI: {}", + provider.getName(), + provider.getIssuer(), + provider.getAudience(), + provider.getJwksUri()); + } + } else { + log.warn( + "OAuth authenticator configuration incomplete. Please provide trustedIssuers, allowedAudiences, " + + "and either jwksUri or discoveryUri in application.yaml, or configure OAuth providers in GlobalSettings."); + } + + } catch (Exception e) { + log.error("Failed to load OAuth static configuration", e); + } + } + + @Override + public Authentication authenticate(@Nonnull AuthenticationRequest context) + throws AuthenticationException { + Objects.requireNonNull(context); + + // Check if the authenticator is properly configured + if (configurationFetcher == null || !configurationFetcher.isConfigured()) { + throw new AuthenticationException( + "OAuth authenticator is not configured. Please configure either SSO settings in GlobalSettings or provide static configuration."); + } + + try { + String jwtToken = context.getRequestHeaders().get(AUTHORIZATION_HEADER_NAME); + + log.info("Request headers are: {}", context.getRequestHeaders()); + + if (jwtToken == null + || (!jwtToken.startsWith("Bearer ") && !jwtToken.startsWith("bearer "))) { + throw new AuthenticationException("Invalid Authorization header"); + } + + String token = getToken(jwtToken); + + // Parse JWT to extract issuer and audience (without signature verification) + String[] tokenParts = token.split("\\."); + if (tokenParts.length != 3) { + throw new AuthenticationException("Invalid JWT token format"); + } + + String payload = new String(java.util.Base64.getUrlDecoder().decode(tokenParts[1])); + com.fasterxml.jackson.databind.JsonNode payloadJson = + new com.fasterxml.jackson.databind.ObjectMapper().readTree(payload); + + String issuer = payloadJson.has("iss") ? payloadJson.get("iss").asText() : null; + if (issuer == null) { + throw new AuthenticationException("Missing issuer claim in JWT token"); + } + + // Get audience(s) from token + List audiences = new ArrayList<>(); + if (payloadJson.has("aud")) { + com.fasterxml.jackson.databind.JsonNode audNode = payloadJson.get("aud"); + if (audNode.isArray()) { + audNode.forEach(node -> audiences.add(node.asText())); + } else { + audiences.add(audNode.asText()); + } + } + + if (audiences.isEmpty()) { + throw new AuthenticationException("Missing audience claim in JWT token"); + } + + // Find matching OAuth provider from configuration fetcher + OAuthProvider matchingProvider = configurationFetcher.findMatchingProvider(issuer, audiences); + if (matchingProvider == null) { + throw new AuthenticationException( + "No configured OAuth provider matches token issuer '" + + issuer + + "' and audiences " + + audiences); + } + + log.debug( + "Using OAuth provider '{}' to validate token from issuer '{}'", + matchingProvider.getName(), + issuer); + + // Get provider-specific algorithm and userIdClaim, falling back to global defaults + String providerAlgorithm = matchingProvider.getAlgorithm().trim(); + String providerUserIdClaim = matchingProvider.getUserIdClaim().trim(); + + log.debug( + "Using algorithm '{}' and userIdClaim '{}' for provider '{}'", + providerAlgorithm, + providerUserIdClaim, + matchingProvider.getName()); + + // Validate JWT signature using the matching provider's JWKS + HashSet trustedIssuers = new HashSet<>(); + trustedIssuers.add(issuer); + + Jws claims = + Jwts.parserBuilder() + .setSigningKeyResolver( + new DataHubOAuthSigningKeyResolver( + trustedIssuers, matchingProvider.getJwksUri(), providerAlgorithm)) + .build() + .parseClaimsJws(token); + + Claims body = claims.getBody(); + + // Extract subject (userIdClaim) + final String subject = body.get(providerUserIdClaim, String.class); + if (subject == null) { + throw new AuthenticationException("Missing required claim: " + providerUserIdClaim); + } + + // Build unique service account user ID + final String userId = serviceAccountService.buildServiceUserUrn(issuer, subject); + + // Ensure service account exists in DataHub (create if needed) + serviceAccountService.ensureServiceAccountExists( + userId, issuer, subject, entityService, systemOperationContext); + + // TODO: distinguish USER vs SERVICE based on scope or custom claim + ActorType actorType = ActorType.USER; + return new Authentication(new Actor(actorType, userId), jwtToken); + } catch (Exception e) { + throw new AuthenticationException("OAuth token validation failed: " + e.getMessage()); + } + } + + private String getToken(String jwtToken) { + var tokenArray = jwtToken.split(" "); + return tokenArray.length == 1 ? tokenArray[0] : tokenArray[1]; + } + + /** Delegation method for test compatibility. */ + private String buildServiceUserUrn(Claims body) { + String issuer = body.getIssuer(); + String subject = body.get(userIdClaim, String.class); + return serviceAccountService.buildServiceUserUrn(issuer, subject); + } + + /** Delegation method for test compatibility. */ + private void ensureServiceAccountExists(@Nonnull String userId, @Nonnull Claims claims) { + String issuer = claims.getIssuer(); + String subject = claims.get(userIdClaim, String.class); + serviceAccountService.ensureServiceAccountExists( + userId, issuer, subject, entityService, systemOperationContext); + } + + /** Delegation method for test compatibility. */ + private List createServiceAccountAspects( + @Nonnull CorpuserUrn userUrn, @Nonnull Claims claims) { + String issuer = claims.getIssuer(); + String subject = claims.get(userIdClaim, String.class); + String displayName = String.format("Service Account: %s @ %s", subject, issuer); + return serviceAccountService.createServiceAccountAspects( + userUrn, displayName, OriginType.EXTERNAL, issuer); + } + + /** + * Returns the current OAuth provider configuration. This method is provided for backward + * compatibility with existing tests. + * + * @return List of OAuth providers + */ + public List getOAuthProviders() { + return configurationFetcher != null + ? configurationFetcher.getCachedConfiguration() + : new ArrayList<>(); + } + + /** + * Forces a refresh of the OAuth provider configuration from GlobalSettings. This method is + * primarily intended for testing scenarios where dynamic configuration changes need to be applied + * immediately. + * + * @return List of OAuth providers after forced refresh + */ + public List forceRefreshOAuthProviders() { + return configurationFetcher != null + ? configurationFetcher.forceRefreshConfiguration() + : new ArrayList<>(); + } + + /** Cleanup method to shutdown the scheduler when the authenticator is destroyed. */ + public void destroy() { + if (this.configurationFetcher != null) { + this.configurationFetcher.destroy(); + } + } +} diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/JwksUriResolver.java b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/JwksUriResolver.java new file mode 100644 index 00000000000000..3b267ae3c3cea3 --- /dev/null +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/JwksUriResolver.java @@ -0,0 +1,100 @@ +package com.datahub.authentication.authenticator; + +import lombok.extern.slf4j.Slf4j; + +/** + * Utility class for resolving JWKS URIs from OAuth/OIDC discovery endpoints. Handles both standard + * OIDC discovery and fallback scenarios. + */ +@Slf4j +public class JwksUriResolver { + + /** + * Derives JWKS URI from a discovery URI by fetching the discovery document. Falls back to + * standard patterns if discovery document is unavailable. + * + * @param discoveryUri The OIDC discovery URI or base URL + * @return The JWKS URI for fetching signing keys + */ + public static String deriveJwksUriFromDiscoveryUri(String discoveryUri) { + try { + // Handle different formats of discovery URIs + String discoveryEndpoint = discoveryUri.trim(); + + // Remove trailing slash + if (discoveryEndpoint.endsWith("/")) { + discoveryEndpoint = discoveryEndpoint.substring(0, discoveryEndpoint.length() - 1); + } + + // If it's not a full discovery endpoint, construct one + if (!discoveryEndpoint.endsWith("/.well-known/openid-configuration")) { + discoveryEndpoint = discoveryEndpoint + "/.well-known/openid-configuration"; + } + + log.debug("Fetching discovery document from: {}", discoveryEndpoint); + + // Fetch the discovery document to get the actual JWKS URI + java.net.http.HttpClient client = java.net.http.HttpClient.newHttpClient(); + java.net.http.HttpRequest request = + java.net.http.HttpRequest.newBuilder() + .uri(java.net.URI.create(discoveryEndpoint)) + .build(); + + java.net.http.HttpResponse response = + client.send(request, java.net.http.HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() != 200) { + log.warn( + "Failed to fetch discovery document from {}, status: {}", + discoveryEndpoint, + response.statusCode()); + // Fallback to standard pattern + return deriveJwksUriFallback(discoveryUri); + } + + // Parse the discovery document + org.json.JSONObject discoveryDoc = new org.json.JSONObject(response.body()); + + if (discoveryDoc.has("jwks_uri")) { + String jwksUri = discoveryDoc.getString("jwks_uri"); + log.debug("Found JWKS URI in discovery document: {}", jwksUri); + return jwksUri; + } else { + log.warn("No jwks_uri found in discovery document from {}", discoveryEndpoint); + return deriveJwksUriFallback(discoveryUri); + } + + } catch (Exception e) { + log.error("Failed to fetch discovery document from {}: {}", discoveryUri, e.getMessage()); + log.debug("Discovery document fetch error details", e); + // Fallback to standard pattern + return deriveJwksUriFallback(discoveryUri); + } + } + + /** + * Provides a fallback JWKS URI using standard OIDC patterns when discovery is unavailable. + * + * @param discoveryUri The original discovery URI or base URL + * @return Fallback JWKS URI using standard patterns + */ + public static String deriveJwksUriFallback(String discoveryUri) { + // Fallback to standard OIDC pattern when discovery document is unavailable + String baseUri = discoveryUri.trim(); + + // Remove trailing slash + if (baseUri.endsWith("/")) { + baseUri = baseUri.substring(0, baseUri.length() - 1); + } + + // If it's a full discovery endpoint, derive base + if (baseUri.endsWith("/.well-known/openid-configuration")) { + baseUri = baseUri.replace("/.well-known/openid-configuration", ""); + } + + // Standard OIDC JWKS endpoint + String fallbackUri = baseUri + "/.well-known/jwks.json"; + log.debug("Using fallback JWKS URI: {}", fallbackUri); + return fallbackUri; + } +} diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/OAuthConfigurationFetcher.java b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/OAuthConfigurationFetcher.java new file mode 100644 index 00000000000000..b9a414b0b4a0c3 --- /dev/null +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/authenticator/OAuthConfigurationFetcher.java @@ -0,0 +1,383 @@ +package com.datahub.authentication.authenticator; + +import static com.linkedin.metadata.Constants.*; + +import com.linkedin.metadata.entity.EntityService; +import com.linkedin.settings.global.GlobalSettingsInfo; +import com.linkedin.settings.global.OAuthProvider; +import com.linkedin.settings.global.OAuthSettings; +import io.datahubproject.metadata.context.OperationContext; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.extern.slf4j.Slf4j; + +/** + * Service for fetching and managing OAuth provider configurations from both static configuration + * and dynamic GlobalSettings. Provides caching and background refresh capabilities. + * + *

This service maintains a unified list of OAuth providers from: + * + *

    + *
  • Static configuration from application.yaml + *
  • Dynamic configuration from GlobalSettings entity + *
+ * + *

The service automatically refreshes dynamic configuration on a scheduled interval and provides + * methods for both cached access and forced refresh. + */ +@Slf4j +public class OAuthConfigurationFetcher { + + private static final int REFRESH_INTERVAL_MINUTES = 1; + + // Configuration dependencies + private EntityService entityService; + private OperationContext systemOperationContext; + + // Unified OAuth provider configuration (static + dynamic) + private volatile List oauthProviders = new ArrayList<>(); + private ScheduledExecutorService scheduler; + private volatile boolean isConfigured = false; + + /** + * Initializes the configuration fetcher with dependencies and loads initial configuration. + * + * @param staticConfig Static configuration from application.yaml + * @param entityService EntityService for accessing GlobalSettings + * @param systemOperationContext System operation context for entity operations + */ + public void initialize( + @Nonnull Map staticConfig, + @Nonnull EntityService entityService, + @Nonnull OperationContext systemOperationContext) { + + this.entityService = entityService; + this.systemOperationContext = systemOperationContext; + + // Initialize unified provider list with static configuration + this.oauthProviders = new ArrayList<>(); + createStaticOAuthProviders(staticConfig); + + // Load initial dynamic configuration from GlobalSettings and merge + loadDynamicConfiguration(); + + // Set up scheduled refresh of dynamic configuration + setupDynamicConfigurationRefresh(); + + // Validate final configuration + this.isConfigured = validateConfiguration(); + + if (this.isConfigured) { + log.info( + "OAuth configuration fetcher initialized with {} OAuth provider(s)", + this.oauthProviders.size()); + for (OAuthProvider provider : this.oauthProviders) { + log.debug( + "OAuth Provider - Name: {}, Issuer: {}, Audience: {}, JWKS URI: {}", + provider.getName(), + provider.getIssuer(), + provider.getAudience(), + provider.getJwksUri()); + } + } else { + log.warn( + "OAuth configuration incomplete. Please provide trustedIssuers, allowedAudiences, " + + "and either jwksUri or discoveryUri in application.yaml, or configure OAuth providers in GlobalSettings."); + } + } + + /** + * Returns the cached OAuth provider configuration. + * + * @return List of OAuth providers from both static and dynamic sources + */ + public List getCachedConfiguration() { + return new ArrayList<>(this.oauthProviders); + } + + /** + * Forces a refresh of dynamic configuration from GlobalSettings and returns the updated + * configuration. + * + * @return List of OAuth providers after forced refresh + */ + public List forceRefreshConfiguration() { + log.debug("Forcing refresh of OAuth configuration"); + loadDynamicConfiguration(); + return getCachedConfiguration(); + } + + /** + * Returns whether the configuration fetcher has valid OAuth providers configured. + * + * @return true if at least one enabled OAuth provider is configured + */ + public boolean isConfigured() { + return this.isConfigured; + } + + /** + * Finds a matching OAuth provider from the unified list based on issuer and audience. + * + * @param issuer The JWT issuer to match + * @param audiences List of JWT audiences to match + * @return Matching OAuth provider or null if no match found + */ + @Nullable + public OAuthProvider findMatchingProvider( + @Nonnull String issuer, @Nonnull List audiences) { + for (OAuthProvider provider : this.oauthProviders) { + // Skip disabled providers + if (!Boolean.TRUE.equals(provider.data().get("enabled"))) { + continue; + } + + // Check if issuer matches + if (!issuer.equals(provider.getIssuer())) { + continue; + } + + // Check if any token audience matches the provider's audience + String providerAudience = provider.getAudience(); + if (audiences.contains(providerAudience)) { + log.debug( + "Found matching provider '{}' for issuer '{}' and audience '{}'", + provider.getName(), + issuer, + providerAudience); + return provider; + } + } + + log.debug("No matching provider found for issuer '{}' and audiences {}", issuer, audiences); + return null; + } + + /** Cleanup method to shutdown the scheduler when the fetcher is destroyed. */ + public void destroy() { + if (this.scheduler != null && !this.scheduler.isShutdown()) { + this.scheduler.shutdown(); + try { + if (!this.scheduler.awaitTermination(5, TimeUnit.SECONDS)) { + this.scheduler.shutdownNow(); + } + } catch (InterruptedException e) { + this.scheduler.shutdownNow(); + Thread.currentThread().interrupt(); + } + log.info("OAuth configuration refresh scheduler shutdown"); + } + } + + /** + * Creates static OAuth providers from application.yaml configuration. + * + * @param config Static configuration map + */ + private void createStaticOAuthProviders(@Nonnull Map config) { + // Load trusted issuers + List trustedIssuers = new ArrayList<>(); + if (config.containsKey("trustedIssuers")) { + String issuersStr = (String) config.get("trustedIssuers"); + if (issuersStr != null && !issuersStr.trim().isEmpty()) { + trustedIssuers = Arrays.asList(issuersStr.split(",")); + } + } + + // Load allowed audiences + List allowedAudiences = new ArrayList<>(); + if (config.containsKey("allowedAudiences")) { + String audiencesStr = (String) config.get("allowedAudiences"); + if (audiencesStr != null && !audiencesStr.trim().isEmpty()) { + allowedAudiences = Arrays.asList(audiencesStr.split(",")); + } + } + + // Load JWKS URI - either directly or derive from discoveryUri + String jwksUri = (String) config.get("jwksUri"); + if (jwksUri == null || jwksUri.trim().isEmpty()) { + String discoveryUri = (String) config.get("discoveryUri"); + if (discoveryUri != null && !discoveryUri.trim().isEmpty()) { + jwksUri = JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + } + } + + // Load optional algorithm and userIdClaim for static providers + String staticAlgorithm = (String) config.getOrDefault("algorithm", "RS256"); + String staticUserIdClaim = (String) config.getOrDefault("userIdClaim", "sub"); + + // Only create static providers if we have sufficient configuration + if (!trustedIssuers.isEmpty() + && !allowedAudiences.isEmpty() + && jwksUri != null + && !jwksUri.trim().isEmpty()) { + // Create a provider for each issuer/audience combination + for (String issuer : trustedIssuers) { + for (String audience : allowedAudiences) { + OAuthProvider provider = new OAuthProvider(); + provider.setName("static_" + issuer.replaceAll("[^a-zA-Z0-9]", "_")); + provider.setIssuer(issuer.trim()); + provider.setAudience(audience.trim()); + provider.setJwksUri(jwksUri.trim()); + provider.setAlgorithm(staticAlgorithm); + provider.setUserIdClaim(staticUserIdClaim); + provider.setEnabled(true); + this.oauthProviders.add(provider); + } + } + log.info( + "Created {} static OAuth provider(s) from configuration", this.oauthProviders.size()); + } else { + log.debug("No valid static OAuth configuration found - static providers not created"); + } + } + + /** Loads dynamic OAuth provider configuration from GlobalSettings. */ + private void loadDynamicConfiguration() { + try { + log.debug("Loading dynamic OAuth configuration from GlobalSettings"); + + GlobalSettingsInfo globalSettings = getGlobalSettings(); + if (globalSettings == null || !globalSettings.hasOauth()) { + log.debug("No OAuth settings found in GlobalSettings"); + removeDynamicProviders(); + return; + } + + OAuthSettings oauthSettings = globalSettings.getOauth(); + if (!oauthSettings.hasProviders() || oauthSettings.getProviders().isEmpty()) { + log.debug("No OAuth providers configured in GlobalSettings"); + removeDynamicProviders(); + return; + } + + // Remove existing dynamic providers to refresh them + removeDynamicProviders(); + + // Add enabled dynamic providers to the unified list + List enabledProviders = new ArrayList<>(); + for (OAuthProvider provider : oauthSettings.getProviders()) { + if (Boolean.TRUE.equals(provider.isEnabled())) { + enabledProviders.add(provider); + this.oauthProviders.add(provider); + log.debug( + "Added dynamic OAuth provider: {} (issuer: {}, audience: {})", + provider.getName(), + provider.getIssuer(), + provider.getAudience()); + } else { + log.debug("Skipping disabled OAuth provider: {}", provider.getName()); + } + } + + log.debug( + "Successfully loaded {} enabled OAuth provider(s) from GlobalSettings (total providers: {})", + enabledProviders.size(), + this.oauthProviders.size()); + + // Re-validate configuration after dynamic config changes + boolean previouslyConfigured = this.isConfigured; + this.isConfigured = validateConfiguration(); + + if (!previouslyConfigured && this.isConfigured) { + log.info("OAuth configuration fetcher now configured with dynamic providers"); + } else if (previouslyConfigured && !this.isConfigured) { + log.warn("OAuth configuration is no longer valid after dynamic configuration update"); + } + + } catch (Exception e) { + log.error("Failed to load dynamic OAuth configuration from GlobalSettings", e); + // On error, remove dynamic configuration to avoid stale data + removeDynamicProviders(); + + // Re-validate configuration after clearing dynamic config + this.isConfigured = validateConfiguration(); + } + } + + /** Removes dynamic providers from the unified list. */ + private void removeDynamicProviders() { + // Remove any dynamic providers (those without "static_" prefix in name) + this.oauthProviders.removeIf(provider -> !provider.getName().startsWith("static_")); + } + + /** + * Validates the current OAuth provider configuration. + * + * @return true if at least one enabled provider is configured + */ + private boolean validateConfiguration() { + // Check if we have any enabled OAuth providers + long enabledProviders = + this.oauthProviders.stream() + .filter(provider -> Boolean.TRUE.equals(provider.isEnabled())) + .count(); + + boolean valid = enabledProviders > 0; + + if (!valid) { + log.debug( + "Configuration validation failed: no enabled OAuth providers found (total providers: {})", + this.oauthProviders.size()); + } else { + log.debug( + "Configuration validation passed: {} enabled OAuth provider(s) found", enabledProviders); + } + + return valid; + } + + /** Sets up scheduled refresh of dynamic OAuth configuration. */ + private void setupDynamicConfigurationRefresh() { + if (this.scheduler != null) { + this.scheduler.shutdown(); + } + + this.scheduler = + Executors.newScheduledThreadPool( + 1, + r -> { + Thread t = new Thread(r, "oauth-config-refresher"); + t.setDaemon(true); + return t; + }); + + this.scheduler.scheduleAtFixedRate( + this::loadDynamicConfiguration, + REFRESH_INTERVAL_MINUTES, + REFRESH_INTERVAL_MINUTES, + TimeUnit.MINUTES); + + log.info("Scheduled OAuth configuration refresh every {} minute(s)", REFRESH_INTERVAL_MINUTES); + } + + /** + * Retrieves GlobalSettings from the entity service. + * + * @return GlobalSettingsInfo or null if not found + */ + @Nullable + private GlobalSettingsInfo getGlobalSettings() { + try { + Object globalSettingsAspect = + this.entityService.getLatestAspect( + this.systemOperationContext, GLOBAL_SETTINGS_URN, GLOBAL_SETTINGS_INFO_ASPECT_NAME); + + if (globalSettingsAspect instanceof GlobalSettingsInfo) { + return (GlobalSettingsInfo) globalSettingsAspect; + } + + return null; + } catch (Exception e) { + log.warn("Failed to retrieve GlobalSettings", e); + return null; + } + } +} diff --git a/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/DataHubOAuthSigningKeyResolver.java b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/DataHubOAuthSigningKeyResolver.java new file mode 100644 index 00000000000000..a86e740409d472 --- /dev/null +++ b/metadata-service/auth-impl/src/main/java/com/datahub/authentication/token/DataHubOAuthSigningKeyResolver.java @@ -0,0 +1,107 @@ +package com.datahub.authentication.token; + +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.JwsHeader; +import io.jsonwebtoken.SigningKeyResolverAdapter; +import java.math.BigInteger; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.security.Key; +import java.security.KeyFactory; +import java.security.PublicKey; +import java.security.spec.RSAPublicKeySpec; +import java.util.Base64; +import java.util.HashSet; +import org.json.JSONArray; +import org.json.JSONObject; + +/** Resolves signing keys from OAuth2 / OIDC JWKS endpoints. */ +public class DataHubOAuthSigningKeyResolver extends SigningKeyResolverAdapter { + + private final HttpClient client; + private final HashSet trustedIssuers; + private final String jwksUri; + private final String algorithm; + + public DataHubOAuthSigningKeyResolver( + HashSet trustedIssuers, String jwksUri, String algorithm) { + this(trustedIssuers, jwksUri, algorithm, HttpClient.newHttpClient()); + } + + // Constructor for testing with custom HttpClient + public DataHubOAuthSigningKeyResolver( + HashSet trustedIssuers, String jwksUri, String algorithm, HttpClient httpClient) { + this.trustedIssuers = trustedIssuers; + this.jwksUri = jwksUri; + this.algorithm = algorithm; + this.client = httpClient; + } + + @Override + public Key resolveSigningKey(JwsHeader jwsHeader, Claims claims) { + try { + if (!trustedIssuers.contains(claims.getIssuer())) { + throw new RuntimeException("Invalid issuer: " + claims.getIssuer()); + } + + // Validate algorithm matches expected algorithm + String headerAlgorithm = jwsHeader.getAlgorithm(); + if (!algorithm.equals(headerAlgorithm)) { + throw new RuntimeException( + "Invalid algorithm: expected " + algorithm + " but got " + headerAlgorithm); + } + + String keyId = jwsHeader.getKeyId(); + return loadPublicKey(jwksUri, keyId, algorithm); + } catch (Exception e) { + throw new RuntimeException("Unable to resolve signing key: " + e.getMessage(), e); + } + } + + private PublicKey loadPublicKey(String jwksUri, String keyId, String algorithm) throws Exception { + HttpRequest request = HttpRequest.newBuilder().uri(URI.create(jwksUri)).build(); + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString()); + + var body = new JSONObject(response.body()); + + JSONArray keys = body.getJSONArray("keys"); + + for (int i = 0; i < keys.length(); i++) { + var token = keys.getJSONObject(i); + if (keyId.equals(token.getString("kid"))) { + return getPublicKey(token, algorithm); + } + } + throw new Exception("No matching key found in JWKS for kid=" + keyId); + } + + private PublicKey getPublicKey(JSONObject token, String algorithm) throws Exception { + String keyType = token.getString("kty"); + + // Validate key type is compatible with algorithm + if (algorithm.startsWith("RS") || algorithm.startsWith("PS")) { + // RSA algorithms (RS256, RS384, RS512, PS256, PS384, PS512) + if (!"RSA".equals(keyType)) { + throw new Exception( + "Algorithm " + algorithm + " requires RSA key type, but got: " + keyType); + } + } else if (algorithm.startsWith("ES")) { + // ECDSA algorithms (ES256, ES384, ES512) + if (!"EC".equals(keyType)) { + throw new Exception( + "Algorithm " + algorithm + " requires EC key type, but got: " + keyType); + } + throw new Exception("ECDSA algorithms not yet supported"); + } else { + throw new Exception("Unsupported algorithm: " + algorithm); + } + + // Currently only RSA keys are supported + KeyFactory kf = KeyFactory.getInstance("RSA"); + BigInteger modulus = new BigInteger(1, Base64.getUrlDecoder().decode(token.getString("n"))); + BigInteger exponent = new BigInteger(1, Base64.getUrlDecoder().decode(token.getString("e"))); + return kf.generatePublic(new RSAPublicKeySpec(modulus, exponent)); + } +} diff --git a/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/DataHubOAuthAuthenticatorTest.java b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/DataHubOAuthAuthenticatorTest.java new file mode 100644 index 00000000000000..dc31f339f5dd0d --- /dev/null +++ b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/DataHubOAuthAuthenticatorTest.java @@ -0,0 +1,1040 @@ +package com.datahub.authentication.authenticator; + +import static com.datahub.authentication.AuthenticationConstants.AUTHORIZATION_HEADER_NAME; +import static com.datahub.authentication.AuthenticationConstants.ENTITY_SERVICE; +import static com.linkedin.metadata.Constants.CORP_USER_INFO_ASPECT_NAME; +import static com.linkedin.metadata.Constants.GLOBAL_SETTINGS_INFO_ASPECT_NAME; +import static com.linkedin.metadata.Constants.GLOBAL_SETTINGS_URN; +import static com.linkedin.metadata.Constants.ORIGIN_ASPECT_NAME; +import static com.linkedin.metadata.Constants.SUB_TYPES_ASPECT_NAME; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import com.datahub.authentication.AuthenticationException; +import com.datahub.authentication.AuthenticationRequest; +import com.datahub.authentication.AuthenticatorContext; +import com.linkedin.common.urn.CorpuserUrn; +import com.linkedin.events.metadata.ChangeType; +import com.linkedin.metadata.aspect.batch.AspectsBatch; +import com.linkedin.metadata.entity.EntityService; +import com.linkedin.mxe.MetadataChangeProposal; +import com.linkedin.settings.global.GlobalSettingsInfo; +import com.linkedin.settings.global.OAuthProvider; +import com.linkedin.settings.global.OAuthSettings; +import io.datahubproject.metadata.context.OperationContext; +import io.datahubproject.test.metadata.context.TestOperationContexts; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import javax.crypto.spec.SecretKeySpec; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class DataHubOAuthAuthenticatorTest { + + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final String TEST_DISCOVERY_URI = + "https://auth.example.com/.well-known/openid-configuration"; + private static final String TEST_USER_NAME_CLAIM = "preferred_username"; + private static final String TEST_ALGORITHM = "RS256"; + private static final String TEST_ISSUER = "https://auth.example.com"; + private static final String TEST_JWKS_URI = "https://auth.example.com/.well-known/jwks.json"; + + private DataHubOAuthAuthenticator authenticator; + private EntityService mockEntityService; + private OperationContext mockOperationContext; + private AuthenticatorContext authenticatorContext; + + @BeforeMethod + public void setUp() { + authenticator = new DataHubOAuthAuthenticator(); + mockEntityService = mock(EntityService.class); + + // Use TestOperationContexts to create a proper OperationContext with all dependencies + mockOperationContext = TestOperationContexts.systemContextNoSearchAuthorization(); + + // Set up authenticator context + Map contextData = new HashMap<>(); + contextData.put(ENTITY_SERVICE, mockEntityService); + contextData.put("systemOperationContext", mockOperationContext); + authenticatorContext = new AuthenticatorContext(contextData); + } + + @Test + public void testInitSuccessWithStaticConfig() { + // Arrange + Map config = createValidStaticConfig(); + + // Act + authenticator.init(config, authenticatorContext); + + // Assert - If no exception is thrown, initialization was successful + assertNotNull(authenticator); + } + + @Test + public void testInitSucceedsButAuthenticationFailsWhenNotConfigured() + throws AuthenticationException { + // Static config is empty/invalid - missing required fields but OAuth enabled + Map invalidStaticConfig = new HashMap<>(); + invalidStaticConfig.put("enabled", "true"); + + // Act - Init should succeed without throwing exception + authenticator.init(invalidStaticConfig, authenticatorContext); + + // But authentication should fail gracefully + AuthenticationRequest request = new AuthenticationRequest(Collections.emptyMap()); + + try { + authenticator.authenticate(request); + assertNotNull(null, "Expected AuthenticationException to be thrown"); + } catch (AuthenticationException e) { + assertTrue( + e.getMessage().contains("OAuth authenticator is not configured"), + "Should fail with not configured message, got: " + e.getMessage()); + } + } + + @Test + public void testInitFailureNullConfig() { + // Act & Assert + try { + authenticator.init(null, authenticatorContext); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertNotNull(e.getMessage()); + } + } + + @Test + public void testInitFailureNullContext() { + // Act & Assert + try { + authenticator.init(Collections.emptyMap(), null); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertNotNull(e.getMessage()); + } + } + + @Test + public void testInitFailureMissingEntityService() { + // Arrange + Map contextData = new HashMap<>(); + contextData.put("systemOperationContext", mockOperationContext); + AuthenticatorContext invalidContext = new AuthenticatorContext(contextData); + + Map config = new HashMap<>(); + config.put("enabled", "true"); // Enable to trigger EntityService check + + // Act & Assert + try { + authenticator.init(config, invalidContext); + assertNotNull(null, "Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertNotNull(e.getMessage()); + } + } + + @Test + public void testAuthenticateFailureMissingAuthorizationHeader() throws AuthenticationException { + // Arrange + Map config = createValidStaticConfig(); + authenticator.init(config, authenticatorContext); + + AuthenticationRequest request = new AuthenticationRequest(Collections.emptyMap()); + + // Act & Assert + try { + authenticator.authenticate(request); + assertNotNull(null, "Expected AuthenticationException to be thrown"); + } catch (AuthenticationException e) { + assertTrue( + e.getMessage().contains("Invalid Authorization header"), + "Should fail with invalid authorization header message, got: " + e.getMessage()); + } + } + + @Test + public void testAuthenticateFailureInvalidAuthorizationHeader() throws AuthenticationException { + // Arrange + Map config = createValidStaticConfig(); + authenticator.init(config, authenticatorContext); + + Map headers = new HashMap<>(); + headers.put(AUTHORIZATION_HEADER_NAME, "InvalidHeader withoutBearer"); + AuthenticationRequest request = new AuthenticationRequest(headers); + + // Act & Assert + try { + authenticator.authenticate(request); + assertNotNull(null, "Expected AuthenticationException to be thrown"); + } catch (AuthenticationException e) { + assertTrue( + e.getMessage().contains("Invalid Authorization header"), + "Should fail with invalid authorization header message, got: " + e.getMessage()); + } + } + + @Test + public void testAuthenticateWithValidJWT() throws AuthenticationException { + // Arrange + Map staticConfig = createValidStaticConfig(); + authenticator.init(staticConfig, authenticatorContext); + + // Create a valid JWT token (though it will fail signature verification in real scenario) + String validJwtToken = createValidJwtToken(); + Map headers = new HashMap<>(); + headers.put(AUTHORIZATION_HEADER_NAME, "Bearer " + validJwtToken); + AuthenticationRequest request = new AuthenticationRequest(headers); + + // Act & Assert + // This test would require more complex setup to mock the JWT verification process + // For now, we'll test the basic flow up to token parsing + try { + authenticator.authenticate(request); + // If we get here without exception, that's unexpected since we don't have a real JWKS setup + // But the test validates the configuration loading works + } catch (AuthenticationException e) { + // Expected since we don't have a real JWT setup for this test + assertNotNull(e.getMessage()); + } + } + + private Map createValidStaticConfig() { + Map config = new HashMap<>(); + config.put("enabled", "true"); // Enable OAuth authentication + config.put("userIdClaim", "sub"); + config.put("trustedIssuers", TEST_ISSUER); + config.put("allowedAudiences", TEST_CLIENT_ID); + config.put("jwksUri", TEST_JWKS_URI); + config.put("algorithm", "RS256"); + return config; + } + + private String createValidJwtToken() { + // Create a simple JWT token for testing + // In a real scenario, this would be created by the OIDC provider + // Use a 256-bit key to avoid WeakKeyException + String secret = "test-secret-key-for-jwt-signing-that-is-long-enough-for-hmac-sha256-algorithm"; + SecretKeySpec key = new SecretKeySpec(secret.getBytes(), "HmacSHA256"); + + Map claims = new HashMap<>(); + claims.put("sub", "test-user"); + claims.put("aud", List.of(TEST_CLIENT_ID)); + claims.put("iss", TEST_ISSUER); + claims.put("exp", new Date(System.currentTimeMillis() + 3600000)); // 1 hour from now + claims.put(TEST_USER_NAME_CLAIM, "test-user@example.com"); + + return Jwts.builder() + .setClaims(claims) + .setHeaderParam("kid", "test-key-id") + .signWith(key, SignatureAlgorithm.HS256) + .compact(); + } + + @Test + public void testMultipleAudienceConfiguration() { + // Arrange - Static config with multiple audiences (comma-separated) + Map staticConfigWithMultipleAudiences = new HashMap<>(); + staticConfigWithMultipleAudiences.put("enabled", "true"); + staticConfigWithMultipleAudiences.put("trustedIssuers", TEST_ISSUER); + staticConfigWithMultipleAudiences.put( + "allowedAudiences", "audience-1,additional-client-1,additional-client-2"); + staticConfigWithMultipleAudiences.put("jwksUri", TEST_JWKS_URI); + + // Act + authenticator.init(staticConfigWithMultipleAudiences, authenticatorContext); + + // Assert - verify OAuth providers were created for each audience + List actualProviders = authenticator.getOAuthProviders(); + + assertNotNull(actualProviders); + assertEquals(actualProviders.size(), 3); // 3 providers (1 issuer × 3 audiences) + + // Collect all audiences from providers + HashSet providerAudiences = new HashSet<>(); + for (OAuthProvider provider : actualProviders) { + assertEquals(provider.getIssuer(), TEST_ISSUER); // Same issuer for all + providerAudiences.add(provider.getAudience()); + assertTrue(provider.getName().startsWith("static_")); + } + + assertEquals(providerAudiences.size(), 3); + assertTrue(providerAudiences.contains("audience-1")); + assertTrue(providerAudiences.contains("additional-client-1")); + assertTrue(providerAudiences.contains("additional-client-2")); + } + + @Test + public void testStaticConfigOnlyAudiences() { + // Arrange - Static config only + Map staticConfig = createValidStaticConfig(); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Assert - verify only static config providers are used + List actualProviders = authenticator.getOAuthProviders(); + + assertNotNull(actualProviders); + assertEquals(actualProviders.size(), 1); // Only one static provider + + OAuthProvider staticProvider = actualProviders.get(0); + assertTrue(staticProvider.getName().startsWith("static_")); + assertEquals(staticProvider.getIssuer(), TEST_ISSUER); + assertEquals(staticProvider.getAudience(), TEST_CLIENT_ID); + } + + @Test + public void testServiceAccountUserIdGeneration() { + // Test that the user ID generation logic works correctly without full JWT authentication + Map staticConfig = createValidStaticConfig(); + authenticator.init(staticConfig, authenticatorContext); + + // Use reflection to test the buildServiceUserUrn method directly + try { + // Create a test JWT claims object + java.lang.reflect.Method buildMethod = + DataHubOAuthAuthenticator.class.getDeclaredMethod( + "buildServiceUserUrn", io.jsonwebtoken.Claims.class); + buildMethod.setAccessible(true); + + // Create mock claims + io.jsonwebtoken.Claims mockClaims = mock(io.jsonwebtoken.Claims.class); + when(mockClaims.getIssuer()).thenReturn("https://auth.example.com"); + when(mockClaims.get("sub", String.class)).thenReturn("service-account-123"); + + // Act + String userId = (String) buildMethod.invoke(authenticator, mockClaims); + + // Assert + assertNotNull(userId); + assertTrue(userId.startsWith("__oauth_")); + assertTrue(userId.contains("auth_example_com")); + assertTrue(userId.contains("service-account-123")); + + } catch (Exception e) { + assertNotNull(null, "Failed to test user ID generation: " + e.getMessage()); + } + } + + @Test + public void testServiceAccountAspectCreation() { + // Test the aspect creation logic without full JWT authentication + Map staticConfig = createValidStaticConfig(); + authenticator.init(staticConfig, authenticatorContext); + + try { + // Use reflection to test the createServiceAccountAspects method directly + java.lang.reflect.Method createMethod = + DataHubOAuthAuthenticator.class.getDeclaredMethod( + "createServiceAccountAspects", CorpuserUrn.class, io.jsonwebtoken.Claims.class); + createMethod.setAccessible(true); + + // Create test data + CorpuserUrn testUrn = new CorpuserUrn("__oauth_auth_example_com_service123"); + io.jsonwebtoken.Claims mockClaims = mock(io.jsonwebtoken.Claims.class); + when(mockClaims.getIssuer()).thenReturn("https://auth.example.com"); + when(mockClaims.get("sub", String.class)).thenReturn("service123"); + + // Act + @SuppressWarnings("unchecked") + List aspects = + (List) createMethod.invoke(authenticator, testUrn, mockClaims); + + // Assert + assertNotNull(aspects); + assertEquals(aspects.size(), 3); // CorpUserInfo, SubTypes, Origin + + // Verify each aspect + boolean hasCorpUserInfo = false; + boolean hasSubTypes = false; + boolean hasOrigin = false; + + for (MetadataChangeProposal mcp : aspects) { + assertEquals(mcp.getEntityUrn(), testUrn); + assertEquals(mcp.getEntityType(), "corpuser"); + assertEquals(mcp.getChangeType(), ChangeType.UPSERT); + + if (CORP_USER_INFO_ASPECT_NAME.equals(mcp.getAspectName())) { + hasCorpUserInfo = true; + } else if (SUB_TYPES_ASPECT_NAME.equals(mcp.getAspectName())) { + hasSubTypes = true; + } else if (ORIGIN_ASPECT_NAME.equals(mcp.getAspectName())) { + hasOrigin = true; + } + } + + assertTrue(hasCorpUserInfo, "Should create CorpUserInfo aspect"); + assertTrue(hasSubTypes, "Should create SubTypes aspect"); + assertTrue(hasOrigin, "Should create Origin aspect"); + + } catch (Exception e) { + assertNotNull(null, "Failed to test aspect creation: " + e.getMessage()); + } + } + + @Test + public void testEnsureServiceAccountExistsWithNewUser() { + // Test the ensureServiceAccountExists logic with a new user + Map staticConfig = createValidStaticConfig(); + + // Mock user doesn't exist initially + when(mockEntityService.exists(eq(mockOperationContext), any(CorpuserUrn.class), eq(false))) + .thenReturn(false); + + authenticator.init(staticConfig, authenticatorContext); + + try { + // Use reflection to test the ensureServiceAccountExists method directly + java.lang.reflect.Method ensureMethod = + DataHubOAuthAuthenticator.class.getDeclaredMethod( + "ensureServiceAccountExists", String.class, io.jsonwebtoken.Claims.class); + ensureMethod.setAccessible(true); + + // Create test data + String userId = "__oauth_auth_example_com_service123"; + io.jsonwebtoken.Claims mockClaims = mock(io.jsonwebtoken.Claims.class); + when(mockClaims.getIssuer()).thenReturn("https://auth.example.com"); + when(mockClaims.get("sub", String.class)).thenReturn("service123"); + + // Act + ensureMethod.invoke(authenticator, userId, mockClaims); + + // Verify user existence was checked + verify(mockEntityService, times(1)) + .exists(eq(mockOperationContext), any(CorpuserUrn.class), eq(false)); + + // Verify aspects were ingested + verify(mockEntityService, times(1)) + .ingestAspects(eq(mockOperationContext), any(AspectsBatch.class), eq(false), eq(true)); + + } catch (Exception e) { + assertNotNull(null, "Failed to test service account creation: " + e.getMessage()); + } + } + + @Test + public void testServiceAccountUserIdUniqueness() { + // Test that different issuers produce different user IDs even with same subject + Map staticConfig = createValidStaticConfig(); + authenticator.init(staticConfig, authenticatorContext); + + try { + // Use reflection to test user ID generation with different issuers + java.lang.reflect.Method buildMethod = + DataHubOAuthAuthenticator.class.getDeclaredMethod( + "buildServiceUserUrn", io.jsonwebtoken.Claims.class); + buildMethod.setAccessible(true); + + // Create mock claims for first issuer + io.jsonwebtoken.Claims mockClaims1 = mock(io.jsonwebtoken.Claims.class); + when(mockClaims1.getIssuer()).thenReturn("https://auth.company1.com"); + when(mockClaims1.get("sub", String.class)).thenReturn("service-account-123"); + + // Create mock claims for second issuer + io.jsonwebtoken.Claims mockClaims2 = mock(io.jsonwebtoken.Claims.class); + when(mockClaims2.getIssuer()).thenReturn("https://auth.company2.com"); + when(mockClaims2.get("sub", String.class)).thenReturn("service-account-123"); + + // Act + String userId1 = (String) buildMethod.invoke(authenticator, mockClaims1); + String userId2 = (String) buildMethod.invoke(authenticator, mockClaims2); + + // Assert different user IDs are generated + assertNotNull(userId1); + assertNotNull(userId2); + assertTrue(!userId1.equals(userId2), "Different issuers should generate different user IDs"); + + // Both should contain the issuer information + assertTrue( + userId1.contains("auth_company1_com"), "User ID should contain sanitized issuer 1"); + assertTrue( + userId2.contains("auth_company2_com"), "User ID should contain sanitized issuer 2"); + + // Both should have the OAuth prefix and subject + assertTrue(userId1.startsWith("__oauth_"), "User ID 1 should have OAuth prefix"); + assertTrue(userId2.startsWith("__oauth_"), "User ID 2 should have OAuth prefix"); + assertTrue(userId1.contains("service-account-123"), "User ID 1 should contain subject"); + assertTrue(userId2.contains("service-account-123"), "User ID 2 should contain subject"); + + } catch (Exception e) { + assertNotNull(null, "Failed to test user ID uniqueness: " + e.getMessage()); + } + } + + @Test + public void testEnsureServiceAccountExistsWithExistingUser() { + // Test the ensureServiceAccountExists logic with an existing user + Map staticConfig = createValidStaticConfig(); + + // Mock user already exists + when(mockEntityService.exists(eq(mockOperationContext), any(CorpuserUrn.class), eq(false))) + .thenReturn(true); + + authenticator.init(staticConfig, authenticatorContext); + + try { + // Use reflection to test the ensureServiceAccountExists method directly + java.lang.reflect.Method ensureMethod = + DataHubOAuthAuthenticator.class.getDeclaredMethod( + "ensureServiceAccountExists", String.class, io.jsonwebtoken.Claims.class); + ensureMethod.setAccessible(true); + + // Create test data + String userId = "__oauth_auth_example_com_service123"; + io.jsonwebtoken.Claims mockClaims = mock(io.jsonwebtoken.Claims.class); + when(mockClaims.getIssuer()).thenReturn("https://auth.example.com"); + when(mockClaims.get("sub", String.class)).thenReturn("service123"); + + // Act + ensureMethod.invoke(authenticator, userId, mockClaims); + + // Verify user existence was checked + verify(mockEntityService, times(1)) + .exists(eq(mockOperationContext), any(CorpuserUrn.class), eq(false)); + + // Verify aspects were NOT ingested (user already exists) + verify(mockEntityService, never()) + .ingestAspects(eq(mockOperationContext), any(AspectsBatch.class), eq(false), eq(true)); + + } catch (Exception e) { + assertNotNull(null, "Failed to test existing service account handling: " + e.getMessage()); + } + } + + @Test + public void testServiceAccountCreationErrorHandling() { + // Test that service account creation failures are handled gracefully + Map staticConfig = createValidStaticConfig(); + + // Mock user doesn't exist initially + when(mockEntityService.exists(eq(mockOperationContext), any(CorpuserUrn.class), eq(false))) + .thenReturn(false); + + // Mock aspect ingestion failure + doThrow(new RuntimeException("Ingestion failed")) + .when(mockEntityService) + .ingestAspects(eq(mockOperationContext), any(AspectsBatch.class), eq(false), eq(true)); + + authenticator.init(staticConfig, authenticatorContext); + + try { + // Use reflection to test the ensureServiceAccountExists method directly + java.lang.reflect.Method ensureMethod = + DataHubOAuthAuthenticator.class.getDeclaredMethod( + "ensureServiceAccountExists", String.class, io.jsonwebtoken.Claims.class); + ensureMethod.setAccessible(true); + + // Create test data + String userId = "__oauth_auth_example_com_service123"; + io.jsonwebtoken.Claims mockClaims = mock(io.jsonwebtoken.Claims.class); + when(mockClaims.getIssuer()).thenReturn("https://auth.example.com"); + when(mockClaims.get("sub", String.class)).thenReturn("service123"); + + // Act - should not throw exception even though ingestion fails + ensureMethod.invoke(authenticator, userId, mockClaims); + + // Verify user existence was checked + verify(mockEntityService, times(1)) + .exists(eq(mockOperationContext), any(CorpuserUrn.class), eq(false)); + + // Verify aspects ingestion was attempted (but failed) + verify(mockEntityService, times(1)) + .ingestAspects(eq(mockOperationContext), any(AspectsBatch.class), eq(false), eq(true)); + + } catch (Exception e) { + assertNotNull(null, "Failed to test error handling: " + e.getMessage()); + } + } + + private String createJwtTokenWithClaims(String issuer, String subject, String audience) { + // Use a 256-bit key to avoid WeakKeyException + String secretKey = "mySecretKeyThatIsLongEnoughFor256BitHmacSha256AlgorithmToWork"; + SecretKeySpec signingKey = new SecretKeySpec(secretKey.getBytes(), "HmacSHA256"); + + return Jwts.builder() + .setIssuer(issuer) + .setSubject(subject) + .setAudience(audience) + .claim("preferred_username", subject) + .setIssuedAt(new Date()) + .setExpiration(new Date(System.currentTimeMillis() + 3600000)) // 1 hour + .signWith(signingKey, SignatureAlgorithm.HS256) + .compact(); + } + + // ==================== DYNAMIC CONFIGURATION TESTS ==================== + + @Test + public void testDynamicConfigurationLoading() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + // Create mock GlobalSettings with OAuth providers + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Assert + verify(mockEntityService, times(1)) + .getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + assertNotNull(authenticator); + } + + @Test + public void testDynamicConfigurationWithNoGlobalSettings() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + // Mock that no GlobalSettings exist + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Assert - Should still initialize successfully with static config + verify(mockEntityService, times(1)) + .getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + assertNotNull(authenticator); + } + + @Test + public void testDynamicConfigurationWithEmptyOAuthProviders() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + // Create GlobalSettings with no OAuth providers + GlobalSettingsInfo globalSettings = new GlobalSettingsInfo(); + OAuthSettings oauthSettings = new OAuthSettings(); + oauthSettings.setProviders(new com.linkedin.settings.global.OAuthProviderArray()); + globalSettings.setOauth(oauthSettings); + + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Assert + verify(mockEntityService, times(1)) + .getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + assertNotNull(authenticator); + } + + @Test + public void testDynamicConfigurationWithDisabledProviders() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + // Create GlobalSettings with disabled OAuth provider + GlobalSettingsInfo globalSettings = new GlobalSettingsInfo(); + OAuthSettings oauthSettings = new OAuthSettings(); + com.linkedin.settings.global.OAuthProviderArray providers = + new com.linkedin.settings.global.OAuthProviderArray(); + + OAuthProvider disabledProvider = new OAuthProvider(); + disabledProvider.data().put("enabled", Boolean.FALSE); + disabledProvider.setName("disabled-provider"); + disabledProvider.setIssuer("https://disabled.example.com"); + disabledProvider.setAudience("disabled-audience"); + disabledProvider.setJwksUri("https://disabled.example.com/jwks"); + providers.add(disabledProvider); + + oauthSettings.setProviders(providers); + globalSettings.setOauth(oauthSettings); + + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Assert - Should not load disabled providers + verify(mockEntityService, times(1)) + .getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + assertNotNull(authenticator); + } + + @Test + public void testDynamicConfigurationErrorHandling() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + // Mock EntityService to throw exception when loading GlobalSettings + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenThrow(new RuntimeException("GlobalSettings loading failed")); + + // Act - Should not throw exception, should handle gracefully + authenticator.init(staticConfig, authenticatorContext); + + // Assert - Should still initialize with static config despite dynamic config error + verify(mockEntityService, times(1)) + .getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + assertNotNull(authenticator); + } + + @Test + public void testAuthenticationWithDynamicProviderOnly() { + // Arrange + Map staticConfig = new HashMap<>(); + staticConfig.put("enabled", "true"); + staticConfig.put("userIdClaim", "sub"); + // Intentionally omit static OAuth configuration + + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + authenticator.init(staticConfig, authenticatorContext); + + // Create a valid JWT token for the dynamic provider + String token = + createJwtTokenWithClaims("https://dynamic.example.com", "test-user", "dynamic-audience"); + + Map headers = new HashMap<>(); + headers.put(AUTHORIZATION_HEADER_NAME, "Bearer " + token); + AuthenticationRequest request = new AuthenticationRequest(headers); + + // Act & Assert - Should authenticate successfully using dynamic provider + try { + var result = authenticator.authenticate(request); + assertNotNull(result); + assertEquals(result.getActor().getId(), "__oauth_dynamic.example.com_test-user"); + } catch (AuthenticationException e) { + // Note: This test may fail if JWT signature verification is required + // In real scenarios, proper JWT signing keys would be configured + assertTrue(e.getMessage().contains("OAuth token validation failed")); + } + } + + @Test + public void testAuthenticationWithStaticAndDynamicProviders() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + authenticator.init(staticConfig, authenticatorContext); + + // Test with static provider token + String staticToken = createJwtTokenWithClaims(TEST_ISSUER, "static-user", "static-audience"); + Map staticHeaders = new HashMap<>(); + staticHeaders.put(AUTHORIZATION_HEADER_NAME, "Bearer " + staticToken); + AuthenticationRequest staticRequest = new AuthenticationRequest(staticHeaders); + + // Test with dynamic provider token + String dynamicToken = + createJwtTokenWithClaims("https://dynamic.example.com", "dynamic-user", "dynamic-audience"); + Map dynamicHeaders = new HashMap<>(); + dynamicHeaders.put(AUTHORIZATION_HEADER_NAME, "Bearer " + dynamicToken); + AuthenticationRequest dynamicRequest = new AuthenticationRequest(dynamicHeaders); + + // Act & Assert - Both should work (subject to JWT signature validation) + try { + var staticResult = authenticator.authenticate(staticRequest); + assertNotNull(staticResult); + } catch (AuthenticationException e) { + assertTrue(e.getMessage().contains("OAuth token validation failed")); + } + + try { + var dynamicResult = authenticator.authenticate(dynamicRequest); + assertNotNull(dynamicResult); + } catch (AuthenticationException e) { + assertTrue(e.getMessage().contains("OAuth token validation failed")); + } + } + + @Test + public void testScheduledConfigurationRefresh() throws InterruptedException { + // Arrange + Map staticConfig = createValidStaticConfig(); + + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Wait a short time to ensure scheduler is set up + Thread.sleep(100); + + // Assert - Verify that the scheduler was set up (initial load + scheduled refresh setup) + verify(mockEntityService, times(1)) + .getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + + // Clean up + authenticator.destroy(); + } + + @Test + public void testCleanupDestroy() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + authenticator.init(staticConfig, authenticatorContext); + + // Act + authenticator.destroy(); + + // Assert - No exception should be thrown, scheduler should be shut down gracefully + assertNotNull(authenticator); + } + + @Test + public void testDynamicReconfigurationFromUnconfiguredToConfigured() + throws AuthenticationException { + // Arrange - Start with empty static config (no providers configured) + Map emptyStaticConfig = new HashMap<>(); + emptyStaticConfig.put("enabled", "true"); + + // Mock that no GlobalSettings exist initially + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act - Initialize with empty config + authenticator.init(emptyStaticConfig, authenticatorContext); + + // Assert - Authentication should fail initially (not configured) + AuthenticationRequest request = + new AuthenticationRequest( + Map.of(AUTHORIZATION_HEADER_NAME, "Bearer " + createValidJwtToken())); + + try { + authenticator.authenticate(request); + assertNotNull(null, "Expected AuthenticationException to be thrown"); + } catch (AuthenticationException e) { + assertTrue( + e.getMessage().contains("OAuth authenticator is not configured"), + "Should fail with not configured message, got: " + e.getMessage()); + } + + // Arrange - Now mock that GlobalSettings with OAuth providers exist + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act - Force refresh the configuration to simulate dynamic update + authenticator.forceRefreshOAuthProviders(); // This will force refresh from GlobalSettings + + // Create a new request with proper token for dynamic provider + Map dynamicClaims = new HashMap<>(); + dynamicClaims.put("sub", "dynamic-test-user"); + dynamicClaims.put("aud", List.of("dynamic-audience")); + dynamicClaims.put("iss", "https://dynamic.example.com"); + dynamicClaims.put("exp", new Date(System.currentTimeMillis() + 3600000)); + + String secret = "test-secret-key-for-jwt-signing-that-is-long-enough-for-hmac-sha256-algorithm"; + SecretKeySpec key = new SecretKeySpec(secret.getBytes(), "HmacSHA256"); + + String dynamicToken = + Jwts.builder() + .setClaims(dynamicClaims) + .setHeaderParam("kid", "dynamic-key-id") + .signWith(key, SignatureAlgorithm.HS256) + .compact(); + + AuthenticationRequest dynamicRequest = + new AuthenticationRequest(Map.of(AUTHORIZATION_HEADER_NAME, "Bearer " + dynamicToken)); + + // Note: This test verifies the configuration state check works correctly. + // The actual JWT validation would fail due to mocked JWKS, but we're testing + // that it gets past the "not configured" check and fails later in the process. + try { + authenticator.authenticate(dynamicRequest); + // If we get here, configuration check passed (which is what we're testing) + assertNotNull(null, "Expected to fail at JWT validation, not configuration check"); + } catch (AuthenticationException e) { + // Should NOT be the "not configured" error anymore + assertTrue( + !e.getMessage().contains("OAuth authenticator is not configured"), + "Should not fail with 'not configured' message after dynamic configuration is loaded. Got: " + + e.getMessage()); + // Should fail with JWT validation error instead (which is expected since we're using mock + // JWKS) + assertTrue( + e.getMessage().contains("OAuth token validation failed") + || e.getMessage().contains("Unable to resolve signing key") + || e.getMessage().contains("No configured OAuth provider matches"), + "Should fail with JWT validation error, got: " + e.getMessage()); + } + } + + @Test + public void testProviderSpecificAlgorithmAndUserIdClaim() { + // Arrange + Map staticConfig = createValidStaticConfig(); + + // Create GlobalSettings with providers that have different algorithm and userIdClaim + GlobalSettingsInfo globalSettings = new GlobalSettingsInfo(); + OAuthSettings oauthSettings = new OAuthSettings(); + com.linkedin.settings.global.OAuthProviderArray providers = + new com.linkedin.settings.global.OAuthProviderArray(); + + // Create provider with custom algorithm and userIdClaim + OAuthProvider customProvider = new OAuthProvider(); + customProvider.data().put("enabled", Boolean.TRUE); + customProvider.setName("custom-provider"); + customProvider.setIssuer("https://custom.example.com"); + customProvider.setAudience("custom-audience"); + customProvider.setJwksUri("https://custom.example.com/jwks"); + customProvider.setAlgorithm("RS384"); // Different from default RS256 + customProvider.setUserIdClaim("email"); // Different from default sub + providers.add(customProvider); + + oauthSettings.setProviders(providers); + globalSettings.setOauth(oauthSettings); + + when(mockEntityService.getLatestAspect( + eq(mockOperationContext), + eq(GLOBAL_SETTINGS_URN), + eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + authenticator.init(staticConfig, authenticatorContext); + + // Assert - Get providers and verify custom fields are set + List loadedProviders = authenticator.getOAuthProviders(); + assertEquals(2, loadedProviders.size()); // 1 static + 1 dynamic + + // Find the custom provider + OAuthProvider foundCustomProvider = + loadedProviders.stream() + .filter(p -> "custom-provider".equals(p.getName())) + .findFirst() + .orElse(null); + + assertNotNull(foundCustomProvider, "Custom provider should be found"); + assertEquals( + "RS384", foundCustomProvider.getAlgorithm(), "Algorithm should be provider-specific"); + assertEquals( + "email", foundCustomProvider.getUserIdClaim(), "UserIdClaim should be provider-specific"); + + // Verify static provider uses defaults from config + OAuthProvider staticProvider = + loadedProviders.stream() + .filter(p -> p.getName().startsWith("static_")) + .findFirst() + .orElse(null); + + assertNotNull(staticProvider, "Static provider should be found"); + assertEquals( + "RS256", staticProvider.getAlgorithm(), "Static provider should use config algorithm"); + assertEquals( + "sub", staticProvider.getUserIdClaim(), "Static provider should use config userIdClaim"); + } + + // Helper method to create GlobalSettings with OAuth providers + private GlobalSettingsInfo createGlobalSettingsWithOAuthProviders() { + GlobalSettingsInfo globalSettings = new GlobalSettingsInfo(); + OAuthSettings oauthSettings = new OAuthSettings(); + com.linkedin.settings.global.OAuthProviderArray providers = + new com.linkedin.settings.global.OAuthProviderArray(); + + // Create enabled provider + OAuthProvider enabledProvider = new OAuthProvider(); + enabledProvider.data().put("enabled", Boolean.TRUE); + enabledProvider.setName("dynamic-provider"); + enabledProvider.setIssuer("https://dynamic.example.com"); + enabledProvider.setAudience("dynamic-audience"); + enabledProvider.setJwksUri("https://dynamic.example.com/jwks"); + enabledProvider.setAlgorithm("RS256"); + enabledProvider.setUserIdClaim("sub"); + providers.add(enabledProvider); + + // Create another enabled provider + OAuthProvider secondProvider = new OAuthProvider(); + secondProvider.data().put("enabled", Boolean.TRUE); + secondProvider.setName("second-provider"); + secondProvider.setIssuer("https://second.example.com"); + secondProvider.setAudience("second-audience"); + secondProvider.setJwksUri("https://second.example.com/jwks"); + secondProvider.setAlgorithm("RS256"); + secondProvider.setUserIdClaim("sub"); + providers.add(secondProvider); + + oauthSettings.setProviders(providers); + globalSettings.setOauth(oauthSettings); + + return globalSettings; + } +} diff --git a/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/JwksUriResolverTest.java b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/JwksUriResolverTest.java new file mode 100644 index 00000000000000..8fe207792562a4 --- /dev/null +++ b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/JwksUriResolverTest.java @@ -0,0 +1,288 @@ +package com.datahub.authentication.authenticator; + +import static org.testng.Assert.*; + +import org.testng.annotations.Test; + +public class JwksUriResolverTest { + + @Test + public void testDeriveJwksUriFallbackWithStandardFormat() { + // Arrange + String discoveryUri = "https://auth.example.com"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithTrailingSlash() { + // Arrange + String discoveryUri = "https://auth.example.com/"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithDiscoveryEndpoint() { + // Arrange + String discoveryUri = "https://auth.example.com/.well-known/openid-configuration"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithDiscoveryEndpointAndTrailingSlash() { + // Arrange + String discoveryUri = "https://auth.example.com/.well-known/openid-configuration/"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithPath() { + // Arrange + String discoveryUri = "https://auth.example.com/oauth2/v1"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/oauth2/v1/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithPathAndDiscoveryEndpoint() { + // Arrange + String discoveryUri = "https://auth.example.com/oauth2/v1/.well-known/openid-configuration"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/oauth2/v1/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithWhitespace() { + // Arrange + String discoveryUri = " https://auth.example.com "; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithLocalhost() { + // Arrange + String discoveryUri = "http://localhost:8080"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "http://localhost:8080/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithPort() { + // Arrange + String discoveryUri = "https://auth.example.com:443/oauth"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "https://auth.example.com:443/oauth/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFromDiscoveryUriWithUnreachableEndpoint() { + // Arrange + String discoveryUri = "https://nonexistent.example.com"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + + // Assert - Should fallback to standard pattern + assertEquals(jwksUri, "https://nonexistent.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFromDiscoveryUriWithInvalidUrl() { + // Arrange + String discoveryUri = "not-a-valid-url"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + + // Assert - Should fallback to standard pattern + assertEquals(jwksUri, "not-a-valid-url/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFromDiscoveryUriAppendsDiscoveryPath() { + // Test that when discovery document fetch fails, it tries the right URL format + // Arrange + String discoveryUri = "https://auth.example.com/oauth2"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + + // Assert - Should fallback to standard pattern since we can't reach the endpoint + assertEquals(jwksUri, "https://auth.example.com/oauth2/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFromDiscoveryUriWithFullDiscoveryPath() { + // Arrange + String discoveryUri = "https://nonexistent.example.com/.well-known/openid-configuration"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + + // Assert - Should fallback to standard pattern + assertEquals(jwksUri, "https://nonexistent.example.com/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFromDiscoveryUriWithEmptyString() { + // Arrange + String discoveryUri = ""; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + + // Assert - Should fallback gracefully + assertEquals(jwksUri, "/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFromDiscoveryUriWithNull() { + // Arrange + String discoveryUri = null; + + // Act & Assert - Should handle gracefully + assertThrows( + NullPointerException.class, + () -> { + JwksUriResolver.deriveJwksUriFromDiscoveryUri(discoveryUri); + }); + } + + @Test + public void testDeriveJwksUriFallbackWithNull() { + // Arrange + String discoveryUri = null; + + // Act & Assert - Should handle gracefully + assertThrows( + NullPointerException.class, + () -> { + JwksUriResolver.deriveJwksUriFallback(discoveryUri); + }); + } + + @Test + public void testDeriveJwksUriFallbackWithOnlySlash() { + // Arrange + String discoveryUri = "/"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert + assertEquals(jwksUri, "/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackCaseInsensitive() { + // Test case sensitivity - OIDC endpoints should maintain case + // Arrange + String discoveryUri = "https://Auth.Example.Com/OAuth2"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert - Should preserve case + assertEquals(jwksUri, "https://Auth.Example.Com/OAuth2/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithMultipleTrailingSlashes() { + // Arrange + String discoveryUri = "https://auth.example.com///"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert - Only removes one trailing slash + assertEquals(jwksUri, "https://auth.example.com///.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithQueryParams() { + // Arrange - discovery URI with query parameters + String discoveryUri = "https://auth.example.com?tenant=example"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert - Should preserve query parameters + assertEquals(jwksUri, "https://auth.example.com?tenant=example/.well-known/jwks.json"); + } + + @Test + public void testDeriveJwksUriFallbackWithFragment() { + // Arrange - discovery URI with fragment + String discoveryUri = "https://auth.example.com#section"; + + // Act + String jwksUri = JwksUriResolver.deriveJwksUriFallback(discoveryUri); + + // Assert - Should preserve fragment + assertEquals(jwksUri, "https://auth.example.com#section/.well-known/jwks.json"); + } + + // Integration test to verify the pattern matching works correctly + @Test + public void testDiscoveryEndpointDetectionPattern() { + // Test various formats to ensure the discovery endpoint pattern is detected correctly + String[] testCases = { + "https://example.com/.well-known/openid-configuration", + "https://example.com/oauth/.well-known/openid-configuration", + "https://example.com/auth/realms/master/.well-known/openid-configuration", + "http://localhost:8080/.well-known/openid-configuration" + }; + + String[] expectedResults = { + "https://example.com/.well-known/jwks.json", + "https://example.com/oauth/.well-known/jwks.json", + "https://example.com/auth/realms/master/.well-known/jwks.json", + "http://localhost:8080/.well-known/jwks.json" + }; + + for (int i = 0; i < testCases.length; i++) { + String result = JwksUriResolver.deriveJwksUriFallback(testCases[i]); + assertEquals(result, expectedResults[i], "Failed for input: " + testCases[i]); + } + } +} diff --git a/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/OAuthConfigurationFetcherTest.java b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/OAuthConfigurationFetcherTest.java new file mode 100644 index 00000000000000..07ca151a46610b --- /dev/null +++ b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/authenticator/OAuthConfigurationFetcherTest.java @@ -0,0 +1,452 @@ +package com.datahub.authentication.authenticator; + +import static com.linkedin.metadata.Constants.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.testng.Assert.*; + +import com.linkedin.metadata.entity.EntityService; +import com.linkedin.settings.global.GlobalSettingsInfo; +import com.linkedin.settings.global.OAuthProvider; +import com.linkedin.settings.global.OAuthProviderArray; +import com.linkedin.settings.global.OAuthSettings; +import io.datahubproject.metadata.context.OperationContext; +import io.datahubproject.test.metadata.context.TestOperationContexts; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class OAuthConfigurationFetcherTest { + + private static final String TEST_ISSUER = "https://auth.example.com"; + private static final String TEST_AUDIENCE = "test-client-id"; + private static final String TEST_JWKS_URI = "https://auth.example.com/.well-known/jwks.json"; + + @Mock private EntityService mockEntityService; + + private OAuthConfigurationFetcher fetcher; + private OperationContext operationContext; + + @BeforeMethod + public void setUp() { + MockitoAnnotations.openMocks(this); + fetcher = new OAuthConfigurationFetcher(); + operationContext = TestOperationContexts.systemContextNoSearchAuthorization(); + } + + @AfterMethod + public void tearDown() { + if (fetcher != null) { + fetcher.destroy(); + } + } + + @Test + public void testInitializeWithValidStaticConfig() { + // Arrange + Map config = createValidStaticConfig(); + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertTrue(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 1); + + OAuthProvider provider = providers.get(0); + assertEquals(provider.getIssuer(), TEST_ISSUER); + assertEquals(provider.getAudience(), TEST_AUDIENCE); + assertEquals(provider.getJwksUri(), TEST_JWKS_URI); + assertTrue(provider.getName().startsWith("static_")); + assertTrue(Boolean.TRUE.equals(provider.data().get("enabled"))); + } + + @Test + public void testInitializeWithMultipleStaticProviders() { + // Arrange + Map config = new HashMap<>(); + config.put("trustedIssuers", "issuer1,issuer2"); + config.put("allowedAudiences", "aud1,aud2"); + config.put("jwksUri", TEST_JWKS_URI); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertTrue(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 4); // 2 issuers × 2 audiences + + // Verify all combinations exist + boolean found11 = false, found12 = false, found21 = false, found22 = false; + for (OAuthProvider provider : providers) { + if ("issuer1".equals(provider.getIssuer()) && "aud1".equals(provider.getAudience())) { + found11 = true; + } else if ("issuer1".equals(provider.getIssuer()) && "aud2".equals(provider.getAudience())) { + found12 = true; + } else if ("issuer2".equals(provider.getIssuer()) && "aud1".equals(provider.getAudience())) { + found21 = true; + } else if ("issuer2".equals(provider.getIssuer()) && "aud2".equals(provider.getAudience())) { + found22 = true; + } + } + assertTrue(found11 && found12 && found21 && found22); + } + + @Test + public void testInitializeWithInvalidStaticConfig() { + // Arrange - Missing required fields + Map config = new HashMap<>(); + config.put("trustedIssuers", TEST_ISSUER); + // Missing allowedAudiences and jwksUri + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertFalse(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 0); + } + + @Test + public void testInitializeWithDynamicConfigOnly() { + // Arrange + Map config = new HashMap<>(); // No static config + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertTrue(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 2); // Two dynamic providers + + // Verify dynamic providers are present + boolean foundDynamic1 = false, foundDynamic2 = false; + for (OAuthProvider provider : providers) { + if ("dynamic-provider".equals(provider.getName())) { + foundDynamic1 = true; + assertEquals(provider.getIssuer(), "https://dynamic.example.com"); + } else if ("second-provider".equals(provider.getName())) { + foundDynamic2 = true; + assertEquals(provider.getIssuer(), "https://second.example.com"); + } + } + assertTrue(foundDynamic1 && foundDynamic2); + } + + @Test + public void testInitializeWithStaticAndDynamicConfig() { + // Arrange + Map config = createValidStaticConfig(); + GlobalSettingsInfo globalSettings = createGlobalSettingsWithOAuthProviders(); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertTrue(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 3); // 1 static + 2 dynamic + + // Verify both static and dynamic providers are present + boolean foundStatic = false, foundDynamic1 = false, foundDynamic2 = false; + for (OAuthProvider provider : providers) { + if (provider.getName().startsWith("static_")) { + foundStatic = true; + } else if ("dynamic-provider".equals(provider.getName())) { + foundDynamic1 = true; + } else if ("second-provider".equals(provider.getName())) { + foundDynamic2 = true; + } + } + assertTrue(foundStatic && foundDynamic1 && foundDynamic2); + } + + @Test + public void testInitializeWithDisabledDynamicProviders() { + // Arrange + Map config = new HashMap<>(); + GlobalSettingsInfo globalSettings = createGlobalSettingsWithDisabledProviders(); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(globalSettings); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertFalse(fetcher.isConfigured()); // Should be false since all providers are disabled + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 0); + } + + @Test + public void testFindMatchingProvider() { + // Arrange + Map config = createValidStaticConfig(); + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + fetcher.initialize(config, mockEntityService, operationContext); + + // Act & Assert - Matching provider + OAuthProvider matchingProvider = + fetcher.findMatchingProvider(TEST_ISSUER, Arrays.asList(TEST_AUDIENCE)); + assertNotNull(matchingProvider); + assertEquals(matchingProvider.getIssuer(), TEST_ISSUER); + assertEquals(matchingProvider.getAudience(), TEST_AUDIENCE); + + // Act & Assert - Non-matching issuer + OAuthProvider nonMatchingIssuer = + fetcher.findMatchingProvider("https://wrong.issuer.com", Arrays.asList(TEST_AUDIENCE)); + assertNull(nonMatchingIssuer); + + // Act & Assert - Non-matching audience + OAuthProvider nonMatchingAudience = + fetcher.findMatchingProvider(TEST_ISSUER, Arrays.asList("wrong-audience")); + assertNull(nonMatchingAudience); + + // Act & Assert - Multiple audiences with one match + OAuthProvider multipleAudiences = + fetcher.findMatchingProvider(TEST_ISSUER, Arrays.asList("wrong-audience", TEST_AUDIENCE)); + assertNotNull(multipleAudiences); + assertEquals(multipleAudiences.getAudience(), TEST_AUDIENCE); + } + + @Test + public void testForceRefreshConfiguration() { + // Arrange + Map config = createValidStaticConfig(); + GlobalSettingsInfo initialSettings = null; + GlobalSettingsInfo updatedSettings = createGlobalSettingsWithOAuthProviders(); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(initialSettings) + .thenReturn(updatedSettings); + + fetcher.initialize(config, mockEntityService, operationContext); + assertEquals(fetcher.getCachedConfiguration().size(), 1); // Only static + + // Act + List refreshedProviders = fetcher.forceRefreshConfiguration(); + + // Assert + assertEquals(refreshedProviders.size(), 3); // Static + 2 dynamic + verify(mockEntityService, times(2)) + .getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME)); + } + + @Test + public void testDynamicConfigurationErrorHandling() { + // Arrange + Map config = createValidStaticConfig(); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenThrow(new RuntimeException("GlobalSettings fetch failed")); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert - Should still be configured with static config despite dynamic error + assertTrue(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 1); // Only static provider + } + + @Test + public void testIsConfiguredReturnsFalseWhenNoProviders() { + // Arrange + Map config = new HashMap<>(); // Empty config + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertFalse(fetcher.isConfigured()); + } + + @Test + public void testGetCachedConfigurationReturnsImmutableCopy() { + // Arrange + Map config = createValidStaticConfig(); + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + fetcher.initialize(config, mockEntityService, operationContext); + + // Act + List providers1 = fetcher.getCachedConfiguration(); + List providers2 = fetcher.getCachedConfiguration(); + + // Assert - Should be different instances (defensive copies) + assertNotSame(providers1, providers2); + assertEquals(providers1.size(), providers2.size()); + } + + @Test + public void testDestroy() { + // Arrange + Map config = createValidStaticConfig(); + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + fetcher.initialize(config, mockEntityService, operationContext); + + // Act + fetcher.destroy(); + + // Assert - Should not throw exception, scheduler should be cleaned up + // Multiple destroy calls should be safe + fetcher.destroy(); + } + + @Test + public void testInitializeWithDiscoveryUri() { + // Arrange + Map config = new HashMap<>(); + config.put("trustedIssuers", TEST_ISSUER); + config.put("allowedAudiences", TEST_AUDIENCE); + config.put("discoveryUri", "https://auth.example.com/.well-known/openid-configuration"); + // No jwksUri - should be derived from discoveryUri + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + assertTrue(fetcher.isConfigured()); + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 1); + + OAuthProvider provider = providers.get(0); + assertNotNull(provider.getJwksUri()); + // JWKS URI should be derived by JwksUriResolver (actual value depends on network call or + // fallback) + } + + @Test + public void testStaticProviderNaming() { + // Arrange + Map config = new HashMap<>(); + config.put("trustedIssuers", "https://auth-server.example.com:443/oauth2"); + config.put("allowedAudiences", TEST_AUDIENCE); + config.put("jwksUri", TEST_JWKS_URI); + + when(mockEntityService.getLatestAspect( + eq(operationContext), eq(GLOBAL_SETTINGS_URN), eq(GLOBAL_SETTINGS_INFO_ASPECT_NAME))) + .thenReturn(null); + + // Act + fetcher.initialize(config, mockEntityService, operationContext); + + // Assert + List providers = fetcher.getCachedConfiguration(); + assertEquals(providers.size(), 1); + + OAuthProvider provider = providers.get(0); + assertTrue(provider.getName().startsWith("static_")); + // Name should have special characters replaced with underscores + assertTrue(provider.getName().contains("auth_server_example_com")); + } + + // Helper methods + private Map createValidStaticConfig() { + Map config = new HashMap<>(); + config.put("trustedIssuers", TEST_ISSUER); + config.put("allowedAudiences", TEST_AUDIENCE); + config.put("jwksUri", TEST_JWKS_URI); + return config; + } + + private GlobalSettingsInfo createGlobalSettingsWithOAuthProviders() { + GlobalSettingsInfo globalSettings = new GlobalSettingsInfo(); + OAuthSettings oauthSettings = new OAuthSettings(); + OAuthProviderArray providers = new OAuthProviderArray(); + + // Create first enabled provider + OAuthProvider provider1 = new OAuthProvider(); + provider1.data().put("enabled", Boolean.TRUE); + provider1.setName("dynamic-provider"); + provider1.setIssuer("https://dynamic.example.com"); + provider1.setAudience("dynamic-audience"); + provider1.setJwksUri("https://dynamic.example.com/jwks"); + providers.add(provider1); + + // Create second enabled provider + OAuthProvider provider2 = new OAuthProvider(); + provider2.data().put("enabled", Boolean.TRUE); + provider2.setName("second-provider"); + provider2.setIssuer("https://second.example.com"); + provider2.setAudience("second-audience"); + provider2.setJwksUri("https://second.example.com/jwks"); + providers.add(provider2); + + oauthSettings.setProviders(providers); + globalSettings.setOauth(oauthSettings); + + return globalSettings; + } + + private GlobalSettingsInfo createGlobalSettingsWithDisabledProviders() { + GlobalSettingsInfo globalSettings = new GlobalSettingsInfo(); + OAuthSettings oauthSettings = new OAuthSettings(); + OAuthProviderArray providers = new OAuthProviderArray(); + + // Create disabled provider + OAuthProvider disabledProvider = new OAuthProvider(); + disabledProvider.data().put("enabled", Boolean.FALSE); + disabledProvider.setName("disabled-provider"); + disabledProvider.setIssuer("https://disabled.example.com"); + disabledProvider.setAudience("disabled-audience"); + disabledProvider.setJwksUri("https://disabled.example.com/jwks"); + providers.add(disabledProvider); + + oauthSettings.setProviders(providers); + globalSettings.setOauth(oauthSettings); + + return globalSettings; + } +} diff --git a/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/DataHubOAuthSigningKeyResolverTest.java b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/DataHubOAuthSigningKeyResolverTest.java new file mode 100644 index 00000000000000..4f0e73454d23d7 --- /dev/null +++ b/metadata-service/auth-impl/src/test/java/com/datahub/authentication/token/DataHubOAuthSigningKeyResolverTest.java @@ -0,0 +1,348 @@ +package com.datahub.authentication.token; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.JwsHeader; +import java.math.BigInteger; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.security.Key; +import java.security.KeyFactory; +import java.security.PublicKey; +import java.security.spec.RSAPublicKeySpec; +import java.util.Base64; +import java.util.HashSet; +import org.json.JSONObject; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class DataHubOAuthSigningKeyResolverTest { + + private static final String TEST_ISSUER = "https://auth.example.com"; + private static final String TEST_JWKS_URI = "https://auth.example.com/.well-known/jwks.json"; + private static final String TEST_ALGORITHM = "RS256"; + private static final String TEST_KEY_ID = "test-key-id"; + + private DataHubOAuthSigningKeyResolver resolver; + private HttpClient mockHttpClient; + private HttpResponse mockHttpResponse; + private JwsHeader mockJwsHeader; + private Claims mockClaims; + + @BeforeMethod + public void setUp() { + HashSet trustedIssuers = new HashSet<>(); + trustedIssuers.add(TEST_ISSUER); + + mockHttpClient = mock(HttpClient.class); + mockHttpResponse = mock(HttpResponse.class); + mockJwsHeader = mock(JwsHeader.class); + mockClaims = mock(Claims.class); + + // Use the new constructor that accepts HttpClient for testing + resolver = + new DataHubOAuthSigningKeyResolver( + trustedIssuers, TEST_JWKS_URI, TEST_ALGORITHM, mockHttpClient); + } + + @Test + public void testResolveSigningKeySuccess() throws Exception { + // Arrange + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn(TEST_ALGORITHM); + + String jwksResponse = createValidJwksResponse(); + when(mockHttpResponse.body()).thenReturn(jwksResponse); + when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(mockHttpResponse); + + // Act + Key result = resolver.resolveSigningKey(mockJwsHeader, mockClaims); + + // Assert + assertNotNull(result); + assertEquals(result.getAlgorithm(), "RSA"); + } + + @Test + public void testResolveSigningKeyInvalidIssuer() { + // Arrange + when(mockClaims.getIssuer()).thenReturn("https://malicious.com"); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn(TEST_ALGORITHM); + + // Act & Assert + try { + resolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals( + e.getMessage(), "Unable to resolve signing key: Invalid issuer: https://malicious.com"); + } + } + + @Test + public void testResolveSigningKeyInvalidAlgorithm() { + // Arrange + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn("HS256"); // Wrong algorithm + + // Act & Assert + try { + resolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals( + e.getMessage(), + "Unable to resolve signing key: Invalid algorithm: expected RS256 but got HS256"); + } + } + + @Test + public void testResolveSigningKeyMismatchedAlgorithm() { + // Arrange + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn("RS512"); // Different RSA algorithm + + // Act & Assert + try { + resolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals( + e.getMessage(), + "Unable to resolve signing key: Invalid algorithm: expected RS256 but got RS512"); + } + } + + @Test + public void testResolveSigningKeyKeyNotFoundInJwks() throws Exception { + // Arrange + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn("missing_key_id"); + when(mockJwsHeader.getAlgorithm()).thenReturn(TEST_ALGORITHM); + + String jwksResponse = createValidJwksResponse(); + when(mockHttpResponse.body()).thenReturn(jwksResponse); + when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(mockHttpResponse); + + // Act & Assert + try { + resolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals( + e.getMessage(), + "Unable to resolve signing key: No matching key found in JWKS for kid=missing_key_id"); + } + } + + @Test + public void testResolveSigningKeyUnsupportedKeyType() throws Exception { + // Arrange + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn(TEST_ALGORITHM); + + String jwksResponse = createJwksResponseWithUnsupportedKeyType(); + when(mockHttpResponse.body()).thenReturn(jwksResponse); + when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(mockHttpResponse); + + // Act & Assert + try { + resolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals( + e.getMessage(), + "Unable to resolve signing key: Algorithm RS256 requires RSA key type, but got: EC"); + } + } + + @Test + public void testResolveSigningKeyWithECDSAAlgorithm() throws Exception { + // Arrange - Create resolver expecting ES256 algorithm + HashSet trustedIssuers = new HashSet<>(); + trustedIssuers.add(TEST_ISSUER); + DataHubOAuthSigningKeyResolver ecdsaResolver = + new DataHubOAuthSigningKeyResolver(trustedIssuers, TEST_JWKS_URI, "ES256", mockHttpClient); + + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn("ES256"); + + String jwksResponse = createJwksResponseWithUnsupportedKeyType(); + when(mockHttpResponse.body()).thenReturn(jwksResponse); + when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(mockHttpResponse); + + // Act & Assert + try { + ecdsaResolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals( + e.getMessage(), "Unable to resolve signing key: ECDSA algorithms not yet supported"); + } + } + + @Test + public void testResolveSigningKeyWithUnsupportedAlgorithm() throws Exception { + // Arrange - Create resolver with unsupported algorithm + HashSet trustedIssuers = new HashSet<>(); + trustedIssuers.add(TEST_ISSUER); + DataHubOAuthSigningKeyResolver hmacResolver = + new DataHubOAuthSigningKeyResolver(trustedIssuers, TEST_JWKS_URI, "HS256", mockHttpClient); + + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn("HS256"); + + String jwksResponse = createJwksResponseWithHMACKeyType(); + when(mockHttpResponse.body()).thenReturn(jwksResponse); + when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(mockHttpResponse); + + // Act & Assert + try { + hmacResolver.resolveSigningKey(mockJwsHeader, mockClaims); + assertNotNull(null, "Expected RuntimeException to be thrown"); + } catch (RuntimeException e) { + assertEquals(e.getMessage(), "Unable to resolve signing key: Unsupported algorithm: HS256"); + } + } + + @Test + public void testResolveSigningKeyWithPS256Algorithm() throws Exception { + // Arrange - Create resolver expecting PS256 algorithm (RSA-PSS) + HashSet trustedIssuers = new HashSet<>(); + trustedIssuers.add(TEST_ISSUER); + DataHubOAuthSigningKeyResolver ps256Resolver = + new DataHubOAuthSigningKeyResolver(trustedIssuers, TEST_JWKS_URI, "PS256", mockHttpClient); + + when(mockClaims.getIssuer()).thenReturn(TEST_ISSUER); + when(mockJwsHeader.getKeyId()).thenReturn(TEST_KEY_ID); + when(mockJwsHeader.getAlgorithm()).thenReturn("PS256"); + + String jwksResponse = createValidJwksResponse(); // RSA key works for PS256 + when(mockHttpResponse.body()).thenReturn(jwksResponse); + when(mockHttpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(mockHttpResponse); + + // Act + Key result = ps256Resolver.resolveSigningKey(mockJwsHeader, mockClaims); + + // Assert + assertNotNull(result); + assertEquals(result.getAlgorithm(), "RSA"); + } + + @Test + public void testConstructorWithValidParameters() { + // Arrange + HashSet issuers = new HashSet<>(); + issuers.add("https://example.com"); + + // Act + DataHubOAuthSigningKeyResolver testResolver = + new DataHubOAuthSigningKeyResolver(issuers, "https://example.com/jwks", "RS256"); + + // Assert + assertNotNull(testResolver); + } + + private String createValidJwksResponse() throws Exception { + // Create a valid JWKS response with a real RSA public key + JSONObject jwks = new JSONObject(); + JSONObject key = new JSONObject(); + + // Sample RSA public key components (these are safe test values) + String modulus = + "0vx7agoebGcQSuuPiLJXZptN9nndrQmbPFRP_gdHzfK3kczjmpsYRIFpqRYwtCAG3KOUKnp7EIbmgZN7I1l" + + "_jBmjmfsGZHqG6dMwL3EwwU7rEUGXZRe0YJ_GWZjEK1HXf3rPCNjkOBYKjSJPnFjDPpK1" + + "_XLIpLqYD8pj4Y-7E5uVa5E8kJvOPllGd4wGLJE6UjqQJ3NbPKHNYGZOdx9J9bL8YJbM" + + "YGJK3l3c6CmjnSjZRh"; + String exponent = "AQAB"; + + key.put("kty", "RSA"); + key.put("kid", TEST_KEY_ID); + key.put("use", "sig"); + key.put("alg", TEST_ALGORITHM); + key.put("n", modulus); + key.put("e", exponent); + + jwks.put("keys", new Object[] {key}); + + return jwks.toString(); + } + + private String createJwksResponseWithUnsupportedKeyType() { + // Create a JWKS response with an unsupported key type (EC instead of RSA) + JSONObject jwks = new JSONObject(); + JSONObject key = new JSONObject(); + + key.put("kty", "EC"); + key.put("kid", TEST_KEY_ID); + key.put("use", "sig"); + key.put("alg", "ES256"); + key.put("crv", "P-256"); + key.put("x", "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4"); + key.put("y", "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"); + + jwks.put("keys", new Object[] {key}); + + return jwks.toString(); + } + + private String createJwksResponseWithHMACKeyType() { + // Create a JWKS response with oct key type for HMAC algorithms + JSONObject jwks = new JSONObject(); + JSONObject key = new JSONObject(); + + key.put("kty", "oct"); + key.put("kid", TEST_KEY_ID); + key.put("use", "sig"); + key.put("alg", "HS256"); + key.put("k", "GawgguFyGrWKav7AX4VKUg"); + + jwks.put("keys", new Object[] {key}); + + return jwks.toString(); + } + + @Test + public void testRSAPublicKeyGeneration() throws Exception { + // Test helper method to generate RSA public key + String modulus = + "0vx7agoebGcQSuuPiLJXZptN9nndrQmbPFRP_gdHzfK3kczjmpsYRIFpqRYwtCAG3KOUKnp7EIbmgZN7I1l" + + "_jBmjmfsGZHqG6dMwL3EwwU7rEUGXZRe0YJ_GWZjEK1HXf3rPCNjkOBYKjSJPnFjDPpK1" + + "_XLIpLqYD8pj4Y-7E5uVa5E8kJvOPllGd4wGLJE6UjqQJ3NbPKHNYGZOdx9J9bL8YJbM" + + "YGJK3l3c6CmjnSjZRh"; + String exponent = "AQAB"; + + // Decode base64url + byte[] modulusBytes = Base64.getUrlDecoder().decode(modulus); + byte[] exponentBytes = Base64.getUrlDecoder().decode(exponent); + + BigInteger modulusBigInt = new BigInteger(1, modulusBytes); + BigInteger exponentBigInt = new BigInteger(1, exponentBytes); + + RSAPublicKeySpec keySpec = new RSAPublicKeySpec(modulusBigInt, exponentBigInt); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + PublicKey publicKey = keyFactory.generatePublic(keySpec); + + assertNotNull(publicKey); + assertEquals(publicKey.getAlgorithm(), "RSA"); + } +} diff --git a/metadata-service/configuration/src/main/resources/application.yaml b/metadata-service/configuration/src/main/resources/application.yaml index 49f906e752f796..98738a4f3f47a1 100644 --- a/metadata-service/configuration/src/main/resources/application.yaml +++ b/metadata-service/configuration/src/main/resources/application.yaml @@ -23,6 +23,30 @@ authentication: salt: ${DATAHUB_TOKEN_SERVICE_SALT:ohDVbJBvHHVJh9S/UA4BYF9COuNnqqVhr9MLKEGXk1O=} # Required for unauthenticated health check endpoints - best not to remove. - type: com.datahub.authentication.authenticator.HealthStatusAuthenticator + # OAuth/OIDC JWT token authenticator for service accounts + # Uses static configuration only (no dynamic GlobalSettings) + - type: com.datahub.authentication.authenticator.DataHubOAuthAuthenticator + configs: + # Enable/disable External OAuth authentication - Global disable switch. + enabled: ${EXTERNAL_OAUTH_ENABLED:false} + + # Trusted JWT issuers - must match the 'iss' claim in JWT tokens (comma-separated) + trustedIssuers: ${EXTERNAL_OAUTH_TRUSTED_ISSUERS:} + + # Allowed audiences - must match the 'aud' claim in JWT tokens (comma-separated) + allowedAudiences: ${EXTERNAL_OAUTH_ALLOWED_AUDIENCES:} + + # Option 1: Direct JWKS URI for fetching JWT signing keys + jwksUri: ${EXTERNAL_OAUTH_JWKS_URI:} + + # Option 2: Discovery URI to auto-derive JWKS URI (alternative to jwksUri) + discoveryUri: ${EXTERNAL_OAUTH_DISCOVERY_URI:} + + # JWT claim to use as user identifier (defaults to 'sub') + userIdClaim: ${EXTERNAL_OAUTH_USER_ID_CLAIM:sub} + + # JWT signing algorithm (defaults to 'RS256') + algorithm: ${EXTERNAL_OAUTH_ALGORITHM:RS256} - type: com.datahub.authentication.authenticator.DataHubGuestAuthenticator configs: guestUser: ${GUEST_AUTHENTICATION_USER:guest} @@ -364,7 +388,7 @@ kafka: topics: # Topic Dictionary Configuration - merged directly into topics section - # Each topic can be created by iterating through this dictionary + # Each topic can be created by iterating through this dictionary # The key name matches the programmatic identifier used in code metadataChangeProposal: name: ${METADATA_CHANGE_PROPOSAL_TOPIC_NAME:MetadataChangeProposal_v1}