diff --git a/README.md b/README.md index 6c166cc4..e095099b 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ A [Trino](https://trino.io) client for [Node.js](https://nodejs.org/). ## Features - Connections over HTTP or HTTPS -- Supports HTTP Basic Authentication +- Supports Basic and OAuth2 authentication types - Per-query user information for access control ## Requirements @@ -67,6 +67,17 @@ const data: QueryData[] = await iter .fold([], (row, acc) => [...acc, ...row]); ``` +### Using OAuth2 Authentication + +```typescript +const trino: Trino = Trino.create({ + server: 'http://localhost:8080', + catalog: 'tpcds', + schema: 'sf100000', + auth: new OAuth2Auth('token', 'clientId', 'clientSecret', 'refreshToken', 'tokenEndpoint'), +}); +``` + ## Examples More usage examples can be found in the diff --git a/src/index.ts b/src/index.ts index 1928aaf3..6babf5b4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -35,6 +35,22 @@ export class BasicAuth implements Auth { constructor(readonly username: string, readonly password?: string) {} } +export class OAuth2Auth implements Auth { + readonly type: AuthType = 'oauth2'; + constructor( + readonly token: string, + readonly clientId?: string, + readonly clientSecret?: string, + readonly refreshToken?: string, + readonly tokenEndpoint?: string, + readonly scopes?: string[], + readonly tokenType?: string, + readonly expiresIn?: number, + readonly redirectUri?: string, + readonly grantType?: string + ) {} +} + export type Session = {[key: string]: string}; export type ExtraCredential = {[key: string]: string}; @@ -143,16 +159,6 @@ export type QueryInfo = { failureInfo?: QueryFailureInfo; }; -export type Query = { - query: string; - catalog?: string; - schema?: string; - user?: string; - session?: Session; - extraCredential?: ExtraCredential; - extraHeaders?: RequestHeaders; -}; - /** * It takes a Headers object and returns a new object with the same keys, but only the values that are * truthy @@ -196,14 +202,50 @@ class Client { ...(options.extraHeaders ?? {}), }; - if (options.auth && options.auth.type === 'basic') { - const basic: BasicAuth = options.auth; - clientConfig.auth = { - username: basic.username, - password: basic.password ?? '', - }; - - headers[TRINO_USER_HEADER] = basic.username; + if (options.auth) { + switch (options.auth.type) { + case 'basic': + const basic: BasicAuth = options.auth; + clientConfig.auth = { + username: basic.username, + password: basic.password ?? '', + }; + headers[TRINO_USER_HEADER] = basic.username; + break; + case 'oauth2': + const oauth2: OAuth2Auth = options.auth; + headers['Authorization'] = `Bearer ${oauth2.token}`; + if (oauth2.clientId) { + headers['Client-Id'] = oauth2.clientId; + } + if (oauth2.clientSecret) { + headers['Client-Secret'] = oauth2.clientSecret; + } + if (oauth2.refreshToken) { + headers['Refresh-Token'] = oauth2.refreshToken; + } + if (oauth2.tokenEndpoint) { + headers['Token-Endpoint'] = oauth2.tokenEndpoint; + } + if (oauth2.scopes) { + headers['Scopes'] = oauth2.scopes.join(' '); + } + if (oauth2.tokenType) { + headers['Token-Type'] = oauth2.tokenType; + } + if (oauth2.expiresIn) { + headers['Expires-In'] = oauth2.expiresIn.toString(); + } + if (oauth2.redirectUri) { + headers['Redirect-Uri'] = oauth2.redirectUri; + } + if (oauth2.grantType) { + headers['Grant-Type'] = oauth2.grantType; + } + break; + default: + throw new Error(`Unsupported auth type: ${options.auth.type}`); + } } clientConfig.headers = cleanHeaders(headers); diff --git a/tests/it/client.spec.ts b/tests/it/client.spec.ts index 1510d032..cef7b1c1 100644 --- a/tests/it/client.spec.ts +++ b/tests/it/client.spec.ts @@ -1,4 +1,4 @@ -import {BasicAuth, QueryData, Trino} from '../../src'; +import {BasicAuth, OAuth2Auth, QueryData, Trino} from '../../src'; const allCustomerQuery = 'select * from customer'; const limit = 1; @@ -175,4 +175,19 @@ describe('trino', () => { ]); expect(sales).toHaveLength(limit); }); + + test.concurrent('oauth2 auth', async () => { + const trino = Trino.create({ + catalog: 'tpcds', + schema: 'sf100000', + auth: new OAuth2Auth('token'), + }); + + const iter = await trino.query(singleCustomerQuery); + const data = await iter + .map(r => r.data ?? []) + .fold([], (row, acc) => [...acc, ...row]); + + expect(data).toHaveLength(limit); + }); });