Skip to content

Commit 7b2ce35

Browse files
committed
fix: update client references inside unified service manager
1 parent 16d5d50 commit 7b2ce35

File tree

2 files changed

+79
-76
lines changed

2 files changed

+79
-76
lines changed

server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.test.ts

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import {
2929
setTokenCredentialsForAmazonQServiceManagerFactory,
3030
setIamCredentialsForAmazonQServiceManagerFactory,
3131
} from '../testUtils'
32-
import { StreamingClientService } from '../streamingClientService'
32+
import { StreamingClientServiceToken, StreamingClientServiceIAM } from '../streamingClientService'
3333
import { generateSingletonInitializationTests } from './testUtils'
3434
import * as utils from '../utils'
3535

@@ -61,8 +61,8 @@ const TEST_ENDPOINT_US_EAST_1 = 'http://amazon-q-in-us-east-1-endpoint'
6161
const TEST_ENDPOINT_EU_CENTRAL_1 = 'http://amazon-q-in-eu-central-1-endpoint'
6262

6363
describe('Token', () => {
64-
let codewhispererServiceStub: StubbedInstance<CodeWhispererService>
65-
let codewhispererStubFactory: sinon.SinonStub<any[], StubbedInstance<CodeWhispererService>>
64+
let codewhispererServiceStub: StubbedInstance<CodeWhispererServiceToken>
65+
let codewhispererStubFactory: sinon.SinonStub<any[], StubbedInstance<CodeWhispererServiceToken>>
6666
let sdkInitializatorSpy: sinon.SinonSpy
6767
let getListAllAvailableProfilesHandlerStub: sinon.SinonStub
6868

@@ -94,7 +94,7 @@ describe('Token', () => {
9494
v2: sinon.spy(features.sdkInitializator.v2),
9595
})
9696

97-
codewhispererServiceStub = stubInterface<CodeWhispererService>()
97+
codewhispererServiceStub = stubInterface<CodeWhispererServiceToken>()
9898
// @ts-ignore
9999
codewhispererServiceStub.client = sinon.stub()
100100
codewhispererServiceStub.customizationArn = undefined
@@ -141,7 +141,7 @@ describe('Token', () => {
141141

142142
const setupServiceManagerWithProfile = async (
143143
profileArn = 'arn:aws:testprofilearn:us-east-1:11111111111111:profile/QQQQQQQQQQQQ'
144-
): Promise<CodeWhispererService> => {
144+
): Promise<CodeWhispererServiceToken> => {
145145
setupServiceManager(true)
146146
assert.strictEqual(amazonQServiceManager.getState(), 'PENDING_CONNECTION')
147147

@@ -161,7 +161,7 @@ describe('Token', () => {
161161
assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED')
162162
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter')
163163

164-
return service
164+
return service as CodeWhispererServiceToken
165165
}
166166

167167
describe('Initialization process', () => {
@@ -245,7 +245,7 @@ describe('Token', () => {
245245
assert.strictEqual(amazonQServiceManager.getState(), 'INITIALIZED')
246246
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId')
247247

248-
assert(streamingClient instanceof StreamingClientService)
248+
assert(streamingClient instanceof StreamingClientServiceToken)
249249
assert(codewhispererServiceStub.generateSuggestions.calledOnce)
250250
})
251251

@@ -318,7 +318,7 @@ describe('Token', () => {
318318
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter')
319319
assert(codewhispererServiceStub.generateSuggestions.calledOnce)
320320

321-
assert(streamingClient instanceof StreamingClientService)
321+
assert(streamingClient instanceof StreamingClientServiceToken)
322322
})
323323
})
324324

@@ -379,7 +379,7 @@ describe('Token', () => {
379379
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter')
380380
assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1))
381381

382-
assert(streamingClient instanceof StreamingClientService)
382+
assert(streamingClient instanceof StreamingClientServiceToken)
383383
assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1')
384384
})
385385

@@ -456,7 +456,7 @@ describe('Token', () => {
456456
)
457457

458458
assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1))
459-
assert(streamingClient1 instanceof StreamingClientService)
459+
assert(streamingClient1 instanceof StreamingClientServiceToken)
460460
assert.strictEqual(await streamingClient1.client.config.region(), 'us-east-1')
461461

462462
// Profile change
@@ -483,7 +483,7 @@ describe('Token', () => {
483483
// CodeWhisperer Service was not recreated
484484
assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1))
485485

486-
assert(streamingClient2 instanceof StreamingClientService)
486+
assert(streamingClient2 instanceof StreamingClientServiceToken)
487487
assert.strictEqual(streamingClient1, streamingClient2)
488488
assert.strictEqual(await streamingClient2.client.config.region(), 'us-east-1')
489489
})
@@ -516,7 +516,7 @@ describe('Token', () => {
516516
)
517517
assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1))
518518

519-
assert(streamingClient1 instanceof StreamingClientService)
519+
assert(streamingClient1 instanceof StreamingClientServiceToken)
520520
assert.strictEqual(await streamingClient1.client.config.region(), 'us-east-1')
521521

522522
// Profile change
@@ -548,7 +548,7 @@ describe('Token', () => {
548548
])
549549

550550
// Streaming Client was recreated
551-
assert(streamingClient2 instanceof StreamingClientService)
551+
assert(streamingClient2 instanceof StreamingClientServiceToken)
552552
assert.notStrictEqual(streamingClient1, streamingClient2)
553553
assert.strictEqual(await streamingClient2.client.config.region(), 'eu-central-1')
554554
})
@@ -582,7 +582,7 @@ describe('Token', () => {
582582
)
583583
assert(codewhispererStubFactory.calledOnceWithExactly('us-east-1', TEST_ENDPOINT_US_EAST_1))
584584

585-
assert(streamingClient instanceof StreamingClientService)
585+
assert(streamingClient instanceof StreamingClientServiceToken)
586586
assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1')
587587

588588
// Profile change to invalid profile
@@ -691,7 +691,7 @@ describe('Token', () => {
691691
TEST_ENDPOINT_EU_CENTRAL_1,
692692
])
693693

694-
assert(streamingClient instanceof StreamingClientService)
694+
assert(streamingClient instanceof StreamingClientServiceToken)
695695
assert.strictEqual(await streamingClient.client.config.region(), 'eu-central-1')
696696
})
697697

@@ -726,7 +726,7 @@ describe('Token', () => {
726726
)
727727
assert.deepStrictEqual(codewhispererStubFactory.lastCall.args, ['us-east-1', TEST_ENDPOINT_US_EAST_1])
728728

729-
assert(streamingClient instanceof StreamingClientService)
729+
assert(streamingClient instanceof StreamingClientServiceToken)
730730
assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1')
731731

732732
// Updaing profile
@@ -907,7 +907,7 @@ describe('Token', () => {
907907
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId')
908908
assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined)
909909

910-
assert(streamingClient instanceof StreamingClientService)
910+
assert(streamingClient instanceof StreamingClientServiceToken)
911911
assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1')
912912

913913
setCredentials('identityCenter')
@@ -921,7 +921,7 @@ describe('Token', () => {
921921
assert(codewhispererStubFactory.calledTwice)
922922
assert(codewhispererStubFactory.calledWithExactly(DEFAULT_AWS_Q_REGION, DEFAULT_AWS_Q_ENDPOINT_URL))
923923

924-
assert(streamingClient2 instanceof StreamingClientService)
924+
assert(streamingClient2 instanceof StreamingClientServiceToken)
925925
assert.strictEqual(await streamingClient2.client.config.region(), DEFAULT_AWS_Q_REGION)
926926
})
927927

@@ -937,7 +937,7 @@ describe('Token', () => {
937937
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'builderId')
938938
assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined)
939939

940-
assert(streamingClient instanceof StreamingClientService)
940+
assert(streamingClient instanceof StreamingClientServiceToken)
941941
assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1')
942942

943943
setCredentials('identityCenter')
@@ -967,7 +967,7 @@ describe('Token', () => {
967967
assert.strictEqual(amazonQServiceManager.getConnectionType(), 'identityCenter')
968968
assert.strictEqual(amazonQServiceManager.getActiveProfileArn(), undefined)
969969

970-
assert(streamingClient instanceof StreamingClientService)
970+
assert(streamingClient instanceof StreamingClientServiceToken)
971971
assert.strictEqual(await streamingClient.client.config.region(), 'us-east-1')
972972

973973
setCredentials('builderId')
@@ -981,7 +981,7 @@ describe('Token', () => {
981981
assert(codewhispererStubFactory.calledTwice)
982982
assert(codewhispererStubFactory.calledWithExactly(DEFAULT_AWS_Q_REGION, DEFAULT_AWS_Q_ENDPOINT_URL))
983983

984-
assert(streamingClient2 instanceof StreamingClientService)
984+
assert(streamingClient2 instanceof StreamingClientServiceToken)
985985
assert.strictEqual(await streamingClient2.client.config.region(), 'us-east-1')
986986
})
987987
})
@@ -1081,19 +1081,12 @@ describe('IAM', () => {
10811081
})
10821082

10831083
it('should initialize the streaming client only once', () => {
1084+
// Mock the credentials provider to return credentials when requested
10841085
setCredentials()
1085-
// Mock getIAMCredentialsFromProvider to return dummy credentials
1086-
const getIAMCredentialsStub = sinon.stub(utils, 'getIAMCredentialsFromProvider').returns({
1087-
accessKeyId: 'dummy-access-key',
1088-
secretAccessKey: 'dummy-secret-key',
1089-
sessionToken: 'dummy-session-token',
1090-
})
1091-
10921086
const streamingClient = serviceManager.getStreamingClient()
10931087

1088+
// Verify that getting the client again returns the same instance
10941089
assert.deepStrictEqual(serviceManager.getStreamingClient(), streamingClient)
1095-
1096-
getIAMCredentialsStub.restore()
10971090
})
10981091
})
10991092
})

server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQServiceManager.ts

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ import { AmazonQDeveloperProfile, signalsAWSQDeveloperProfilesEnabled } from './
3232
import { isStringOrNull } from '../utils'
3333
import { getAmazonQRegionAndEndpoint } from './configurationUtils'
3434
import { getUserAgent } from '../telemetryUtils'
35-
import { StreamingClientServiceToken, StreamingClientServiceIAM } from '../streamingClientService'
35+
import {
36+
StreamingClientServiceToken,
37+
StreamingClientServiceIAM,
38+
StreamingClientServiceBase,
39+
} from '../streamingClientService'
3640
import { parse } from '@aws-sdk/util-arn-parser'
41+
import { CodeWhispererServiceBase } from '../codeWhispererService/codeWhispererServiceBase'
3742

3843
/**
3944
* AmazonQServiceManager manages state and provides centralized access to
@@ -60,7 +65,10 @@ import { parse } from '@aws-sdk/util-arn-parser'
6065
* the LSP hand shake.
6166
*
6267
*/
63-
export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhispererService, StreamingClientService> {
68+
export class AmazonQServiceManager extends BaseAmazonQServiceManager<
69+
CodeWhispererServiceBase,
70+
StreamingClientServiceBase
71+
> {
6472
private static instance: AmazonQServiceManager | null = null
6573
private enableDeveloperProfileSupport?: boolean
6674
private activeIdcProfile?: AmazonQDeveloperProfile
@@ -185,33 +193,6 @@ export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhisper
185193
}
186194
}
187195

188-
private handleConnectionChange() {
189-
if (this.features.credentialsProvider.hasCredentials('iam')) {
190-
if (!this.cachedCodewhispererService) {
191-
const amazonQRegionAndEndpoint = getAmazonQRegionAndEndpoint(
192-
this.features.runtime,
193-
this.features.logging
194-
)
195-
this.region = amazonQRegionAndEndpoint.region
196-
this.endpoint = amazonQRegionAndEndpoint.endpoint
197-
this.cachedCodewhispererService = new CodeWhispererService(
198-
this.features.credentialsProvider,
199-
this.features.workspace,
200-
this.features.logging,
201-
this.region,
202-
this.endpoint,
203-
this.features.sdkInitializator
204-
)
205-
this.updateCachedServiceConfig()
206-
}
207-
this.state = 'INITIALIZED'
208-
return
209-
} else {
210-
this.handleSsoConnectionChange()
211-
return
212-
}
213-
}
214-
215196
/**
216197
* Validate if Bearer Token Connection type has changed mid-session.
217198
* When connection type change is detected: reinitialize CodeWhispererService class with current connection type.
@@ -435,13 +416,34 @@ export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhisper
435416
return
436417
}
437418

438-
public getCodewhispererService(): CodeWhispererService {
419+
public getCodewhispererService(): CodeWhispererServiceBase {
439420
// Prevent initiating requests while profile change is in progress.
440421
if (this.state === 'PENDING_Q_PROFILE_UPDATE') {
441422
throw new AmazonQServicePendingProfileUpdateError()
442423
}
443424

444-
this.handleConnectionChange()
425+
if (this.features.credentialsProvider.hasCredentials('iam')) {
426+
if (!this.cachedCodewhispererService) {
427+
const amazonQRegionAndEndpoint = getAmazonQRegionAndEndpoint(
428+
this.features.runtime,
429+
this.features.logging
430+
)
431+
this.region = amazonQRegionAndEndpoint.region
432+
this.endpoint = amazonQRegionAndEndpoint.endpoint
433+
this.cachedCodewhispererService = new CodeWhispererServiceIAM(
434+
this.features.credentialsProvider,
435+
this.features.workspace,
436+
this.features.logging,
437+
this.region,
438+
this.endpoint,
439+
this.features.sdkInitializator
440+
)
441+
this.updateCachedServiceConfig()
442+
}
443+
this.state = 'INITIALIZED'
444+
} else {
445+
this.handleSsoConnectionChange()
446+
}
445447

446448
if (this.state === 'INITIALIZED' && this.cachedCodewhispererService) {
447449
return this.cachedCodewhispererService
@@ -458,7 +460,7 @@ export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhisper
458460
throw new AmazonQServiceNotInitializedError()
459461
}
460462

461-
public getStreamingClient() {
463+
public getStreamingClient(): StreamingClientServiceBase {
462464
this.log('Getting instance of CodeWhispererStreaming client')
463465

464466
// Trigger checks in token service
@@ -517,8 +519,8 @@ export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhisper
517519
return getUserAgent(initializeParams as InitializeParams, this.features.runtime.serverInfo)
518520
}
519521

520-
private serviceFactory(region: string, endpoint: string): CodeWhispererService {
521-
const service = new CodeWhispererService(
522+
private serviceFactory(region: string, endpoint: string): CodeWhispererServiceToken {
523+
const service = new CodeWhispererServiceToken(
522524
this.features.credentialsProvider,
523525
this.features.workspace,
524526
this.features.logging,
@@ -545,17 +547,25 @@ export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhisper
545547
return service
546548
}
547549

548-
private streamingClientFactory(region: string, endpoint: string): StreamingClientService {
549-
const streamingClient = new StreamingClientService(
550-
this.features.credentialsProvider,
551-
this.features.sdkInitializator,
552-
this.features.logging,
553-
region,
554-
endpoint,
555-
this.getCustomUserAgent()
556-
)
557-
558-
if (this.features.credentialsProvider.hasCredentials('bearer')) {
550+
private streamingClientFactory(region: string, endpoint: string): StreamingClientServiceBase {
551+
let streamingClient: StreamingClientServiceBase
552+
if (this.features.credentialsProvider.hasCredentials('iam')) {
553+
streamingClient = new StreamingClientServiceIAM(
554+
this.features.credentialsProvider,
555+
this.features.sdkInitializator,
556+
this.features.logging,
557+
region,
558+
endpoint
559+
)
560+
} else {
561+
streamingClient = new StreamingClientServiceToken(
562+
this.features.credentialsProvider,
563+
this.features.sdkInitializator,
564+
this.features.logging,
565+
region,
566+
endpoint,
567+
this.getCustomUserAgent()
568+
)
559569
streamingClient.profileArn = this.activeIdcProfile?.arn
560570
}
561571

@@ -598,7 +608,7 @@ export class AmazonQServiceManager extends BaseAmazonQServiceManager<CodeWhisper
598608
return this.activeIdcProfile?.arn
599609
}
600610

601-
public setServiceFactory(factory: (region: string, endpoint: string) => CodeWhispererService) {
611+
public setServiceFactory(factory: (region: string, endpoint: string) => CodeWhispererServiceToken) {
602612
this.serviceFactory = factory.bind(this)
603613
}
604614

0 commit comments

Comments
 (0)