diff --git a/packages/auth/__tests__/providers/cognito/refreshToken.test.ts b/packages/auth/__tests__/providers/cognito/refreshToken.test.ts index a298e1aa377..c84dfefe0d7 100644 --- a/packages/auth/__tests__/providers/cognito/refreshToken.test.ts +++ b/packages/auth/__tests__/providers/cognito/refreshToken.test.ts @@ -1,3 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + import { decodeJWT } from '@aws-amplify/core/internals/utils'; import { refreshAuthTokens } from '../../../src/providers/cognito/utils/refreshAuthTokens'; @@ -60,6 +63,7 @@ describe('refreshToken', () => { }); it('should refresh token', async () => { + const clientMetadata = { 'app-version': '1.0.0' }; const expectedOutput = { accessToken: decodeJWT(mockAccessToken), idToken: decodeJWT(mockAccessToken), @@ -82,6 +86,7 @@ describe('refreshToken', () => { }, }, username: mockedUsername, + clientMetadata, }); // stringify and re-parse for JWT equality @@ -93,6 +98,7 @@ describe('refreshToken', () => { expect.objectContaining({ ClientId: 'aaaaaaaaaaaa', RefreshToken: mockedRefreshToken, + ClientMetadata: clientMetadata, }), ); }); diff --git a/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts b/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts index 8906d8d7eed..de276fafd5a 100644 --- a/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts +++ b/packages/auth/__tests__/providers/cognito/tokenOrchestrator.test.ts @@ -151,4 +151,40 @@ describe('TokenOrchestrator', () => { expect(tokens?.accessToken).toEqual(validAuthTokens.accessToken); }); }); + + describe('setClientMetadataProvider', () => { + it('should use clientMetadataProvider for token refresh', async () => { + const clientMetadata = { 'app-version': '1.0.0' }; + const clientMetadataProvider = () => Promise.resolve(clientMetadata); + + mockTokenRefresher.mockResolvedValue({ + accessToken: { payload: {} }, + idToken: { payload: {} }, + clockDrift: 0, + refreshToken: 'newRefreshToken', + username: 'testuser', + }); + + tokenOrchestrator.setTokenRefresher(mockTokenRefresher); + tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore); + tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider); + + mockAuthTokenStore.loadTokens.mockResolvedValue({ + accessToken: { payload: { exp: 1 } }, + idToken: { payload: { exp: 1 } }, + clockDrift: 0, + refreshToken: 'refreshToken', + username: 'testuser', + }); + mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser'); + + await tokenOrchestrator.getTokens({ forceRefresh: true }); + + expect(mockTokenRefresher).toHaveBeenCalledWith( + expect.objectContaining({ + clientMetadata, + }), + ); + }); + }); }); diff --git a/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts b/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts index 43f0f8a2d8c..ca4a300d598 100644 --- a/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts +++ b/packages/auth/src/providers/cognito/tokenProvider/CognitoUserPoolsTokenProvider.ts @@ -4,6 +4,7 @@ import { AuthConfig, AuthTokens, + ClientMetadataProvider, FetchAuthSessionOptions, KeyValueStorageInterface, defaultStorage, @@ -38,6 +39,12 @@ export class CognitoUserPoolsTokenProvider this.authTokenStore.setKeyValueStorage(keyValueStorage); } + setClientMetadataProvider( + clientMetadataProvider: ClientMetadataProvider, + ): void { + this.tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider); + } + setAuthConfig(authConfig: AuthConfig) { this.authTokenStore.setAuthConfig(authConfig); this.tokenOrchestrator.setAuthConfig(authConfig); diff --git a/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts b/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts index 3f8027d2596..851db00846f 100644 --- a/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts +++ b/packages/auth/src/providers/cognito/tokenProvider/TokenOrchestrator.ts @@ -3,6 +3,7 @@ import { AuthConfig, AuthTokens, + ClientMetadataProvider, CognitoUserPoolConfig, FetchAuthSessionOptions, Hub, @@ -19,7 +20,7 @@ import { assertServiceError } from '../../../errors/utils/assertServiceError'; import { AuthError } from '../../../errors/AuthError'; import { oAuthStore } from '../utils/oauth/oAuthStore'; import { addInflightPromise } from '../utils/oauth/inflightPromise'; -import { CognitoAuthSignInDetails } from '../types'; +import { ClientMetadata, CognitoAuthSignInDetails } from '../types'; import { AuthTokenOrchestrator, @@ -32,6 +33,7 @@ import { export class TokenOrchestrator implements AuthTokenOrchestrator { private authConfig?: AuthConfig; + clientMetadataProvider?: ClientMetadataProvider; tokenStore?: AuthTokenStore; tokenRefresher?: TokenRefresher; inflightPromise: Promise | undefined; @@ -94,6 +96,12 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { return this.tokenRefresher; } + setClientMetadataProvider( + clientMetadataProvider: ClientMetadataProvider, + ): void { + this.clientMetadataProvider = clientMetadataProvider; + } + async getTokens( options?: FetchAuthSessionOptions, ): Promise< @@ -130,6 +138,8 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { tokens = await this.refreshTokens({ tokens, username, + clientMetadata: + options?.clientMetadata ?? (await this.clientMetadataProvider?.()), }); if (tokens === null) { @@ -147,9 +157,11 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { private async refreshTokens({ tokens, username, + clientMetadata, }: { tokens: CognitoAuthTokens; username: string; + clientMetadata?: ClientMetadata; }): Promise { try { const { signInDetails } = tokens; @@ -157,6 +169,7 @@ export class TokenOrchestrator implements AuthTokenOrchestrator { tokens, authConfig: this.authConfig, username, + clientMetadata, }); newTokens.signInDetails = signInDetails; await this.setTokens({ tokens: newTokens }); diff --git a/packages/auth/src/providers/cognito/tokenProvider/types.ts b/packages/auth/src/providers/cognito/tokenProvider/types.ts index 5f381b42016..4ac6973b60e 100644 --- a/packages/auth/src/providers/cognito/tokenProvider/types.ts +++ b/packages/auth/src/providers/cognito/tokenProvider/types.ts @@ -3,21 +3,24 @@ import { AuthConfig, AuthTokens, + ClientMetadataProvider, FetchAuthSessionOptions, KeyValueStorageInterface, TokenProvider, } from '@aws-amplify/core'; -import { CognitoAuthSignInDetails } from '../types'; +import { ClientMetadata, CognitoAuthSignInDetails } from '../types'; export type TokenRefresher = ({ tokens, authConfig, username, + clientMetadata, }: { tokens: CognitoAuthTokens; authConfig?: AuthConfig; username: string; + clientMetadata?: ClientMetadata; }) => Promise; export type AuthKeys = Record; @@ -66,6 +69,9 @@ export interface AuthTokenOrchestrator { export interface CognitoUserPoolTokenProviderType extends TokenProvider { setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void; setAuthConfig(authConfig: AuthConfig): void; + setClientMetadataProvider( + clientMetadataProvider: ClientMetadataProvider, + ): void; } export type CognitoAuthTokens = AuthTokens & { diff --git a/packages/auth/src/providers/cognito/types/models.ts b/packages/auth/src/providers/cognito/types/models.ts index 1b113ef1720..8bd212d117a 100644 --- a/packages/auth/src/providers/cognito/types/models.ts +++ b/packages/auth/src/providers/cognito/types/models.ts @@ -38,7 +38,7 @@ export const cognitoHostedUIIdentityProviderMap: Record = /** * Arbitrary key/value pairs that may be passed as part of certain Cognito requests */ -export type ClientMetadata = Record; +export type { ClientMetadata } from '@aws-amplify/core'; /** * Allowed values for preferredChallenge diff --git a/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts b/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts index e29bd460b19..54f30f7a796 100644 --- a/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts +++ b/packages/auth/src/providers/cognito/utils/refreshAuthTokens.ts @@ -14,15 +14,18 @@ import { assertAuthTokensWithRefreshToken } from '../utils/types'; import { AuthError } from '../../../errors/AuthError'; import { createCognitoUserPoolEndpointResolver } from '../factories'; import { createGetTokensFromRefreshTokenClient } from '../../../foundation/factories/serviceClients/cognitoIdentityProvider'; +import { ClientMetadata } from '../types'; const refreshAuthTokensFunction: TokenRefresher = async ({ tokens, authConfig, username, + clientMetadata, }: { tokens: CognitoAuthTokens; authConfig?: AuthConfig; username: string; + clientMetadata?: ClientMetadata; }): Promise => { assertTokenProviderConfig(authConfig?.Cognito); const { userPoolId, userPoolClientId, userPoolEndpoint } = authConfig.Cognito; @@ -41,6 +44,7 @@ const refreshAuthTokensFunction: TokenRefresher = async ({ ClientId: userPoolClientId, RefreshToken: tokens.refreshToken, DeviceKey: tokens.deviceMetadata?.deviceKey, + ClientMetadata: clientMetadata, }, ); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index b37829169b3..d65eac33003 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,6 +20,8 @@ export { OAuthConfig, CognitoUserPoolConfig, JWT, + ClientMetadata, + ClientMetadataProvider, } from './singleton/Auth/types'; export { decodeJWT } from './singleton/Auth/utils'; export { diff --git a/packages/core/src/singleton/Auth/types.ts b/packages/core/src/singleton/Auth/types.ts index 8fa811251b1..18dafd9761b 100644 --- a/packages/core/src/singleton/Auth/types.ts +++ b/packages/core/src/singleton/Auth/types.ts @@ -4,6 +4,16 @@ import { StrictUnion } from '../../types'; import { AtLeastOne } from '../types'; +/** + * Arbitrary key/value pairs that may be passed as part of certain Cognito requests + */ +export type ClientMetadata = Record; + +/** + * Function type for providing client metadata for Cognito operations + */ +export type ClientMetadataProvider = () => Promise; + // From https://github.com/awslabs/aws-jwt-verify/blob/main/src/safe-json-parse.ts // From https://github.com/awslabs/aws-jwt-verify/blob/main/src/jwt-model.ts interface JwtPayloadStandardFields { @@ -66,6 +76,7 @@ export interface TokenProvider { export interface FetchAuthSessionOptions { forceRefresh?: boolean; + clientMetadata?: ClientMetadata; } export interface AuthTokens {