diff --git a/api_version.lock b/api_version.lock index 627ab5f24..7728a9f96 100644 --- a/api_version.lock +++ b/api_version.lock @@ -1 +1 @@ -v0.1.483 +v0.1.484 diff --git a/examples/teleop-react/.gitignore b/examples/teleop-react/.gitignore new file mode 100644 index 000000000..2bd180399 --- /dev/null +++ b/examples/teleop-react/.gitignore @@ -0,0 +1,26 @@ +.env + +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/teleop-react/src/components/connect-form.tsx b/examples/teleop-react/src/components/connect-form.tsx index 699b9aeff..82fd65cf9 100644 --- a/examples/teleop-react/src/components/connect-form.tsx +++ b/examples/teleop-react/src/components/connect-form.tsx @@ -52,7 +52,7 @@ export const ConnectForm = (props: ConnectFormProps): JSX.Element => { setApiKeyId(event.target.value); }; const handleApiKey: ChangeEventHandler = (event) => { - setApiKeyId(event.target.value); + setApiKey(event.target.value); }; const handleSubmit: FormEventHandler = (event) => { onSubmit({ hostname, apiKeyId, apiKey }); diff --git a/package-lock.json b/package-lock.json index 31eb98cbc..11b9a307d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,7 +17,7 @@ }, "devDependencies": { "@bufbuild/buf": "^1.15.0-1", - "@playwright/test": "1.45.3", + "@playwright/test": "1.56.1", "@types/node": "^20.11.10", "@typescript-eslint/eslint-plugin": "^6.17.0", "@typescript-eslint/parser": "^6.17.0", @@ -31,7 +31,7 @@ "eslint-plugin-tsdoc": "^0.2.17", "eslint-plugin-vitest": "^0.3.20", "grpc-web": "^1.4.2", - "happy-dom": "^15.10.1", + "happy-dom": "^20.0.8", "npm-check-updates": "^17.1.11", "prettier": "^3.1.1", "prettier-plugin-jsdoc": "^1.1.1", @@ -883,13 +883,13 @@ } }, "node_modules/@playwright/test": { - "version": "1.45.3", - "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.45.3.tgz", - "integrity": "sha512-UKF4XsBfy+u3MFWEH44hva1Q8Da28G6RFtR2+5saw+jgAFQV5yYnB1fu68Mz7fO+5GJF3wgwAIs0UelU8TxFrA==", + "version": "1.56.1", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.56.1.tgz", + "integrity": "sha512-vSMYtL/zOcFpvJCW71Q/OEGQb7KYBPAdKh35WNSkaZA75JlAO8ED8UN6GUNTm3drWomcbcqRPFqQbLae8yBTdg==", "dev": true, "license": "Apache-2.0", "dependencies": { - "playwright": "1.45.3" + "playwright": "1.56.1" }, "bin": { "playwright": "cli.js" @@ -1283,6 +1283,13 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/whatwg-mimetype": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/whatwg-mimetype/-/whatwg-mimetype-3.0.2.tgz", + "integrity": "sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==", + "dev": true, + "license": "MIT" + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "6.21.0", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-6.21.0.tgz", @@ -3184,18 +3191,18 @@ "license": "Apache-2.0" }, "node_modules/happy-dom": { - "version": "15.11.7", - "resolved": "https://registry.npmjs.org/happy-dom/-/happy-dom-15.11.7.tgz", - "integrity": "sha512-KyrFvnl+J9US63TEzwoiJOQzZBJY7KgBushJA8X61DMbNsH+2ONkDuLDnCnwUiPTF42tLoEmrPyoqbenVA5zrg==", + "version": "20.0.8", + "resolved": "https://registry.npmjs.org/happy-dom/-/happy-dom-20.0.8.tgz", + "integrity": "sha512-TlYaNQNtzsZ97rNMBAm8U+e2cUQXNithgfCizkDgc11lgmN4j9CKMhO3FPGKWQYPwwkFcPpoXYF/CqEPLgzfOg==", "dev": true, "license": "MIT", "dependencies": { - "entities": "^4.5.0", - "webidl-conversions": "^7.0.0", + "@types/node": "^20.0.0", + "@types/whatwg-mimetype": "^3.0.2", "whatwg-mimetype": "^3.0.0" }, "engines": { - "node": ">=18.0.0" + "node": ">=20.0.0" } }, "node_modules/has-flag": { @@ -4576,13 +4583,13 @@ } }, "node_modules/playwright": { - "version": "1.45.3", - "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.45.3.tgz", - "integrity": "sha512-QhVaS+lpluxCaioejDZ95l4Y4jSFCsBvl2UZkpeXlzxmqS+aABr5c82YmfMHrL6x27nvrvykJAFpkzT2eWdJww==", + "version": "1.56.1", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.56.1.tgz", + "integrity": "sha512-aFi5B0WovBHTEvpM3DzXTUaeN6eN0qWnTkKx4NQaH4Wvcmc153PdaY2UBdSYKaGYw+UyWXSVyxDUg5DoPEttjw==", "dev": true, "license": "Apache-2.0", "dependencies": { - "playwright-core": "1.45.3" + "playwright-core": "1.56.1" }, "bin": { "playwright": "cli.js" @@ -4595,9 +4602,9 @@ } }, "node_modules/playwright-core": { - "version": "1.45.3", - "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.45.3.tgz", - "integrity": "sha512-+ym0jNbcjikaOwwSZycFbwkWgfruWvYlJfThKYAlImbxUgdWFO2oW70ojPm4OpE4t6TAo2FY/smM+hpVTtkhDA==", + "version": "1.56.1", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.56.1.tgz", + "integrity": "sha512-hutraynyn31F+Bifme+Ps9Vq59hKuUCz7H1kDOcBs+2oGguKkWTU50bBWrtz34OUWmIwpBTWDxaRPXrIXkgvmQ==", "dev": true, "license": "Apache-2.0", "bin": { @@ -5673,9 +5680,9 @@ } }, "node_modules/vite": { - "version": "5.4.20", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.20.tgz", - "integrity": "sha512-j3lYzGC3P+B5Yfy/pfKNgVEg4+UtcIJcVRt2cDjIOmhLourAqPqf8P7acgxeiSgUB7E3p2P8/3gNIgDLpwzs4g==", + "version": "5.4.21", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.21.tgz", + "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", "dependencies": { @@ -5836,16 +5843,6 @@ } } }, - "node_modules/webidl-conversions": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-7.0.0.tgz", - "integrity": "sha512-VwddBukDzu71offAQR975unBIGqfKZpM+8ZX6ySk8nYhVoo5CYaZyzt3YBvYtRtO+aoGlqxPg/B87NGVZ/fu6g==", - "dev": true, - "license": "BSD-2-Clause", - "engines": { - "node": ">=12" - } - }, "node_modules/whatwg-mimetype": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-3.0.0.tgz", diff --git a/package.json b/package.json index 4335b7521..d66460c4d 100644 --- a/package.json +++ b/package.json @@ -60,7 +60,7 @@ }, "devDependencies": { "@bufbuild/buf": "^1.15.0-1", - "@playwright/test": "1.45.3", + "@playwright/test": "1.56.1", "@types/node": "^20.11.10", "@typescript-eslint/eslint-plugin": "^6.17.0", "@typescript-eslint/parser": "^6.17.0", @@ -74,7 +74,7 @@ "eslint-plugin-tsdoc": "^0.2.17", "eslint-plugin-vitest": "^0.3.20", "grpc-web": "^1.4.2", - "happy-dom": "^15.10.1", + "happy-dom": "^20.0.8", "npm-check-updates": "^17.1.11", "prettier": "^3.1.1", "prettier-plugin-jsdoc": "^1.1.1", diff --git a/src/app/viam-client.spec.ts b/src/app/viam-client.spec.ts index d3aae6eed..21bd9f506 100644 --- a/src/app/viam-client.spec.ts +++ b/src/app/viam-client.spec.ts @@ -1,7 +1,10 @@ // @vitest-environment happy-dom import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { Location, RobotPart, SharedSecret_State } from '../gen/app/v1/app_pb'; +import { + GetRobotPartByNameAndLocationResponse, + RobotPart, +} from '../gen/app/v1/app_pb'; import { createRobotClient } from '../robot/dial'; import { AppClient } from './app-client'; import { BillingClient } from './billing-client'; @@ -156,98 +159,64 @@ describe('ViamClient', () => { ).rejects.toThrowError('not provided and could not be obtained'); }); - it('gets location secret if credential is access token -- host', async () => { + it('gets robot secret if credential is access token -- host', async () => { options = { credentials: testAccessToken }; const client = await subject(); - const location = new Location({ - auth: { - secrets: [ - { - id: '0', - state: SharedSecret_State.DISABLED, // eslint-disable-line camelcase - secret: 'disabled secret', - }, - { - id: '1', - state: SharedSecret_State.UNSPECIFIED, // eslint-disable-line camelcase - secret: 'unspecified secret', - }, - { - id: '2', - state: SharedSecret_State.ENABLED, // eslint-disable-line camelcase - secret: 'enabled secret', - }, - ], - locationId: 'location', - secret: 'secret', - }, + const MAIN_PART = new RobotPart({ + mainPart: true, + name: 'main-part', + secret: 'fake-robot-secret', }); - const getLocationMock = vi.fn().mockImplementation(() => location); - AppClient.prototype.getLocation = getLocationMock; + const partByNameAndLocationResponse = + new GetRobotPartByNameAndLocationResponse({ + part: MAIN_PART, + }); + const getRobotPartByNameAndLocationMock = vi + .fn() + .mockImplementation(() => partByNameAndLocationResponse); + AppClient.prototype.getRobotPartByNameAndLocation = + getRobotPartByNameAndLocationMock; await client.connectToMachine({ host: 'main-part.location.viam.cloud', }); - expect(getLocationMock).toHaveBeenCalledWith('location'); + expect(getRobotPartByNameAndLocationMock).toHaveBeenCalledWith( + 'main-part', + 'location' + ); expect(createRobotClient).toHaveBeenCalledWith( expect.objectContaining({ credentials: expect.objectContaining({ - type: 'robot-location-secret', - payload: 'enabled secret', + type: 'robot-secret', + payload: 'fake-robot-secret', }), }) ); }); - it('gets location secret if credential is access token -- id', async () => { + it('gets robot secret if credential is access token -- id', async () => { options = { credentials: testAccessToken }; const client = await subject(); const MAIN_PART = new RobotPart({ mainPart: true, - locationId: 'location-id', fqdn: 'main-part.fqdn', + secret: 'fake-robot-secret', }); const robotParts = [MAIN_PART]; const getRobotPartsMock = vi.fn().mockImplementation(() => robotParts); AppClient.prototype.getRobotParts = getRobotPartsMock; - const location = new Location({ - auth: { - secrets: [ - { - id: '0', - state: SharedSecret_State.DISABLED, // eslint-disable-line camelcase - secret: 'disabled secret', - }, - { - id: '1', - state: SharedSecret_State.UNSPECIFIED, // eslint-disable-line camelcase - secret: 'unspecified secret', - }, - { - id: '2', - state: SharedSecret_State.ENABLED, // eslint-disable-line camelcase - secret: 'enabled secret', - }, - ], - locationId: 'location', - secret: 'secret', - }, - }); - const getLocationMock = vi.fn().mockImplementation(() => location); - AppClient.prototype.getLocation = getLocationMock; - await client.connectToMachine({ id: 'machine-uuid', }); - expect(getLocationMock).toHaveBeenCalledWith('location-id'); + expect(getRobotPartsMock).toHaveBeenCalledWith('machine-uuid'); expect(createRobotClient).toHaveBeenCalledWith( expect.objectContaining({ credentials: expect.objectContaining({ - type: 'robot-location-secret', - payload: 'enabled secret', + type: 'robot-secret', + payload: 'fake-robot-secret', }), }) ); diff --git a/src/app/viam-client.ts b/src/app/viam-client.ts index 1b4a9ae4d..581b4c184 100644 --- a/src/app/viam-client.ts +++ b/src/app/viam-client.ts @@ -1,5 +1,4 @@ import type { Transport } from '@connectrpc/connect'; -import { SharedSecret_State } from '../gen/app/v1/app_pb'; import { createRobotClient } from '../robot/dial'; import { AppClient } from './app-client'; import { BillingClient } from './billing-client'; @@ -54,6 +53,26 @@ export class ViamClient { this.billingClient = new BillingClient(this.transport); } + async getRobotSecretFromHost(host: string): Promise { + const firstHalf = host.split('.viam.'); + const locationSplit = firstHalf[0]?.split('.'); + if (locationSplit !== undefined) { + const locationId = locationSplit.at(-1); + if (locationId === undefined) { + return undefined; + } + const name = host.split('.').at(0); + if (name !== undefined) { + const resp = await this.appClient.getRobotPartByNameAndLocation( + name, + locationId + ); + return resp.part?.secret; + } + } + return undefined; + } + public async connectToMachine({ host = undefined, id = undefined, @@ -62,7 +81,7 @@ export class ViamClient { throw new Error('Either a machine address or ID must be provided'); } let address = host; - let locationId: string | undefined = undefined; + let robotSecret: string | undefined = undefined; // Get address if only ID was provided if (id !== undefined && host === undefined) { @@ -74,7 +93,7 @@ export class ViamClient { ); } address = mainPart.fqdn; - locationId = mainPart.locationId; + robotSecret = mainPart.secret; } if (address === undefined || address === '') { @@ -83,31 +102,20 @@ export class ViamClient { ); } - // If credentials is AccessToken, then attempt to get the robot location secret + // If credentials is AccessToken, then attempt to use the robot part secret let creds = this.credentials; if (!isCredential(creds)) { - if (locationId === undefined) { - // If we don't have a location, try to get it from the address - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const firstHalf = address.split('.viam.'); - const locationSplit = firstHalf[0]?.split('.'); - if (locationSplit !== undefined) { - locationId = locationSplit.at(-1); - } - } - if (locationId !== undefined) { - // If we found the location, then attempt to get its secret - const location = await this.appClient.getLocation(locationId); - const secret = location?.auth?.secrets.find( - // eslint-disable-next-line camelcase - (sec) => sec.state === SharedSecret_State.ENABLED - ); - creds = { - type: 'robot-location-secret', - payload: secret?.secret, - authEntity: address, - } as Credential; + if (robotSecret === undefined) { + robotSecret = await this.getRobotSecretFromHost(address); } + creds = + robotSecret === undefined + ? creds + : ({ + type: 'robot-secret', + payload: robotSecret, + authEntity: address, + } as Credential); } return createRobotClient({ diff --git a/src/app/viam-transport.ts b/src/app/viam-transport.ts index 725b94c79..e86bd7d83 100644 --- a/src/app/viam-transport.ts +++ b/src/app/viam-transport.ts @@ -16,10 +16,7 @@ export interface Credential { payload: string; } -export type CredentialType = - | 'robot-location-secret' - | 'api-key' - | 'robot-secret'; +export type CredentialType = 'api-key' | 'robot-secret'; /** An access token used to access protected resources. */ export interface AccessToken { diff --git a/src/components/arm/arm.ts b/src/components/arm/arm.ts index f6f2682b9..b3e54f710 100644 --- a/src/components/arm/arm.ts +++ b/src/components/arm/arm.ts @@ -1,8 +1,9 @@ import type { PlainMessage, Struct } from '@bufbuild/protobuf'; -import type { Pose, Resource } from '../../types'; +import type { Pose, Resource, Vector3 } from '../../types'; import * as armApi from '../../gen/component/arm/v1/arm_pb'; import type { Geometry } from '../../gen/common/v1/common_pb'; +import type { Frame } from '../../gen/app/v1/robot_pb'; export type ArmJointPositions = PlainMessage; @@ -41,6 +42,34 @@ export interface Arm extends Resource { */ getGeometries: (extra?: Struct) => Promise; + /** + * Get the kinematics information associated with the arm. + * + * @example + * + * ```ts + * const arm = new VIAM.ArmClient(machine, 'my_arm'); + * const kinematics = await arm.getKinematics(); + * console.log(kinematics); + * + * For more information, see [Arm + * API](https://docs.viam.com/dev/reference/apis/components/arm/#getkinematics). + * ``` + */ + getKinematics: (extra?: Struct) => Promise<{ + name: string; + kinematic_param_type: 'SVA' | 'URDF' | 'UNSPECIFIED'; + joints: { + id: string; + type: string; + parent: string; + axis: Vector3; + max: number; + min: number; + }[]; + links: Frame[]; + }>; + /** * Move the end of the arm to the pose. * diff --git a/src/components/arm/client.ts b/src/components/arm/client.ts index be2651067..ebf3872e7 100644 --- a/src/components/arm/client.ts +++ b/src/components/arm/client.ts @@ -14,7 +14,10 @@ import type { RobotClient } from '../../robot'; import type { Options, Pose } from '../../types'; import { doCommandFromClient } from '../../utils'; import type { Arm } from './arm'; -import { GetGeometriesRequest } from '../../gen/common/v1/common_pb'; +import { + GetGeometriesRequest, + GetKinematicsRequest, +} from '../../gen/common/v1/common_pb'; /** * A gRPC-web client for the Arm component. @@ -59,6 +62,20 @@ export class ArmClient implements Arm { return response.geometries; } + async getKinematics(extra = {}, callOptions = this.callOptions) { + const request = new GetKinematicsRequest({ + name: this.name, + extra: Struct.fromJson(extra), + }); + + const response = await this.client.getKinematics(request, callOptions); + + const decoder = new TextDecoder('utf8'); + const jsonString = decoder.decode(response.kinematicsData); + + return JSON.parse(jsonString) as ReturnType; + } + async moveToPosition(pose: Pose, extra = {}, callOptions = this.callOptions) { const request = new MoveToPositionRequest({ name: this.name, diff --git a/src/robot/client.spec.ts b/src/robot/client.spec.ts new file mode 100644 index 000000000..4ed02991e --- /dev/null +++ b/src/robot/client.spec.ts @@ -0,0 +1,393 @@ +// @vitest-environment happy-dom + +import { + beforeEach, + afterEach, + describe, + expect, + it, + vi, + type MockInstance, +} from 'vitest'; +import type { Transport } from '@connectrpc/connect'; +import { createRouterTransport } from '@connectrpc/connect'; +import { RobotService } from '../gen/robot/v1/robot_connect'; +import { RobotClient } from './client'; +import * as rpcModule from '../rpc'; + +vi.mock('../rpc', async () => { + const actual = await vi.importActual('../rpc'); + return { + ...actual, + dialWebRTC: vi.fn(), + dialDirect: vi.fn(), + }; +}); + +describe('RobotClient', () => { + let mockTransport: Transport; + let mockPeerConnection: RTCPeerConnection; + let mockDataChannel: RTCDataChannel; + let client: RobotClient; + + beforeEach(() => { + mockTransport = createRouterTransport(({ service }) => { + service(RobotService, { + resourceNames: () => ({ resources: [] }), + getOperations: () => ({ operations: [] }), + }); + }); + + mockPeerConnection = { + close: vi.fn(), + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + iceConnectionState: 'connected', + } as unknown as RTCPeerConnection; + + mockDataChannel = { + close: vi.fn(), + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + readyState: 'open', + } as unknown as RTCDataChannel; + + vi.mocked(rpcModule.dialWebRTC).mockResolvedValue({ + transport: mockTransport, + peerConnection: mockPeerConnection, + dataChannel: mockDataChannel, + }); + + client = new RobotClient(); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + describe('event listeners', () => { + let pcAddEventListenerSpy: ReturnType; + let pcRemoveEventListenerSpy: ReturnType; + + let dcAddEventListenerSpy: ReturnType; + let dcRemoveEventListenerSpy: ReturnType; + + beforeEach(() => { + pcAddEventListenerSpy = vi.fn(); + pcRemoveEventListenerSpy = vi.fn(); + dcAddEventListenerSpy = vi.fn(); + dcRemoveEventListenerSpy = vi.fn(); + + mockPeerConnection = { + close: vi.fn(), + addEventListener: pcAddEventListenerSpy, + removeEventListener: pcRemoveEventListenerSpy, + iceConnectionState: 'connected', + } as unknown as RTCPeerConnection; + + mockDataChannel = { + close: vi.fn(), + addEventListener: dcAddEventListenerSpy, + removeEventListener: dcRemoveEventListenerSpy, + readyState: 'open', + } as unknown as RTCDataChannel; + + vi.mocked(rpcModule.dialWebRTC).mockResolvedValue({ + transport: mockTransport, + peerConnection: mockPeerConnection, + dataChannel: mockDataChannel, + }); + }); + + it.each([ + { + eventType: 'iceconnectionstatechange', + addSpy: () => pcAddEventListenerSpy, + removeSpy: () => pcRemoveEventListenerSpy, + description: 'peer connection iceconnectionstatechange', + }, + { + eventType: 'close', + addSpy: () => dcAddEventListenerSpy, + removeSpy: () => dcRemoveEventListenerSpy, + description: 'data channel close', + }, + { + eventType: 'track', + addSpy: () => pcAddEventListenerSpy, + removeSpy: () => pcRemoveEventListenerSpy, + description: 'peer connection track', + }, + ])( + 'should remove old $description handler before adding new one', + async ({ eventType, addSpy, removeSpy }) => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + disableSessions: true, + noReconnect: true, + }); + + const firstCallArgs = addSpy().mock.calls.find( + (call) => call[0] === eventType + ); + + expect(firstCallArgs).toBeDefined(); + + const firstHandler = firstCallArgs?.[1]; + + addSpy().mockClear(); + removeSpy().mockClear(); + + // simulate reconnection + await client.connect(); + + const removeCallArgs = removeSpy().mock.calls.find( + (call) => call[0] === eventType + ); + + const secondCallArgs = addSpy().mock.calls.find( + (call) => call[0] === eventType + ); + + expect(removeCallArgs).toBeDefined(); + expect(removeCallArgs?.[1]).toBe(firstHandler); + expect(secondCallArgs).toBeDefined(); + } + ); + + it.each([ + { + eventType: 'iceconnectionstatechange', + addSpy: () => pcAddEventListenerSpy, + removeSpy: () => pcRemoveEventListenerSpy, + description: 'iceconnectionstatechange', + }, + { + eventType: 'close', + addSpy: () => dcAddEventListenerSpy, + removeSpy: () => dcRemoveEventListenerSpy, + description: 'data channel close', + }, + { + eventType: 'track', + addSpy: () => pcAddEventListenerSpy, + removeSpy: () => pcRemoveEventListenerSpy, + description: 'track', + }, + ])( + 'should only have one $description handler at a time', + async ({ eventType, addSpy, removeSpy }) => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + disableSessions: true, + noReconnect: true, + }); + + const firstConnectionCalls = addSpy().mock.calls.filter( + (call) => call[0] === eventType + ); + + expect(firstConnectionCalls).toHaveLength(1); + + // simulate reconnection + await client.connect(); + + const totalCalls = addSpy().mock.calls.filter( + (call) => call[0] === eventType + ); + const removeCalls = removeSpy().mock.calls.filter( + (call) => call[0] === eventType + ); + + expect(totalCalls).toHaveLength(2); + expect(removeCalls).toHaveLength(1); + } + ); + + it('should not accumulate handlers over multiple reconnections', async () => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + disableSessions: true, + noReconnect: true, + }); + + for (let i = 0; i < 5; i += 1) { + // eslint-disable-next-line no-await-in-loop + await client.connect(); + } + + const iceAddCalls = pcAddEventListenerSpy.mock.calls.filter( + (call) => call[0] === 'iceconnectionstatechange' + ); + const iceRemoveCalls = pcRemoveEventListenerSpy.mock.calls.filter( + (call) => call[0] === 'iceconnectionstatechange' + ); + + expect(iceAddCalls).toHaveLength(6); + expect(iceRemoveCalls).toHaveLength(5); + expect(iceAddCalls.length - iceRemoveCalls.length).toBe(1); + }); + + it('should clean up all event handlers when disconnecting', async () => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + disableSessions: true, + noReconnect: true, + }); + + pcRemoveEventListenerSpy.mockClear(); + dcRemoveEventListenerSpy.mockClear(); + + await client.disconnect(); + + const iceRemoveCalls = pcRemoveEventListenerSpy.mock.calls.filter( + (call) => call[0] === 'iceconnectionstatechange' + ); + const trackRemoveCalls = pcRemoveEventListenerSpy.mock.calls.filter( + (call) => call[0] === 'track' + ); + + const dcRemoveCalls = dcRemoveEventListenerSpy.mock.calls.filter( + (call) => call[0] === 'close' + ); + + expect(iceRemoveCalls.length).toBeGreaterThanOrEqual(1); + expect(trackRemoveCalls.length).toBeGreaterThanOrEqual(1); + expect(dcRemoveCalls.length).toBeGreaterThanOrEqual(1); + }); + }); + + describe('session management on reconnection', () => { + let mockResetFn: MockInstance<[], void>; + + const testCredential = { + authEntity: 'test-entity', + type: 'api-key' as const, + payload: 'test-payload', + }; + + const differentCredential = { + authEntity: 'different-entity', + type: 'api-key' as const, + payload: 'different-payload', + }; + + const accessToken = { + type: 'access-token' as const, + payload: 'test-access-token', + }; + + const differentAccessToken = { + type: 'access-token' as const, + payload: 'different-access-token', + }; + + beforeEach(() => { + // Spy on the SessionManager's reset method to verify conditional reset behavior + // eslint-disable-next-line vitest/no-restricted-vi-methods, @typescript-eslint/dot-notation + mockResetFn = vi.spyOn(client['sessionManager'], 'reset'); + }); + + afterEach(() => { + mockResetFn.mockRestore(); + }); + + it('should reset session when connecting for the first time', async () => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + credentials: testCredential, + disableSessions: false, + noReconnect: true, + }); + + expect(mockResetFn).toHaveBeenCalledTimes(1); + }); + + it.each([ + { + description: + 'should reset session when credentials change during reconnection', + initialCreds: testCredential, + disableSessions: false, + reconnectCreds: differentCredential, + }, + { + description: 'should reset session when sessions are disabled', + initialCreds: testCredential, + disableSessions: true, + reconnectCreds: testCredential, + }, + { + description: + 'should reset session when reconnecting with no saved credentials', + initialCreds: undefined, + disableSessions: false, + reconnectCreds: undefined, + }, + { + description: + 'should reset session when access token changes during reconnection', + initialCreds: accessToken, + disableSessions: false, + reconnectCreds: differentAccessToken, + }, + ])( + '$description', + async ({ initialCreds, disableSessions, reconnectCreds }) => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + credentials: initialCreds, + disableSessions, + noReconnect: true, + }); + + mockResetFn.mockClear(); + + await client.connect({ creds: reconnectCreds }); + + expect(mockResetFn).toHaveBeenCalledTimes(1); + } + ); + + it.each([ + { + description: + 'should NOT reset session when reconnecting with same credentials', + initialCreds: testCredential, + reconnectCreds: testCredential, + }, + { + description: + 'should NOT reset session when reconnecting without explicitly passing creds (uses savedCreds)', + initialCreds: testCredential, + reconnectCreds: undefined, + }, + { + description: + 'should NOT reset session when using access token and reconnecting with same token', + initialCreds: accessToken, + reconnectCreds: accessToken, + }, + ])('$description', async ({ initialCreds, reconnectCreds }) => { + await client.dial({ + host: 'test-host', + signalingAddress: 'https://test.local', + credentials: initialCreds, + disableSessions: false, + noReconnect: true, + }); + + mockResetFn.mockClear(); + + await client.connect({ creds: reconnectCreds }); + + expect(mockResetFn).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/src/robot/client.ts b/src/robot/client.ts index 9e9246251..97ed36656 100644 --- a/src/robot/client.ts +++ b/src/robot/client.ts @@ -184,6 +184,7 @@ export class RobotClient extends EventDispatcher implements Robot { private sessionManager: SessionManager; private peerConn: RTCPeerConnection | undefined; + private dataChannel: RTCDataChannel | undefined; private transport: Transport | undefined; @@ -245,6 +246,10 @@ export class RobotClient extends EventDispatcher implements Robot { private currentRetryAttempt = 0; + private onICEConnectionStateChange?: () => void; + private onDataChannelClose?: (event: Event) => void; + private onTrack?: (event: RTCTrackEvent) => void; + constructor( serviceHost?: string, webrtcOptions?: WebRTCOptions, @@ -306,6 +311,27 @@ export class RobotClient extends EventDispatcher implements Robot { this.closed = false; } + private cleanupEventListeners() { + if (this.peerConn && this.onICEConnectionStateChange) { + this.peerConn.removeEventListener( + 'iceconnectionstatechange', + this.onICEConnectionStateChange + ); + + this.onICEConnectionStateChange = undefined; + } + + if (this.peerConn && this.onTrack) { + this.peerConn.removeEventListener('track', this.onTrack); + this.onTrack = undefined; + } + + if (this.dataChannel && this.onDataChannelClose) { + this.dataChannel.removeEventListener('close', this.onDataChannelClose); + this.onDataChannelClose = undefined; + } + } + private onDisconnect(event?: Event) { this.emit(MachineConnectionEvent.DISCONNECTED, event ?? {}); @@ -638,10 +664,18 @@ export class RobotClient extends EventDispatcher implements Robot { await this.connecting; } + this.cleanupEventListeners(); + if (this.peerConn) { this.peerConn.close(); this.peerConn = undefined; } + + if (this.dataChannel) { + this.dataChannel.close(); + this.dataChannel = undefined; + } + this.sessionManager.reset(); this.closed = true; this.emit(MachineConnectionEvent.DISCONNECTED, {}); @@ -680,11 +714,18 @@ export class RobotClient extends EventDispatcher implements Robot { this.peerConn = undefined; } + if (this.dataChannel) { + this.dataChannel.close(); + this.dataChannel = undefined; + } + /* - * TODO(RSDK-887): no longer reset if we are reusing authentication material; otherwise our session - * and authentication context will no longer match. + * Only reset session if credentials have changed or if explicitly required; + * otherwise our session and authentication context will no longer match. */ - this.sessionManager.reset(); + if (!creds || creds !== this.savedCreds || this.sessionOptions.disabled) { + this.sessionManager.reset(); + } try { const opts: DialOptions = { @@ -727,12 +768,12 @@ export class RobotClient extends EventDispatcher implements Robot { this.serviceHost !== '' && signalingAddress !== this.serviceHost ); - /* - * Lint disabled because we know that we are the only code to - * read and then write to 'peerConn', even after we have awaited/paused. - */ - this.peerConn = webRTCConn.peerConnection; // eslint-disable-line require-atomic-updates - this.peerConn.addEventListener('iceconnectionstatechange', () => { + this.peerConn = webRTCConn.peerConnection; + this.dataChannel = webRTCConn.dataChannel; + + this.cleanupEventListeners(); + + this.onICEConnectionStateChange = () => { /* * TODO: are there any disconnection scenarios where we can reuse the * same connection and restart ice? @@ -746,17 +787,22 @@ export class RobotClient extends EventDispatcher implements Robot { } else if (this.peerConn?.iceConnectionState === 'closed') { this.onDisconnect(); } - }); + }; + + this.peerConn.addEventListener( + 'iceconnectionstatechange', + this.onICEConnectionStateChange + ); + // There is not an iceconnectionstatechange nor connectionstatechange // event when the peerConn closes. Instead, listen to the data channel // closing and emit disconnect when that occurs. - webRTCConn.dataChannel.addEventListener('close', (event) => { - this.onDisconnect(event); - }); + this.onDataChannelClose = (event: Event) => this.onDisconnect(event); + this.dataChannel.addEventListener('close', this.onDataChannelClose); this.transport = webRTCConn.transport; - webRTCConn.peerConnection.addEventListener('track', (event) => { + this.onTrack = (event: RTCTrackEvent) => { const [eventStream] = event.streams; if (!eventStream) { this.emit('track', event); @@ -773,7 +819,9 @@ export class RobotClient extends EventDispatcher implements Robot { value: resName, }); this.emit('track', event); - }); + }; + + this.peerConn.addEventListener('track', this.onTrack); } else { this.transport = await dialDirect(this.serviceHost, opts); await this.gRPCConnectionManager.start(); @@ -795,12 +843,7 @@ export class RobotClient extends EventDispatcher implements Robot { } finally { this.connectResolve?.(); this.connectResolve = undefined; - - /* - * Lint disabled because we know that we are the only code to - * read and then write to 'connecting', even after we have awaited/paused. - */ - this.connecting = undefined; // eslint-disable-line require-atomic-updates + this.connecting = undefined; } } diff --git a/src/rpc/base-channel.ts b/src/rpc/base-channel.ts index dec7aa9f8..58005f7f6 100644 --- a/src/rpc/base-channel.ts +++ b/src/rpc/base-channel.ts @@ -56,6 +56,7 @@ export class BaseChannel { this.closedReason = err; this.pReject?.(err); this.peerConn.close(); + this.dataChannel.close(); } private onChannelOpen() { diff --git a/src/rpc/dial.spec.ts b/src/rpc/dial.spec.ts index 20955a51e..e2b9ff243 100644 --- a/src/rpc/dial.spec.ts +++ b/src/rpc/dial.spec.ts @@ -37,6 +37,7 @@ describe('dialWebRTC', () => { let mockCreateClient: ReturnType; let mockPeerConnectionClose: ReturnType; + let mockDataChannelClose: ReturnType; let mockExchangeDoExchange: ReturnType; let mockExchangeTerminate: ReturnType; @@ -52,9 +53,11 @@ describe('dialWebRTC', () => { removeEventListener: removeEventListenerFn, } as unknown as RTCPeerConnection; + mockDataChannelClose = vi.fn(); const dcAddEventListenerFn = vi.fn(); const dcRemoveEventListenerFn = vi.fn(); mockDataChannel = { + close: mockDataChannelClose, addEventListener: dcAddEventListenerFn, removeEventListener: dcRemoveEventListenerFn, } as unknown as RTCDataChannel; diff --git a/src/rpc/dial.ts b/src/rpc/dial.ts index fa076f31e..d1f0bd2be 100644 --- a/src/rpc/dial.ts +++ b/src/rpc/dial.ts @@ -394,6 +394,7 @@ export const dialWebRTC = async ( ); } catch (error) { pc.close(); + dc.close(); throw error; } @@ -446,6 +447,7 @@ export const dialWebRTC = async ( if (!successful) { pc.close(); + dc.close(); } } };