Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import { decodeJWT } from '@aws-amplify/core/internals/utils';

import { refreshAuthTokens } from '../../../src/providers/cognito/utils/refreshAuthTokens';
Expand Down Expand Up @@ -60,6 +63,7 @@ describe('refreshToken', () => {
});

it('should refresh token', async () => {
const clientMetadata = { 'app-version': '1.0.0' };
const expectedOutput = {
accessToken: decodeJWT(mockAccessToken),
idToken: decodeJWT(mockAccessToken),
Expand All @@ -82,6 +86,7 @@ describe('refreshToken', () => {
},
},
username: mockedUsername,
clientMetadata,
});

// stringify and re-parse for JWT equality
Expand All @@ -93,6 +98,7 @@ describe('refreshToken', () => {
expect.objectContaining({
ClientId: 'aaaaaaaaaaaa',
RefreshToken: mockedRefreshToken,
ClientMetadata: clientMetadata,
}),
);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,40 @@ describe('TokenOrchestrator', () => {
expect(tokens?.accessToken).toEqual(validAuthTokens.accessToken);
});
});

describe('setClientMetadataProvider', () => {
it('should use clientMetadataProvider for token refresh', async () => {
const clientMetadata = { 'app-version': '1.0.0' };
const clientMetadataProvider = () => Promise.resolve(clientMetadata);

mockTokenRefresher.mockResolvedValue({
accessToken: { payload: {} },
idToken: { payload: {} },
clockDrift: 0,
refreshToken: 'newRefreshToken',
username: 'testuser',
});

tokenOrchestrator.setTokenRefresher(mockTokenRefresher);
tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore);
tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);

mockAuthTokenStore.loadTokens.mockResolvedValue({
accessToken: { payload: { exp: 1 } },
idToken: { payload: { exp: 1 } },
clockDrift: 0,
refreshToken: 'refreshToken',
username: 'testuser',
});
mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser');

await tokenOrchestrator.getTokens({ forceRefresh: true });

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
clientMetadata,
}),
);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
FetchAuthSessionOptions,
KeyValueStorageInterface,
defaultStorage,
Expand Down Expand Up @@ -38,6 +39,12 @@ export class CognitoUserPoolsTokenProvider
this.authTokenStore.setKeyValueStorage(keyValueStorage);
}

setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void {
this.tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);
}

setAuthConfig(authConfig: AuthConfig) {
this.authTokenStore.setAuthConfig(authConfig);
this.tokenOrchestrator.setAuthConfig(authConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
CognitoUserPoolConfig,
FetchAuthSessionOptions,
Hub,
Expand All @@ -19,7 +20,7 @@ import { assertServiceError } from '../../../errors/utils/assertServiceError';
import { AuthError } from '../../../errors/AuthError';
import { oAuthStore } from '../utils/oauth/oAuthStore';
import { addInflightPromise } from '../utils/oauth/inflightPromise';
import { CognitoAuthSignInDetails } from '../types';
import { ClientMetadata, CognitoAuthSignInDetails } from '../types';

import {
AuthTokenOrchestrator,
Expand All @@ -32,6 +33,7 @@ import {

export class TokenOrchestrator implements AuthTokenOrchestrator {
private authConfig?: AuthConfig;
clientMetadataProvider?: ClientMetadataProvider;
tokenStore?: AuthTokenStore;
tokenRefresher?: TokenRefresher;
inflightPromise: Promise<void> | undefined;
Expand Down Expand Up @@ -94,6 +96,12 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
return this.tokenRefresher;
}

setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void {
this.clientMetadataProvider = clientMetadataProvider;
}

async getTokens(
options?: FetchAuthSessionOptions,
): Promise<
Expand Down Expand Up @@ -130,6 +138,8 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
tokens = await this.refreshTokens({
tokens,
username,
clientMetadata:
options?.clientMetadata ?? (await this.clientMetadataProvider?.()),
});

if (tokens === null) {
Expand All @@ -147,16 +157,19 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
private async refreshTokens({
tokens,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
username: string;
clientMetadata?: ClientMetadata;
}): Promise<CognitoAuthTokens | null> {
try {
const { signInDetails } = tokens;
const newTokens = await this.getTokenRefresher()({
tokens,
authConfig: this.authConfig,
username,
clientMetadata,
});
newTokens.signInDetails = signInDetails;
await this.setTokens({ tokens: newTokens });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
FetchAuthSessionOptions,
KeyValueStorageInterface,
TokenProvider,
} from '@aws-amplify/core';

import { CognitoAuthSignInDetails } from '../types';
import { ClientMetadata, CognitoAuthSignInDetails } from '../types';

export type TokenRefresher = ({
tokens,
authConfig,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
authConfig?: AuthConfig;
username: string;
clientMetadata?: ClientMetadata;
}) => Promise<CognitoAuthTokens>;

export type AuthKeys<AuthKey extends string> = Record<AuthKey, string>;
Expand Down Expand Up @@ -66,6 +69,9 @@ export interface AuthTokenOrchestrator {
export interface CognitoUserPoolTokenProviderType extends TokenProvider {
setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void;
setAuthConfig(authConfig: AuthConfig): void;
setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void;
}

export type CognitoAuthTokens = AuthTokens & {
Expand Down
2 changes: 1 addition & 1 deletion packages/auth/src/providers/cognito/types/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export const cognitoHostedUIIdentityProviderMap: Record<AuthProvider, string> =
/**
* Arbitrary key/value pairs that may be passed as part of certain Cognito requests
*/
export type ClientMetadata = Record<string, string>;
export type { ClientMetadata } from '@aws-amplify/core';

/**
* Allowed values for preferredChallenge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ import { assertAuthTokensWithRefreshToken } from '../utils/types';
import { AuthError } from '../../../errors/AuthError';
import { createCognitoUserPoolEndpointResolver } from '../factories';
import { createGetTokensFromRefreshTokenClient } from '../../../foundation/factories/serviceClients/cognitoIdentityProvider';
import { ClientMetadata } from '../types';

const refreshAuthTokensFunction: TokenRefresher = async ({
tokens,
authConfig,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
authConfig?: AuthConfig;
username: string;
clientMetadata?: ClientMetadata;
}): Promise<CognitoAuthTokens> => {
assertTokenProviderConfig(authConfig?.Cognito);
const { userPoolId, userPoolClientId, userPoolEndpoint } = authConfig.Cognito;
Expand All @@ -41,6 +44,7 @@ const refreshAuthTokensFunction: TokenRefresher = async ({
ClientId: userPoolClientId,
RefreshToken: tokens.refreshToken,
DeviceKey: tokens.deviceMetadata?.deviceKey,
ClientMetadata: clientMetadata,
},
);

Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export {
OAuthConfig,
CognitoUserPoolConfig,
JWT,
ClientMetadata,
ClientMetadataProvider,
} from './singleton/Auth/types';
export { decodeJWT } from './singleton/Auth/utils';
export {
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/singleton/Auth/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import { StrictUnion } from '../../types';
import { AtLeastOne } from '../types';

/**
* Arbitrary key/value pairs that may be passed as part of certain Cognito requests
*/
export type ClientMetadata = Record<string, string>;

/**
* Function type for providing client metadata for Cognito operations
*/
export type ClientMetadataProvider = () => Promise<ClientMetadata>;

// From https://github.com/awslabs/aws-jwt-verify/blob/main/src/safe-json-parse.ts
// From https://github.com/awslabs/aws-jwt-verify/blob/main/src/jwt-model.ts
interface JwtPayloadStandardFields {
Expand Down Expand Up @@ -66,6 +76,7 @@ export interface TokenProvider {

export interface FetchAuthSessionOptions {
forceRefresh?: boolean;
clientMetadata?: ClientMetadata;
}

export interface AuthTokens {
Expand Down
Loading