diff --git a/server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.test.ts b/server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.test.ts index 3ee23089b2..bb9ced1528 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.test.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/agenticChat/agenticChatController.test.ts @@ -35,7 +35,11 @@ import { } from '@aws/language-server-runtimes/server-interface' import { TestFeatures } from '@aws/language-server-runtimes/testing' import * as assert from 'assert' -import { createIterableResponse, setCredentialsForAmazonQTokenServiceManagerFactory } from '../../shared/testUtils' +import { + createIterableResponse, + setCredentialsForAmazonQTokenServiceManagerFactory, + setIamCredentialsForAmazonQServiceManagerFactory, +} from '../../shared/testUtils' import sinon from 'ts-sinon' import { AgenticChatController } from './agenticChatController' import { ChatSessionManagementService } from '../chat/chatSessionManagementService' @@ -179,7 +183,8 @@ describe('AgenticChatController', () => { let getMessagesStub: sinon.SinonStub let addMessageStub: sinon.SinonStub - const setCredentials = setCredentialsForAmazonQTokenServiceManagerFactory(() => testFeatures) + const setSsoCredentials = setCredentialsForAmazonQTokenServiceManagerFactory(() => testFeatures) + const setIamCredentials = setIamCredentialsForAmazonQServiceManagerFactory(() => testFeatures) beforeEach(() => { // Override the response timeout for tests to avoid long waits @@ -272,7 +277,7 @@ describe('AgenticChatController', () => { } testFeatures.lsp.window.showDocument = sinon.stub() testFeatures.setClientParams(cachedInitializeParams) - setCredentials('builderId') + setSsoCredentials('builderId') activeTabSpy = sinon.spy(ChatTelemetryController.prototype, 'activeTabId', ['get', 'set']) removeConversationSpy = sinon.spy(ChatTelemetryController.prototype, 'removeConversation') @@ -3144,6 +3149,9 @@ ${' '.repeat(8)}} // Reset the singleton instance ChatSessionManagementService.reset() + // Store IAM credentials + setIamCredentials() + // Create IAM service manager AmazonQIAMServiceManager.resetInstance() iamServiceManager = AmazonQIAMServiceManager.initInstance(testFeatures) diff --git a/server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.test.ts b/server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.test.ts new file mode 100644 index 0000000000..f8da4e8c32 --- /dev/null +++ b/server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.test.ts @@ -0,0 +1,1092 @@ +import * as assert from 'assert' +import sinon, { StubbedInstance, stubInterface } from 'ts-sinon' +import { AmazonQServiceManager } from './AmazonQServiceManager' +import { TestFeatures } from '@aws/language-server-runtimes/testing' +import { GenerateSuggestionsRequest } from '../codeWhispererService/codeWhispererServiceBase' +import { CodeWhispererServiceToken } from '../codeWhispererService/codeWhispererServiceToken' +import { CodeWhispererServiceIAM } from '../codeWhispererService/codeWhispererServiceIAM' +import { + AmazonQServiceInitializationError, + AmazonQServicePendingProfileError, + AmazonQServicePendingProfileUpdateError, + AmazonQServicePendingSigninError, +} from './errors' +import { + CancellationToken, + InitializeParams, + LSPErrorCodes, + ResponseError, +} from '@aws/language-server-runtimes/protocol' +import { + AWS_Q_ENDPOINT_URL_ENV_VAR, + AWS_Q_ENDPOINTS, + AWS_Q_REGION_ENV_VAR, + DEFAULT_AWS_Q_ENDPOINT_URL, + DEFAULT_AWS_Q_REGION, +} from '../constants' +import * as qDeveloperProfilesFetcherModule from './qDeveloperProfiles' +import { + setTokenCredentialsForAmazonQServiceManagerFactory, + setIamCredentialsForAmazonQServiceManagerFactory, +} from '../testUtils' +import { StreamingClientServiceToken, StreamingClientServiceIAM } from '../streamingClientService' +import { generateSingletonInitializationTests } from './testUtils' +import * as utils from '../utils' + +export const mockedProfiles: qDeveloperProfilesFetcherModule.AmazonQDeveloperProfile[] = [ + { + arn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + name: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + identityDetails: { + region: 'us-east-1', + }, + }, + { + arn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ-2', + name: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ-2', + identityDetails: { + region: 'us-east-1', + }, + }, + { + arn: 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ', + name: 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ', + identityDetails: { + region: 'eu-central-1', + }, + }, +] + +const TEST_ENDPOINT_US_EAST_1 = 'http://amazon-q-in-us-east-1-endpoint' +const TEST_ENDPOINT_EU_CENTRAL_1 = 'http://amazon-q-in-eu-central-1-endpoint' + +describe('Token', () => { + let codewhispererServiceStub: StubbedInstance + let codewhispererStubFactory: sinon.SinonStub> + let sdkInitializatorSpy: sinon.SinonSpy + let getListAllAvailableProfilesHandlerStub: sinon.SinonStub + + let amazonQServiceManager: AmazonQServiceManager + let features: TestFeatures + + beforeEach(() => { + // Override endpoints for testing + AWS_Q_ENDPOINTS.set('us-east-1', TEST_ENDPOINT_US_EAST_1) + AWS_Q_ENDPOINTS.set('eu-central-1', TEST_ENDPOINT_EU_CENTRAL_1) + + getListAllAvailableProfilesHandlerStub = sinon + .stub() + .resolves( + Promise.resolve(mockedProfiles).then(() => + new Promise(resolve => setTimeout(resolve, 1)).then(() => mockedProfiles) + ) + ) + + sinon + .stub(qDeveloperProfilesFetcherModule, 'getListAllAvailableProfilesHandler') + .returns(getListAllAvailableProfilesHandlerStub) + + AmazonQServiceManager.resetInstance() + + features = new TestFeatures() + + sdkInitializatorSpy = Object.assign(sinon.spy(features.sdkInitializator), { + v2: sinon.spy(features.sdkInitializator.v2), + }) + + codewhispererServiceStub = stubInterface() + // @ts-ignore + codewhispererServiceStub.client = sinon.stub() + codewhispererServiceStub.customizationArn = undefined + codewhispererServiceStub.shareCodeWhispererContentWithAWS = false + codewhispererServiceStub.profileArn = undefined + + // Initialize the class with mocked dependencies + codewhispererStubFactory = sinon.stub().returns(codewhispererServiceStub) + }) + + afterEach(() => { + AmazonQServiceManager.resetInstance() + features.dispose() + sinon.restore() + }) + + const setupServiceManager = (enableProfiles = false) => { + // @ts-ignore + const cachedInitializeParams: InitializeParams = { + initializationOptions: { + aws: { + awsClientCapabilities: { + q: { + developerProfiles: enableProfiles, + }, + }, + }, + }, + } + features.setClientParams(cachedInitializeParams) + + AmazonQServiceManager.initInstance(features) + amazonQServiceManager = AmazonQServiceManager.getInstance() + amazonQServiceManager.setServiceFactory(codewhispererStubFactory) + } + + const setCredentials = setTokenCredentialsForAmazonQServiceManagerFactory(() => features) + + const clearCredentials = () => { + features.credentialsProvider.hasCredentials.returns(false) + features.credentialsProvider.getCredentials.returns(undefined) + features.credentialsProvider.getConnectionType.returns('none') + } + + const setupServiceManagerWithProfile = async ( + profileArn = 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ' + ): Promise => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: profileArn, + }, + }, + {} as CancellationToken + ) + + const service = amazonQServiceManager.getCodewhispererService() + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + + return service as CodeWhispererServiceToken + } + + describe('Initialization process', () => { + generateSingletonInitializationTests(AmazonQServiceManager) + }) + + describe('Client is not connected', () => { + it('should be in PENDING_CONNECTION state when bearer token is not set', () => { + setupServiceManager() + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + clearCredentials() + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingSigninError) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'none') + }) + }) + + describe('Clear state upon bearer token deletion', () => { + let cancelActiveProfileChangeTokenSpy: sinon.SinonSpy + + beforeEach(() => { + setupServiceManager() + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + cancelActiveProfileChangeTokenSpy = sinon.spy( + amazonQServiceManager as any, + 'cancelActiveProfileChangeToken' + ) + + setCredentials('builderId') + }) + + it('should clear local state variables on receiving bearer token deletion event', () => { + amazonQServiceManager.getCodewhispererService() + + amazonQServiceManager.handleOnCredentialsDeleted('bearer') + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'none') + assert.strictEqual((amazonQServiceManager as any)['cachedCodewhispererService'], undefined) + assert.strictEqual((amazonQServiceManager as any)['cachedStreamingClient'], undefined) + assert.strictEqual((amazonQServiceManager as any)['activeIdcProfile'], undefined) + sinon.assert.calledOnce(cancelActiveProfileChangeTokenSpy) + }) + + it('should not clear local state variables on receiving iam token deletion event', () => { + amazonQServiceManager.getCodewhispererService() + + amazonQServiceManager.handleOnCredentialsDeleted('iam') + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId') + assert(!(amazonQServiceManager['cachedCodewhispererService'] === undefined)) + assert.strictEqual((amazonQServiceManager as any)['activeIdcProfile'], undefined) + sinon.assert.notCalled(cancelActiveProfileChangeTokenSpy) + }) + }) + + describe('BuilderId support', () => { + const testRegion = 'some-region' + const testEndpoint = 'http://some-endpoint-in-some-region' + + beforeEach(() => { + setupServiceManager() + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('builderId') + + AWS_Q_ENDPOINTS.set(testRegion, testEndpoint) + + features.lsp.getClientInitializeParams.reset() + }) + + it('should be INITIALIZED with BuilderId Connection', async () => { + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId') + + assert(streamingClient instanceof StreamingClientServiceToken) + assert(codewhispererServiceStub.generateSuggestions.calledOnce) + }) + + it('should initialize service with region set by client', async () => { + features.setClientParams({ + processId: 0, + rootUri: 'some-root-uri', + capabilities: {}, + initializationOptions: { + aws: { + region: testRegion, + }, + }, + }) + + amazonQServiceManager.getCodewhispererService() + assert(codewhispererStubFactory.calledOnceWithExactly(testRegion, testEndpoint)) + + const streamingClient = amazonQServiceManager.getStreamingClient() + assert.strictEqual(await streamingClient.client.config.region(), testRegion) + assert.strictEqual( + (await streamingClient.client.config.endpoint()).hostname, + 'some-endpoint-in-some-region' + ) + }) + + it('should initialize service with region set by runtime if not set by client', async () => { + features.runtime.getConfiguration.withArgs(AWS_Q_REGION_ENV_VAR).returns('eu-central-1') + features.runtime.getConfiguration.withArgs(AWS_Q_ENDPOINT_URL_ENV_VAR).returns(TEST_ENDPOINT_EU_CENTRAL_1) + + amazonQServiceManager.getCodewhispererService() + assert(codewhispererStubFactory.calledOnceWithExactly('eu-central-1', TEST_ENDPOINT_EU_CENTRAL_1)) + + const streamingClient = amazonQServiceManager.getStreamingClient() + assert.strictEqual(await streamingClient.client.config.region(), 'eu-central-1') + assert.strictEqual( + (await streamingClient.client.config.endpoint()).hostname, + 'amazon-q-in-eu-central-1-endpoint' + ) + }) + + it('should initialize service with default region if not set by client and runtime', async () => { + amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + + assert(codewhispererStubFactory.calledOnceWithExactly(DEFAULT_AWS_Q_REGION, DEFAULT_AWS_Q_ENDPOINT_URL)) + + assert.strictEqual(await streamingClient.client.config.region(), DEFAULT_AWS_Q_REGION) + assert.strictEqual( + (await streamingClient.client.config.endpoint()).hostname, + 'codewhisperer.us-east-1.amazonaws.com' + ) + }) + }) + + describe('IdentityCenter support', () => { + describe('Developer Profiles Support is disabled', () => { + it('should be INITIALIZED with IdentityCenter Connection', async () => { + setupServiceManager() + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert(codewhispererServiceStub.generateSuggestions.calledOnce) + + assert(streamingClient instanceof StreamingClientServiceToken) + }) + }) + + describe('Developer Profiles Support is enabled', () => { + it('should not throw when receiving null profile arn in PENDING_CONNECTION state', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + await assert.doesNotReject( + amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: null, + }, + }, + {} as CancellationToken + ) + ) + + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + }) + + it('should initialize to PENDING_Q_PROFILE state when IdentityCenter Connection is set', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingProfileError) + assert.throws(() => amazonQServiceManager.getStreamingClient(), AmazonQServicePendingProfileError) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + }) + + it('handles Profile configuration request for valid profile and initializes to INITIALIZED state', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1)) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1') + }) + + it('handles Profile configuration request for valid profile & cancels the old in-flight update request', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + assert.strictEqual((amazonQServiceManager as any)['profileChangeTokenSource'], undefined) + + let firstRequestStarted = false + const originalHandleProfileChange = amazonQServiceManager['handleProfileChange'] + amazonQServiceManager['handleProfileChange'] = async (...args) => { + firstRequestStarted = true + return originalHandleProfileChange.apply(amazonQServiceManager, args) + } + const firstUpdate = amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + while (!firstRequestStarted) { + await new Promise(resolve => setTimeout(resolve, 1)) + } + const secondUpdate = amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + const results = await Promise.allSettled([firstUpdate, secondUpdate]) + + assert.strictEqual((amazonQServiceManager as any)['profileChangeTokenSource'], undefined) + const service = amazonQServiceManager.getCodewhispererService() + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + + assert.strictEqual(results[0].status, 'fulfilled') + assert.strictEqual(results[1].status, 'fulfilled') + }) + + it('handles Profile configuration change to valid profile in same region', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient1 = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ' + ) + + assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1)) + assert(streamingClient1 instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient1.client.config.region(), 'us-east-1') + + // Profile change + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ-2', + }, + }, + {} as CancellationToken + ) + await service.generateSuggestions({} as GenerateSuggestionsRequest) + const streamingClient2 = amazonQServiceManager.getStreamingClient() + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ-2' + ) + + // CodeWhisperer Service was not recreated + assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1)) + + assert(streamingClient2 instanceof StreamingClientServiceToken) + assert.strictEqual(streamingClient1, streamingClient2) + assert.strictEqual(await streamingClient2.client.config.region(), 'us-east-1') + }) + + it('handles Profile configuration change to valid profile in different region', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient1 = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ' + ) + assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1)) + + assert(streamingClient1 instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient1.client.config.region(), 'us-east-1') + + // Profile change + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + await service.generateSuggestions({} as GenerateSuggestionsRequest) + const streamingClient2 = amazonQServiceManager.getStreamingClient() + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ' + ) + + // CodeWhisperer Service was recreated + assert(codewhispererStubFactory.calledTwice) + assert.deepStrictEqual(codewhispererStubFactory.lastCall.args, [ + 'eu-central-1', + TEST_ENDPOINT_EU_CENTRAL_1, + ]) + + // Streaming Client was recreated + assert(streamingClient2 instanceof StreamingClientServiceToken) + assert.notStrictEqual(streamingClient1, streamingClient2) + assert.strictEqual(await streamingClient2.client.config.region(), 'eu-central-1') + }) + + // As we're not validating profile at this moment, there is no "invalid" profile + it.skip('handles Profile configuration change from valid to invalid profile', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + let service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ' + ) + assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1)) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1') + + // Profile change to invalid profile + + await assert.rejects( + amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/invalid-profile-arn', + }, + }, + {} as CancellationToken + ), + new ResponseError(LSPErrorCodes.RequestFailed, 'Requested Amazon Q Profile does not exist', { + awsErrorCode: 'E_AMAZON_Q_INVALID_PROFILE', + }) + ) + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingProfileError) + assert.throws(() => amazonQServiceManager.getStreamingClient(), AmazonQServicePendingProfileError) + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + // CodeWhisperer Service was not recreated + assert(codewhispererStubFactory.calledOnce) + assert.deepStrictEqual(codewhispererStubFactory.lastCall.args, ['us-east-1', TEST_ENDPOINT_US_EAST_1]) + }) + + // As we're not validating profile at this moment, there is no "non-existing" profile + it.skip('handles non-existing profile selection', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await assert.rejects( + amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/invalid-profile-arn', + }, + }, + {} as CancellationToken + ), + new ResponseError(LSPErrorCodes.RequestFailed, 'Requested Amazon Q Profile does not exist', { + awsErrorCode: 'E_AMAZON_Q_INVALID_PROFILE', + }) + ) + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingProfileError) + assert.throws(() => amazonQServiceManager.getStreamingClient(), AmazonQServicePendingProfileError) + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(codewhispererStubFactory.notCalled) + }) + + it('prevents service usage while profile change is inflight when profile was not set', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingProfileError) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + + amazonQServiceManager.setState('PENDING_Q_PROFILE_UPDATE') + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE_UPDATE') + + assert.throws( + () => amazonQServiceManager.getCodewhispererService(), + AmazonQServicePendingProfileUpdateError + ) + assert.throws(() => amazonQServiceManager.getStreamingClient(), AmazonQServicePendingProfileUpdateError) + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:eu-central-1:11111111111111:profile/QQQQQQQQQQQQ' + ) + assert.deepStrictEqual(codewhispererStubFactory.lastCall.args, [ + 'eu-central-1', + TEST_ENDPOINT_EU_CENTRAL_1, + ]) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'eu-central-1') + }) + + it('prevents service usage while profile change is inflight when profile was set before', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingProfileError) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + const service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual( + amazonQServiceManager.getActiveProfileArn(), + 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ' + ) + assert.deepStrictEqual(codewhispererStubFactory.lastCall.args, ['us-east-1', TEST_ENDPOINT_US_EAST_1]) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1') + + // Updaing profile + amazonQServiceManager.setState('PENDING_Q_PROFILE_UPDATE') + assert.throws( + () => amazonQServiceManager.getCodewhispererService(), + AmazonQServicePendingProfileUpdateError + ) + assert.throws(() => amazonQServiceManager.getStreamingClient(), AmazonQServicePendingProfileUpdateError) + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE_UPDATE') + }) + + it('resets to PENDING_PROFILE from INITIALIZED when receiving null profileArn', async () => { + await setupServiceManagerWithProfile() + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: null, + }, + }, + {} as CancellationToken + ) + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + sinon.assert.calledOnce(codewhispererServiceStub.abortInflightRequests) + }) + + it('resets to PENDING_Q_PROFILE from PENDING_Q_PROFILE_UPDATE when receiving null profileArn', async () => { + await setupServiceManagerWithProfile() + + amazonQServiceManager.setState('PENDING_Q_PROFILE_UPDATE') + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE_UPDATE') + + // Null profile arn + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: null, + }, + }, + {} as CancellationToken + ) + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + sinon.assert.calledOnce(codewhispererServiceStub.abortInflightRequests) + assert.throws(() => amazonQServiceManager.getCodewhispererService()) + }) + + it('cancels on-going profile update when credentials are deleted', async () => { + await setupServiceManagerWithProfile() + + amazonQServiceManager.setState('PENDING_Q_PROFILE_UPDATE') + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE_UPDATE') + + amazonQServiceManager.handleOnCredentialsDeleted('bearer') + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + sinon.assert.calledOnce(codewhispererServiceStub.abortInflightRequests) + assert.throws(() => amazonQServiceManager.getCodewhispererService()) + }) + + // Due to service limitation, validation was removed for the sake of recovering API availability + // When service is ready to take more tps, revert https://github.com/aws/language-servers/pull/1329 to add profile validation + it('should not call service to validate profile and always assume its validness', async () => { + setupServiceManager(true) + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + + setCredentials('identityCenter') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + sinon.assert.notCalled(getListAllAvailableProfilesHandlerStub) + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + }) + }) + }) + + describe('Connection types with no Developer Profiles support', () => { + it('handles reauthentication scenario when connection type is none but profile ARN is provided', async () => { + setupServiceManager(true) + clearCredentials() + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'none') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ) + + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + }) + + it('ignores null profile when connection type is none', async () => { + setupServiceManager(true) + clearCredentials() + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'none') + + await amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: null, + }, + }, + {} as CancellationToken + ) + + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'none') + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION') + }) + + it('returns error when profile update is requested and connection type is builderId', async () => { + setupServiceManager(true) + setCredentials('builderId') + + await assert.rejects( + amazonQServiceManager.handleOnUpdateConfiguration( + { + section: 'aws.q', + settings: { + profileArn: 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ', + }, + }, + {} as CancellationToken + ), + new ResponseError( + LSPErrorCodes.RequestFailed, + 'Connection type builderId does not support Developer Profiles feature.', + { + awsErrorCode: 'E_AMAZON_Q_CONNECTION_NO_PROFILE_SUPPORT', + } + ) + ) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId') + }) + }) + + describe('Handle connection type changes', () => { + describe('connection changes from BuilderId to IdentityCenter', () => { + it('should initialize service with default region when profile support is disabled', async () => { + setupServiceManager(false) + setCredentials('builderId') + + let service1 = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service1.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1') + + setCredentials('identityCenter') + let service2 = amazonQServiceManager.getCodewhispererService() + const streamingClient2 = amazonQServiceManager.getStreamingClient() + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(codewhispererStubFactory.calledTwice) + assert(codewhispererStubFactory.calledWithExactly(DEFAULT_AWS_Q_REGION, DEFAULT_AWS_Q_ENDPOINT_URL)) + + assert(streamingClient2 instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient2.client.config.region(), DEFAULT_AWS_Q_REGION) + }) + + it('should initialize service to PENDING_Q_PROFILE state when profile support is enabled', async () => { + setupServiceManager(true) + setCredentials('builderId') + + let service = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1') + + setCredentials('identityCenter') + + assert.throws(() => amazonQServiceManager.getCodewhispererService(), AmazonQServicePendingProfileError) + assert.throws(() => amazonQServiceManager.getStreamingClient(), AmazonQServicePendingProfileError) + + assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_Q_PROFILE') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(codewhispererStubFactory.calledOnce) + assert(codewhispererStubFactory.calledWithExactly(DEFAULT_AWS_Q_REGION, DEFAULT_AWS_Q_ENDPOINT_URL)) + }) + }) + + describe('connection changes from IdentityCenter to BuilderId', () => { + it('should initialize service in default IAD region', async () => { + setupServiceManager(false) + setCredentials('identityCenter') + + let service1 = amazonQServiceManager.getCodewhispererService() + const streamingClient = amazonQServiceManager.getStreamingClient() + await service1.generateSuggestions({} as GenerateSuggestionsRequest) + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(streamingClient instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1') + + setCredentials('builderId') + let service2 = amazonQServiceManager.getCodewhispererService() + const streamingClient2 = amazonQServiceManager.getStreamingClient() + + assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED') + assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId') + assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined) + + assert(codewhispererStubFactory.calledTwice) + assert(codewhispererStubFactory.calledWithExactly(DEFAULT_AWS_Q_REGION, DEFAULT_AWS_Q_ENDPOINT_URL)) + + assert(streamingClient2 instanceof StreamingClientServiceToken) + assert.strictEqual(await streamingClient2.client.config.region(), 'us-east-1') + }) + }) + }) + + describe('handle LSP Configuration settings', () => { + it('should initialize codewhisperer service with default configurations when not set by client', async () => { + setupServiceManager() + setCredentials('identityCenter') + + await amazonQServiceManager.handleDidChangeConfiguration() + + const service = amazonQServiceManager.getCodewhispererService() + + assert.strictEqual(service.customizationArn, undefined) + assert.strictEqual(service.shareCodeWhispererContentWithAWS, false) + }) + + it('should returned configured codewhispererService with expected configuration values', async () => { + const getConfigStub = features.lsp.workspace.getConfiguration + getConfigStub.withArgs('aws.q').resolves({ + customization: 'test-customization-arn', + optOutTelemetryPreference: true, + }) + getConfigStub.withArgs('aws.codeWhisperer').resolves({ + includeSuggestionsWithCodeReferences: true, + shareCodeWhispererContentWithAWS: true, + }) + + // Initialize mock server + setupServiceManager() + setCredentials('identityCenter') + + amazonQServiceManager = AmazonQServiceManager.getInstance() + const service = amazonQServiceManager.getCodewhispererService() + + assert.strictEqual(service.customizationArn, undefined) + assert.strictEqual(service.shareCodeWhispererContentWithAWS, false) + + await amazonQServiceManager.handleDidChangeConfiguration() + + // Force next tick to allow async work inside handleDidChangeConfiguration to complete + await Promise.resolve() + + assert.strictEqual(service.customizationArn, 'test-customization-arn') + assert.strictEqual(service.shareCodeWhispererContentWithAWS, true) + }) + }) + + describe('Initialize', () => { + it('should throw when initialize is called before LSP has been initialized with InitializeParams', () => { + features.resetClientParams() + + assert.throws(() => AmazonQServiceManager.initInstance(features), AmazonQServiceInitializationError) + }) + }) +}) + +describe('IAM', () => { + describe('Initialization process', () => { + generateSingletonInitializationTests(AmazonQServiceManager) + }) + + describe('Service caching', () => { + let serviceManager: AmazonQServiceManager + let features: TestFeatures + let updateCachedServiceConfigSpy: sinon.SinonSpy + + const setCredentials = setIamCredentialsForAmazonQServiceManagerFactory(() => features) + + beforeEach(() => { + features = new TestFeatures() + features.lsp.getClientInitializeParams.resolves({}) + + updateCachedServiceConfigSpy = sinon.spy( + AmazonQServiceManager.prototype, + 'updateCachedServiceConfig' as keyof AmazonQServiceManager + ) + + AmazonQServiceManager.resetInstance() + serviceManager = AmazonQServiceManager.initInstance(features) + }) + + afterEach(() => { + AmazonQServiceManager.resetInstance() + features.dispose() + sinon.restore() + }) + + it('should initialize the CodeWhisperer service only once', () => { + setCredentials() + const service = serviceManager.getCodewhispererService() + sinon.assert.calledOnce(updateCachedServiceConfigSpy) + + assert.deepStrictEqual(serviceManager.getCodewhispererService(), service) + sinon.assert.calledOnce(updateCachedServiceConfigSpy) + }) + + it('should initialize the streaming client only once', () => { + // Mock the credentials provider to return credentials when requested + setCredentials() + const streamingClient = serviceManager.getStreamingClient() + + // Verify that getting the client again returns the same instance + assert.deepStrictEqual(serviceManager.getStreamingClient(), streamingClient) + }) + }) +}) diff --git a/server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.ts b/server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.ts new file mode 100644 index 0000000000..6d7ad4d76c --- /dev/null +++ b/server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.ts @@ -0,0 +1,695 @@ +import { + UpdateConfigurationParams, + ResponseError, + LSPErrorCodes, + SsoConnectionType, + CancellationToken, + CredentialsType, + InitializeParams, + CancellationTokenSource, +} from '@aws/language-server-runtimes/server-interface' +import { CodeWhispererServiceToken } from '../codeWhispererService/codeWhispererServiceToken' +import { CodeWhispererServiceIAM } from '../codeWhispererService/codeWhispererServiceIAM' +import { + AmazonQError, + AmazonQServiceAlreadyInitializedError, + AmazonQServiceInitializationError, + AmazonQServiceInvalidProfileError, + AmazonQServiceNoProfileSupportError, + AmazonQServiceNotInitializedError, + AmazonQServicePendingProfileError, + AmazonQServicePendingProfileUpdateError, + AmazonQServicePendingSigninError, + AmazonQServiceProfileUpdateCancelled, +} from './errors' +import { + AmazonQBaseServiceManager, + BaseAmazonQServiceManager, + QServiceManagerFeatures, +} from './BaseAmazonQServiceManager' +import { AWS_Q_ENDPOINTS, Q_CONFIGURATION_SECTION } from '../constants' +import { AmazonQDeveloperProfile, signalsAWSQDeveloperProfilesEnabled } from './qDeveloperProfiles' +import { isStringOrNull } from '../utils' +import { getAmazonQRegionAndEndpoint } from './configurationUtils' +import { getUserAgent } from '../telemetryUtils' +import { + StreamingClientServiceToken, + StreamingClientServiceIAM, + StreamingClientServiceBase, +} from '../streamingClientService' +import { parse } from '@aws-sdk/util-arn-parser' +import { CodeWhispererServiceBase } from '../codeWhispererService/codeWhispererServiceBase' + +/** + * AmazonQServiceManager manages state and provides centralized access to + * instance of CodeWhispererService SDK client to any consuming code. + * It ensures that CodeWhispererService is configured to always access correct regionalized Amazon Q Developer API endpoint. + * Regional endppoint is selected based on: + * 1) current SSO auth connection type (BuilderId or IDC). + * 2) selected Amazon Q Developer profile (only for IDC connection type). + * + * @states + * - PENDING_CONNECTION: Initial state when no bearer token is set + * - PENDING_Q_PROFILE: When using Identity Center and waiting for profile selection + * - PENDING_Q_PROFILE_UPDATE: During profile update operation + * - INITIALIZED: Service is ready to handle requests + * + * @connectionTypes + * - none: No active connection + * - builderId: Connected via Builder ID + * - identityCenter: Connected via Identity Center + * + * AmazonQServiceManager is a singleton class, which must be instantiated with Language Server runtimes [Features](https://github.com/aws/language-server-runtimes/blob/21d5d1dc7c73499475b7c88c98d2ce760e5d26c8/runtimes/server-interface/server.ts#L31-L42) + * in the `AmazonQServiceServer` via the `initBaseServiceManager` factory. Dependencies of this class can access the singleton via + * the `getOrThrowBaseServiceManager` factory or `getInstance()` method after the initialized notification has been received during + * the LSP hand shake. + * + */ +export class AmazonQServiceManager extends BaseAmazonQServiceManager< + CodeWhispererServiceBase, + StreamingClientServiceBase +> { + private static instance: AmazonQServiceManager | null = null + private enableDeveloperProfileSupport?: boolean + private activeIdcProfile?: AmazonQDeveloperProfile + private connectionType?: SsoConnectionType + private profileChangeTokenSource: CancellationTokenSource | undefined + private region?: string + private endpoint?: string + private regionChangeListeners: Array<(region: string) => void> = [] + /** + * Internal state of Service connection, based on status of bearer token and Amazon Q Developer profile selection. + * Supported states: + * PENDING_CONNECTION - Waiting for (Bearer Token and StartURL) or (Access Key and Secret Key) to be passed + * PENDING_Q_PROFILE - (only for identityCenter connection) waiting for setting Developer Profile + * PENDING_Q_PROFILE_UPDATE (only for identityCenter connection) waiting for Developer Profile to complete + * INITIALIZED - Service is initialized + */ + private state: 'PENDING_CONNECTION' | 'PENDING_Q_PROFILE' | 'PENDING_Q_PROFILE_UPDATE' | 'INITIALIZED' = + 'PENDING_CONNECTION' + + private constructor(features: QServiceManagerFeatures) { + super(features) + } + + // @VisibleForTesting, please DO NOT use in production + setState(state: 'PENDING_CONNECTION' | 'PENDING_Q_PROFILE' | 'PENDING_Q_PROFILE_UPDATE' | 'INITIALIZED') { + this.state = state + } + + endpointOverride(): string | undefined { + return this.features.lsp.getClientInitializeParams()?.initializationOptions?.aws?.awsClientCapabilities + ?.textDocument?.inlineCompletionWithReferences?.endpointOverride + } + + public static initInstance(features: QServiceManagerFeatures): AmazonQServiceManager { + if (!AmazonQServiceManager.instance) { + AmazonQServiceManager.instance = new AmazonQServiceManager(features) + AmazonQServiceManager.instance.initialize() + + return AmazonQServiceManager.instance + } + + throw new AmazonQServiceAlreadyInitializedError() + } + + public static getInstance(): AmazonQServiceManager { + if (!AmazonQServiceManager.instance) { + throw new AmazonQServiceInitializationError( + 'Amazon Q service has not been initialized yet. Make sure the Amazon Q server is present and properly initialized.' + ) + } + + return AmazonQServiceManager.instance + } + + private initialize(): void { + if (!this.features.lsp.getClientInitializeParams()) { + this.log('AmazonQServiceManager initialized before LSP connection was initialized.') + throw new AmazonQServiceInitializationError( + 'AmazonQServiceManager initialized before LSP connection was initialized.' + ) + } + + // Bind methods that are passed by reference to some handlers to maintain proper scope. + this.serviceFactory = this.serviceFactory.bind(this) + + this.log('Reading enableDeveloperProfileSupport setting from AWSInitializationOptions') + if (this.features.lsp.getClientInitializeParams()?.initializationOptions?.aws) { + const awsOptions = this.features.lsp.getClientInitializeParams()?.initializationOptions?.aws || {} + this.enableDeveloperProfileSupport = signalsAWSQDeveloperProfilesEnabled(awsOptions) + + this.log(`Enabled Q Developer Profile support: ${this.enableDeveloperProfileSupport}`) + } + + this.connectionType = 'none' + this.state = 'PENDING_CONNECTION' + + this.log('Manager instance is initialize') + } + + public handleOnCredentialsDeleted(type: CredentialsType): void { + this.log(`Received credentials delete event for type: ${type}`) + if (type === 'iam') { + return + } + this.cancelActiveProfileChangeToken() + + this.resetCodewhispererService() + this.connectionType = 'none' + this.state = 'PENDING_CONNECTION' + } + + public async handleOnUpdateConfiguration(params: UpdateConfigurationParams, _token: CancellationToken) { + try { + if (params.section === Q_CONFIGURATION_SECTION && params.settings.profileArn !== undefined) { + const profileArn = params.settings.profileArn + const region = params.settings.region + + if (!isStringOrNull(profileArn)) { + throw new Error('Expected params.settings.profileArn to be of either type string or null') + } + + this.log(`Profile update is requested for profile ${profileArn}`) + this.cancelActiveProfileChangeToken() + this.profileChangeTokenSource = new CancellationTokenSource() + + await this.handleProfileChange(profileArn, this.profileChangeTokenSource.token) + } + } catch (error) { + this.log('Error updating profiles: ' + error) + if (error instanceof AmazonQServiceProfileUpdateCancelled) { + throw new ResponseError(LSPErrorCodes.ServerCancelled, error.message, { + awsErrorCode: error.code, + }) + } + if (error instanceof AmazonQError) { + throw new ResponseError(LSPErrorCodes.RequestFailed, error.message, { + awsErrorCode: error.code, + }) + } + + throw new ResponseError(LSPErrorCodes.RequestFailed, 'Failed to update configuration') + } finally { + if (this.profileChangeTokenSource) { + this.profileChangeTokenSource.dispose() + this.profileChangeTokenSource = undefined + } + } + } + + /** + * Validate if Bearer Token Connection type has changed mid-session. + * When connection type change is detected: reinitialize CodeWhispererService class with current connection type. + */ + private handleSsoConnectionChange() { + const newConnectionType = this.features.credentialsProvider.getConnectionType() + + this.logServiceState('Validate State of SSO Connection') + + const noCreds = !this.features.credentialsProvider.hasCredentials('bearer') + const noConnectionType = newConnectionType === 'none' + if (noCreds || noConnectionType) { + // Connection was reset, wait for SSO connection token from client + this.log( + `No active SSO connection is detected: no ${noCreds ? 'credentials' : 'connection type'} provided. Resetting the client` + ) + this.resetCodewhispererService() + this.connectionType = 'none' + this.state = 'PENDING_CONNECTION' + + return + } + + // Connection type hasn't change. + + if (newConnectionType === this.connectionType) { + this.logging.debug(`Connection type did not change: ${this.connectionType}`) + + return + } + + const endpointOverride = + this.features.lsp.getClientInitializeParams()?.initializationOptions?.aws?.awsClientCapabilities + ?.textDocument?.inlineCompletionWithReferences?.endpointOverride + + // Connection type changed to 'builderId' + + if (newConnectionType === 'builderId') { + this.log('Detected New connection type: builderId') + this.resetCodewhispererService() + + // For the builderId connection type regional endpoint discovery chain is: + // region set by client -> runtime region -> default region + const clientParams = this.features.lsp.getClientInitializeParams() + + this.createCodewhispererServiceInstances( + 'builderId', + clientParams?.initializationOptions?.aws?.region, + endpointOverride + ) + this.state = 'INITIALIZED' + this.log('Initialized Amazon Q service with builderId connection') + + return + } + + // Connection type changed to 'identityCenter' + + if (newConnectionType === 'identityCenter') { + this.log('Detected New connection type: identityCenter') + + this.resetCodewhispererService() + + if (this.enableDeveloperProfileSupport) { + this.connectionType = 'identityCenter' + this.state = 'PENDING_Q_PROFILE' + this.logServiceState('Pending profile selection for IDC connection') + + return + } + + this.createCodewhispererServiceInstances('identityCenter', undefined, endpointOverride) + this.state = 'INITIALIZED' + this.log('Initialized Amazon Q service with identityCenter connection') + + return + } + + this.logServiceState('Unknown Connection state') + } + + private cancelActiveProfileChangeToken() { + this.profileChangeTokenSource?.cancel() + this.profileChangeTokenSource?.dispose() + this.profileChangeTokenSource = undefined + } + + private handleTokenCancellationRequest(token: CancellationToken) { + if (token.isCancellationRequested) { + this.logServiceState('Handling CancellationToken cancellation request') + throw new AmazonQServiceProfileUpdateCancelled('Requested profile update got cancelled') + } + } + + private async handleProfileChange(newProfileArn: string | null, token: CancellationToken): Promise { + if (!this.enableDeveloperProfileSupport) { + this.log('Developer Profiles Support is not enabled') + return + } + + if (typeof newProfileArn === 'string' && newProfileArn.length === 0) { + throw new Error('Received invalid Profile ARN (empty string)') + } + + this.logServiceState('UpdateProfile is requested') + + // Test if connection type changed + this.handleSsoConnectionChange() + + if (this.connectionType === 'none') { + if (newProfileArn !== null) { + // During reauthentication, connection might be temporarily 'none' but user is providing a profile + // Set connection type to identityCenter to proceed with profile setting + this.connectionType = 'identityCenter' + this.state = 'PENDING_Q_PROFILE_UPDATE' + } else { + this.logServiceState('Received null profile while not connected, ignoring request') + return + } + } + + if (this.connectionType !== 'identityCenter') { + this.logServiceState('Q Profile can not be set') + throw new AmazonQServiceNoProfileSupportError( + `Connection type ${this.connectionType} does not support Developer Profiles feature.` + ) + } + + if ((this.state === 'INITIALIZED' && this.activeIdcProfile) || this.state === 'PENDING_Q_PROFILE') { + // Change status to pending to prevent API calls until profile is updated. + // Because `listAvailableProfiles` below can take few seconds to complete, + // there is possibility that client could send requests while profile is changing. + this.state = 'PENDING_Q_PROFILE_UPDATE' + } + + // Client sent an explicit null, indicating they want to reset the assigned profile (if any) + if (newProfileArn === null) { + this.logServiceState('Received null profile, resetting to PENDING_Q_PROFILE state') + this.resetCodewhispererService() + this.state = 'PENDING_Q_PROFILE' + + return + } + + const parsedArn = parse(newProfileArn) + const region = parsedArn.region + const endpoint = AWS_Q_ENDPOINTS.get(region) + if (!endpoint) { + throw new Error('Requested profileArn region is not supported') + } + + // Hack to inject a dummy profile name as it's not used by client IDE for now, if client IDE starts consuming name field then we should also pass both profile name and arn from the IDE + // When service is ready to take more tps, revert https://github.com/aws/language-servers/pull/1329 to add profile validation + const newProfile: AmazonQDeveloperProfile = { + arn: newProfileArn, + name: 'Client provided profile', + identityDetails: { + region: parsedArn.region, + }, + } + + if (!newProfile || !newProfile.identityDetails?.region) { + this.log(`Amazon Q Profile ${newProfileArn} is not valid`) + this.resetCodewhispererService() + this.state = 'PENDING_Q_PROFILE' + + throw new AmazonQServiceInvalidProfileError('Requested Amazon Q Profile does not exist') + } + + this.handleTokenCancellationRequest(token) + + if (!this.activeIdcProfile) { + this.activeIdcProfile = newProfile + this.createCodewhispererServiceInstances( + 'identityCenter', + newProfile.identityDetails.region, + this.endpointOverride() + ) + this.state = 'INITIALIZED' + this.log( + `Initialized identityCenter connection to region ${newProfile.identityDetails.region} for profile ${newProfile.arn}` + ) + + return + } + + // Profile didn't change + if (this.activeIdcProfile && this.activeIdcProfile.arn === newProfile.arn) { + // Update cached profile fields, keep existing client + this.log(`Profile selection did not change, active profile is ${this.activeIdcProfile.arn}`) + this.activeIdcProfile = newProfile + this.state = 'INITIALIZED' + + return + } + + this.handleTokenCancellationRequest(token) + + // At this point new valid profile is selected. + + const oldRegion = this.activeIdcProfile.identityDetails?.region + const newRegion = newProfile.identityDetails.region + if (oldRegion === newRegion) { + this.log(`New profile is in the same region as old one, keeping exising service.`) + this.log(`New active profile is ${this.activeIdcProfile.arn}, region ${oldRegion}`) + this.activeIdcProfile = newProfile + this.state = 'INITIALIZED' + + if (this.cachedCodewhispererService) { + this.cachedCodewhispererService.profileArn = newProfile.arn + } + + if (this.cachedStreamingClient) { + this.cachedStreamingClient.profileArn = newProfile.arn + } + + return + } + + this.log(`Switching service client region from ${oldRegion} to ${newRegion}`) + this.notifyRegionChangeListeners(newRegion) + + this.handleTokenCancellationRequest(token) + + // Selected new profile is in different region. Re-initialize service + this.resetCodewhispererService() + + this.activeIdcProfile = newProfile + + this.createCodewhispererServiceInstances( + 'identityCenter', + newProfile.identityDetails.region, + this.endpointOverride() + ) + this.state = 'INITIALIZED' + + return + } + + public getCodewhispererService(): CodeWhispererServiceBase { + // Prevent initiating requests while profile change is in progress. + if (this.state === 'PENDING_Q_PROFILE_UPDATE') { + throw new AmazonQServicePendingProfileUpdateError() + } + + if (this.features.credentialsProvider.hasCredentials('iam')) { + if (!this.cachedCodewhispererService) { + const amazonQRegionAndEndpoint = getAmazonQRegionAndEndpoint( + this.features.runtime, + this.features.logging + ) + this.region = amazonQRegionAndEndpoint.region + this.endpoint = amazonQRegionAndEndpoint.endpoint + this.cachedCodewhispererService = this.serviceFactory(this.region, this.endpoint) + this.updateCachedServiceConfig() + } + this.state = 'INITIALIZED' + } else { + this.handleSsoConnectionChange() + } + + if (this.state === 'INITIALIZED' && this.cachedCodewhispererService) { + return this.cachedCodewhispererService + } + + if (this.state === 'PENDING_CONNECTION') { + throw new AmazonQServicePendingSigninError() + } + + if (this.state === 'PENDING_Q_PROFILE') { + throw new AmazonQServicePendingProfileError() + } + + throw new AmazonQServiceNotInitializedError() + } + + public getStreamingClient(): StreamingClientServiceBase { + this.log('Getting instance of CodeWhispererStreaming client') + + // Trigger checks in token service + const service = this.getCodewhispererService() + + if (!service || !this.region || !this.endpoint) { + throw new AmazonQServiceNotInitializedError() + } + + if (!this.cachedStreamingClient) { + this.cachedStreamingClient = this.streamingClientFactory(this.region, this.endpoint) + } + + return this.cachedStreamingClient + } + + private resetCodewhispererService() { + this.cachedCodewhispererService?.abortInflightRequests() + this.cachedCodewhispererService = undefined + this.cachedStreamingClient?.abortInflightRequests() + this.cachedStreamingClient = undefined + this.activeIdcProfile = undefined + this.region = undefined + this.endpoint = undefined + } + + private createCodewhispererServiceInstances( + connectionType: 'builderId' | 'identityCenter', + clientOrProfileRegion: string | undefined, + endpointOverride: string | undefined + ) { + this.logServiceState('Initializing CodewhispererService') + + const { region, endpoint } = getAmazonQRegionAndEndpoint( + this.features.runtime, + this.features.logging, + clientOrProfileRegion + ) + + // Cache active region and endpoint selection + this.connectionType = connectionType + this.region = region + this.endpoint = endpoint + + if (endpointOverride) { + this.endpoint = endpointOverride + } + + this.cachedCodewhispererService = this.serviceFactory(region, this.endpoint) + this.log(`CodeWhispererToken service for connection type ${connectionType} was initialized, region=${region}`) + + this.cachedStreamingClient = this.streamingClientFactory(region, this.endpoint) + this.log(`StreamingClient service for connection type ${connectionType} was initialized, region=${region}`) + + this.logServiceState('CodewhispererService and StreamingClient Initialization finished') + } + + private getCustomUserAgent() { + const initializeParams = this.features.lsp.getClientInitializeParams() || {} + + return getUserAgent(initializeParams as InitializeParams, this.features.runtime.serverInfo) + } + + private serviceFactory(region: string, endpoint: string): CodeWhispererServiceBase { + let service: CodeWhispererServiceBase + if (this.features.credentialsProvider.hasCredentials('iam')) { + service = new CodeWhispererServiceIAM( + this.features.credentialsProvider, + this.features.workspace, + this.features.logging, + region, + endpoint, + this.features.sdkInitializator + ) + } else { + service = new CodeWhispererServiceToken( + this.features.credentialsProvider, + this.features.workspace, + this.features.logging, + region, + endpoint, + this.features.sdkInitializator + ) + + const customUserAgent = this.getCustomUserAgent() + service.updateClientConfig({ + customUserAgent: customUserAgent, + }) + service.customizationArn = this.configurationCache.getProperty('customizationArn') + service.profileArn = this.activeIdcProfile?.arn + service.shareCodeWhispererContentWithAWS = this.configurationCache.getProperty( + 'shareCodeWhispererContentWithAWS' + ) + + this.log('Configured CodeWhispererService instance settings:') + this.log( + `customUserAgent=${customUserAgent}, customizationArn=${service.customizationArn}, shareCodeWhispererContentWithAWS=${service.shareCodeWhispererContentWithAWS}` + ) + } + + return service + } + + private streamingClientFactory(region: string, endpoint: string): StreamingClientServiceBase { + let streamingClient: StreamingClientServiceBase + if (this.features.credentialsProvider.hasCredentials('iam')) { + streamingClient = new StreamingClientServiceIAM( + this.features.credentialsProvider, + this.features.sdkInitializator, + this.features.logging, + region, + endpoint + ) + } else { + streamingClient = new StreamingClientServiceToken( + this.features.credentialsProvider, + this.features.sdkInitializator, + this.features.logging, + region, + endpoint, + this.getCustomUserAgent() + ) + streamingClient.profileArn = this.activeIdcProfile?.arn + } + + this.logging.debug(`Created streaming client instance region=${region}, endpoint=${endpoint}`) + return streamingClient + } + + private log(message: string): void { + const prefix = 'Amazon Q Token Service Manager' + this.logging?.log(`${prefix}: ${message}`) + } + + private logServiceState(context: string): void { + this.logging?.debug( + JSON.stringify({ + context, + state: { + serviceStatus: this.state, + connectionType: this.connectionType, + activeIdcProfile: this.activeIdcProfile, + }, + }) + ) + } + + // For Unit Tests + public static resetInstance(): void { + AmazonQServiceManager.instance = null + } + + public getState() { + return this.state + } + + public getConnectionType() { + return this.connectionType + } + + public override getActiveProfileArn() { + return this.activeIdcProfile?.arn + } + + public setServiceFactory(factory: (region: string, endpoint: string) => CodeWhispererServiceToken) { + this.serviceFactory = factory.bind(this) + } + + public getServiceFactory() { + return this.serviceFactory + } + + public getEnableDeveloperProfileSupport(): boolean { + return this.enableDeveloperProfileSupport === undefined ? false : this.enableDeveloperProfileSupport + } + + /** + * Registers a listener that will be called when the region changes + * @param listener Function that will be called with the new region + * @returns Function to unregister the listener + */ + public override onRegionChange(listener: (region: string) => void): () => void { + this.regionChangeListeners.push(listener) + // If we already have a region, notify the listener immediately + if (this.region) { + try { + listener(this.region) + } catch (error) { + this.logging.error(`Error in region change listener: ${error}`) + } + } + return () => { + this.regionChangeListeners = this.regionChangeListeners.filter(l => l !== listener) + } + } + + private notifyRegionChangeListeners(region: string): void { + this.logging.debug( + `Notifying ${this.regionChangeListeners.length} region change listeners of region: ${region}` + ) + this.regionChangeListeners.forEach(listener => { + try { + listener(region) + } catch (error) { + this.logging.error(`Error in region change listener: ${error}`) + } + }) + } + + public getRegion(): string | undefined { + return this.region + } +} + +export const initBaseServiceManager = (features: QServiceManagerFeatures) => + AmazonQServiceManager.initInstance(features) + +export const getOrThrowBaseServiceManager = (): AmazonQBaseServiceManager => AmazonQServiceManager.getInstance() diff --git a/server/aws-lsp-codewhisperer/src/shared/testUtils.ts b/server/aws-lsp-codewhisperer/src/shared/testUtils.ts index 79596d8cc6..e444809cc5 100644 --- a/server/aws-lsp-codewhisperer/src/shared/testUtils.ts +++ b/server/aws-lsp-codewhisperer/src/shared/testUtils.ts @@ -319,10 +319,10 @@ export function shuffleList(list: T[]): T[] { return shuffledList } -export const setCredentialsForAmazonQTokenServiceManagerFactory = (getFeatures: () => TestFeatures) => { +export const setTokenCredentialsForAmazonQServiceManagerFactory = (getFeatures: () => TestFeatures) => { return (connectionType: SsoConnectionType) => { const features = getFeatures() - features.credentialsProvider.hasCredentials.returns(true) + features.credentialsProvider.hasCredentials.withArgs('bearer').returns(true) features.credentialsProvider.getConnectionType.returns(connectionType) features.credentialsProvider.getCredentials.returns({ token: 'test-token', @@ -330,6 +330,22 @@ export const setCredentialsForAmazonQTokenServiceManagerFactory = (getFeatures: } } +// TODO: remove this when changing references +export const setCredentialsForAmazonQTokenServiceManagerFactory = setTokenCredentialsForAmazonQServiceManagerFactory + +export const setIamCredentialsForAmazonQServiceManagerFactory = (getFeatures: () => TestFeatures) => { + return () => { + const features = getFeatures() + features.credentialsProvider.hasCredentials.withArgs('iam').returns(true) + features.credentialsProvider.getConnectionType.returns('none') + features.credentialsProvider.getCredentials.returns({ + accessKeyId: 'test-access-key', + secretAccessKey: 'test-secret-key', + sessionToken: 'test-session-token', + }) + } +} + export const stubCodeWhispererService = () => { const service = stubInterface()