diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index dc38c59721c1..31295387ed25 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -327,16 +327,23 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { const revocationEndpointAuthMethodsSupported = serverConfig.oauth?.revocation_endpoint_auth_methods_supported ?? clientMetadata.revocation_endpoint_auth_methods_supported; + const oauthHeaders = serverConfig.oauth_headers ?? {}; if (tokens?.access_token) { try { - await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', { - serverUrl: serverConfig.url, - clientId: clientInfo.client_id, - clientSecret: clientInfo.client_secret ?? '', - revocationEndpoint, - revocationEndpointAuthMethodsSupported, - }); + await MCPOAuthHandler.revokeOAuthToken( + serverName, + tokens.access_token, + 'access', + { + serverUrl: serverConfig.url, + clientId: clientInfo.client_id, + clientSecret: clientInfo.client_secret ?? '', + revocationEndpoint, + revocationEndpointAuthMethodsSupported, + }, + oauthHeaders, + ); } catch (error) { logger.error(`Error revoking OAuth access token for ${serverName}:`, error); } @@ -344,13 +351,19 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { if (tokens?.refresh_token) { try { - await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', { - serverUrl: serverConfig.url, - clientId: clientInfo.client_id, - clientSecret: clientInfo.client_secret ?? '', - revocationEndpoint, - revocationEndpointAuthMethodsSupported, - }); + await MCPOAuthHandler.revokeOAuthToken( + serverName, + tokens.refresh_token, + 'refresh', + { + serverUrl: serverConfig.url, + clientId: clientInfo.client_id, + clientSecret: clientInfo.client_secret ?? '', + revocationEndpoint, + revocationEndpointAuthMethodsSupported, + }, + oauthHeaders, + ); } catch (error) { logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error); } diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 0df28d7b10c8..64c95c58eead 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -127,8 +127,13 @@ describe('MCP Routes', () => { }), }; + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue({}), + }; + getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', @@ -146,6 +151,7 @@ describe('MCP Routes', () => { 'test-server', 'https://test-server.com', 'test-user-id', + {}, { clientId: 'test-client-id' }, ); }); @@ -314,6 +320,7 @@ describe('MCP Routes', () => { }; const mockMcpManager = { getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), + getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -336,6 +343,7 @@ describe('MCP Routes', () => { 'test-flow-id', 'test-auth-code', mockFlowManager, + {}, ); expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith( expect.objectContaining({ @@ -392,6 +400,11 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); + const mockMcpManager = { + getRawConfig: jest.fn().mockReturnValue({}), + }; + require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ code: 'test-auth-code', state: 'test-flow-id', @@ -427,6 +440,7 @@ describe('MCP Routes', () => { const mockMcpManager = { getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')), + getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -1234,6 +1248,7 @@ describe('MCP Routes', () => { getUserConnection: jest.fn().mockResolvedValue({ fetchTools: jest.fn().mockResolvedValue([]), }), + getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -1281,6 +1296,7 @@ describe('MCP Routes', () => { .fn() .mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]), }), + getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index b1022136e33a..e8415fd801dc 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -65,6 +65,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { serverName, serverUrl, userId, + getOAuthHeaders(serverName), oauthConfig, ); @@ -132,7 +133,12 @@ router.get('/:serverName/oauth/callback', async (req, res) => { }); logger.debug('[MCP OAuth] Completing OAuth flow'); - const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager); + const tokens = await MCPOAuthHandler.completeOAuthFlow( + flowId, + code, + flowManager, + getOAuthHeaders(serverName), + ); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); /** Persist tokens immediately so reconnection uses fresh credentials */ @@ -538,4 +544,10 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { } }); +function getOAuthHeaders(serverName) { + const mcpManager = getMCPManager(); + const serverConfig = mcpManager.getRawConfig(serverName); + return serverConfig?.oauth_headers ?? {}; +} + module.exports = router; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 6785de748f07..5f4447b2bdff 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -142,6 +142,7 @@ export class MCPConnectionFactory { serverName: metadata.serverName, clientInfo: metadata.clientInfo, }, + this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, ); }; @@ -161,6 +162,7 @@ export class MCPConnectionFactory { this.serverName, data.serverUrl || '', this.userId!, + config?.oauth_headers ?? {}, config?.oauth, ); @@ -358,6 +360,7 @@ export class MCPConnectionFactory { this.serverName, serverUrl, this.userId!, + this.serverConfig.oauth_headers ?? {}, this.serverConfig.oauth, ); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index e96d207f294e..9ee05dfb274d 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -255,6 +255,7 @@ describe('MCPConnectionFactory', () => { 'test-server', 'https://api.example.com', 'user123', + {}, undefined, ); expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com'); diff --git a/packages/api/src/mcp/__tests__/handler.test.ts b/packages/api/src/mcp/__tests__/handler.test.ts index 01794fe4db0f..24e8c5ddb4ba 100644 --- a/packages/api/src/mcp/__tests__/handler.test.ts +++ b/packages/api/src/mcp/__tests__/handler.test.ts @@ -1,6 +1,6 @@ import type { MCPOptions } from 'librechat-data-provider'; import type { AuthorizationServerMetadata } from '@modelcontextprotocol/sdk/shared/auth.js'; -import { MCPOAuthHandler } from '~/mcp/oauth'; +import { MCPOAuthFlowMetadata, MCPOAuthHandler, MCPOAuthTokens } from '~/mcp/oauth'; jest.mock('@librechat/data-schemas', () => ({ logger: { @@ -14,18 +14,33 @@ jest.mock('@librechat/data-schemas', () => ({ jest.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({ startAuthorization: jest.fn(), discoverAuthorizationServerMetadata: jest.fn(), + discoverOAuthProtectedResourceMetadata: jest.fn(), + registerClient: jest.fn(), + exchangeAuthorization: jest.fn(), })); import { startAuthorization, discoverAuthorizationServerMetadata, + discoverOAuthProtectedResourceMetadata, + registerClient, + exchangeAuthorization, } from '@modelcontextprotocol/sdk/client/auth.js'; +import { FlowStateManager } from '../../flow/manager'; const mockStartAuthorization = startAuthorization as jest.MockedFunction; const mockDiscoverAuthorizationServerMetadata = discoverAuthorizationServerMetadata as jest.MockedFunction< typeof discoverAuthorizationServerMetadata >; +const mockDiscoverOAuthProtectedResourceMetadata = + discoverOAuthProtectedResourceMetadata as jest.MockedFunction< + typeof discoverOAuthProtectedResourceMetadata + >; +const mockRegisterClient = registerClient as jest.MockedFunction; +const mockExchangeAuthorization = exchangeAuthorization as jest.MockedFunction< + typeof exchangeAuthorization +>; describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { const mockServerName = 'test-server'; @@ -60,6 +75,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { mockServerName, mockServerUrl, mockUserId, + {}, baseConfig, ); @@ -82,7 +98,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { grant_types_supported: ['authorization_code'], }; - await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config); + await MCPOAuthHandler.initiateOAuthFlow( + mockServerName, + mockServerUrl, + mockUserId, + {}, + config, + ); expect(mockStartAuthorization).toHaveBeenCalledWith( mockServerUrl, @@ -100,7 +122,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { token_endpoint_auth_methods_supported: ['client_secret_post'], }; - await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config); + await MCPOAuthHandler.initiateOAuthFlow( + mockServerName, + mockServerUrl, + mockUserId, + {}, + config, + ); expect(mockStartAuthorization).toHaveBeenCalledWith( mockServerUrl, @@ -118,7 +146,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { response_types_supported: ['code', 'token'], }; - await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config); + await MCPOAuthHandler.initiateOAuthFlow( + mockServerName, + mockServerUrl, + mockUserId, + {}, + config, + ); expect(mockStartAuthorization).toHaveBeenCalledWith( mockServerUrl, @@ -136,7 +170,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { code_challenge_methods_supported: ['S256'], }; - await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config); + await MCPOAuthHandler.initiateOAuthFlow( + mockServerName, + mockServerUrl, + mockUserId, + {}, + config, + ); expect(mockStartAuthorization).toHaveBeenCalledWith( mockServerUrl, @@ -157,7 +197,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { code_challenge_methods_supported: ['S256'], }; - await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config); + await MCPOAuthHandler.initiateOAuthFlow( + mockServerName, + mockServerUrl, + mockUserId, + {}, + config, + ); expect(mockStartAuthorization).toHaveBeenCalledWith( mockServerUrl, @@ -181,7 +227,13 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { code_challenge_methods_supported: [], }; - await MCPOAuthHandler.initiateOAuthFlow(mockServerName, mockServerUrl, mockUserId, config); + await MCPOAuthHandler.initiateOAuthFlow( + mockServerName, + mockServerUrl, + mockUserId, + {}, + config, + ); expect(mockStartAuthorization).toHaveBeenCalledWith( mockServerUrl, @@ -251,7 +303,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }), } as Response); - const result = await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata); + const result = await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}); // Verify the call was made without Authorization header expect(mockFetch).toHaveBeenCalledWith( @@ -314,7 +366,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }), } as Response); - await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata); + await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}); const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`; expect(mockFetch).toHaveBeenCalledWith( @@ -363,7 +415,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }), } as Response); - await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata); + await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}); const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`; expect(mockFetch).toHaveBeenCalledWith( @@ -410,7 +462,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }), } as Response); - await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata); + await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}); const expectedAuth = `Basic ${Buffer.from('test-client-id:test-client-secret').toString('base64')}`; expect(mockFetch).toHaveBeenCalledWith( @@ -457,7 +509,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { }), } as Response); - await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata); + await MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}); // Verify the call was made without Authorization header expect(mockFetch).toHaveBeenCalledWith( @@ -498,6 +550,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { await MCPOAuthHandler.refreshOAuthTokens( mockRefreshToken, { serverName: 'test-server' }, + {}, config, ); @@ -539,6 +592,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { await MCPOAuthHandler.refreshOAuthTokens( mockRefreshToken, { serverName: 'test-server' }, + {}, config, ); @@ -575,6 +629,7 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { await MCPOAuthHandler.refreshOAuthTokens( mockRefreshToken, { serverName: 'test-server' }, + {}, config, ); @@ -617,7 +672,9 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { '{"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}', } as Response); - await expect(MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata)).rejects.toThrow( + await expect( + MCPOAuthHandler.refreshOAuthTokens(mockRefreshToken, metadata, {}, {}), + ).rejects.toThrow( 'Token refresh failed: 400 Bad Request - {"error":"invalid_request","error_description":"refresh_token.client_id: Field required"}', ); }); @@ -813,4 +870,126 @@ describe('MCPOAuthHandler - Configurable OAuth Metadata', () => { ); }); }); + + describe('Custom OAuth Headers', () => { + const originalFetch = global.fetch; + const mockFetch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + global.fetch = mockFetch as unknown as typeof fetch; + mockFetch.mockResolvedValue({ ok: true, json: async () => ({}) } as Response); + mockDiscoverAuthorizationServerMetadata.mockResolvedValue({ + issuer: 'http://example.com', + authorization_endpoint: 'http://example.com/auth', + token_endpoint: 'http://example.com/token', + response_types_supported: ['code'], + } as AuthorizationServerMetadata); + mockStartAuthorization.mockResolvedValue({ + authorizationUrl: new URL('http://example.com/auth'), + codeVerifier: 'test-verifier', + }); + }); + + afterAll(() => { + global.fetch = originalFetch; + }); + + it('passes headers to client registration', async () => { + mockRegisterClient.mockImplementation(async (_, options) => { + await options.fetchFn?.('http://example.com/register', {}); + return { client_id: 'test', redirect_uris: [] }; + }); + + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'http://example.com', + 'user-123', + { foo: 'bar' }, + {}, + ); + + const headers = mockFetch.mock.calls[0][1]?.headers as Headers; + expect(headers.get('foo')).toBe('bar'); + }); + + it('passes headers to discovery operations', async () => { + mockDiscoverOAuthProtectedResourceMetadata.mockImplementation(async (_, __, fetchFn) => { + await fetchFn?.('http://example.com/.well-known/oauth-protected-resource', {}); + return { + resource: 'http://example.com', + authorization_servers: ['http://auth.example.com'], + }; + }); + + await MCPOAuthHandler.initiateOAuthFlow( + 'test-server', + 'http://example.com', + 'user-123', + { foo: 'bar' }, + {}, + ); + + const allHaveHeader = mockFetch.mock.calls.every((call) => { + const headers = call[1]?.headers as Headers; + return headers?.get('foo') === 'bar'; + }); + expect(allHaveHeader).toBe(true); + }); + + it('passes headers to token exchange', async () => { + const mockFlowManager = { + getFlowState: jest.fn().mockResolvedValue({ + status: 'PENDING', + metadata: { + serverName: 'test-server', + codeVerifier: 'test-verifier', + clientInfo: {}, + metadata: {}, + } as MCPOAuthFlowMetadata, + }), + completeFlow: jest.fn(), + } as unknown as FlowStateManager; + + mockExchangeAuthorization.mockImplementation(async (_, options) => { + await options.fetchFn?.('http://example.com/token', {}); + return { access_token: 'test-token', token_type: 'Bearer', expires_in: 3600 }; + }); + + await MCPOAuthHandler.completeOAuthFlow('test-flow-id', 'test-auth-code', mockFlowManager, { + foo: 'bar', + }); + + const headers = mockFetch.mock.calls[0][1]?.headers as Headers; + expect(headers.get('foo')).toBe('bar'); + }); + + it('passes headers to token refresh', async () => { + mockDiscoverAuthorizationServerMetadata.mockImplementation(async (_, options) => { + await options?.fetchFn?.('http://example.com/.well-known/oauth-authorization-server', {}); + return { + issuer: 'http://example.com', + token_endpoint: 'http://example.com/token', + } as AuthorizationServerMetadata; + }); + + await MCPOAuthHandler.refreshOAuthTokens( + 'test-refresh-token', + { + serverName: 'test-server', + serverUrl: 'http://example.com', + clientInfo: { client_id: 'test-client', client_secret: 'test-secret' }, + }, + { foo: 'bar' }, + {}, + ); + + const discoveryCall = mockFetch.mock.calls.find((call) => + call[0].toString().includes('.well-known'), + ); + expect(discoveryCall).toBeDefined(); + const headers = discoveryCall![1]?.headers as Headers; + expect(headers.get('foo')).toBe('bar'); + }); + }); }); diff --git a/packages/api/src/mcp/oauth/handler.ts b/packages/api/src/mcp/oauth/handler.ts index a96dae844229..896d199b6dd0 100644 --- a/packages/api/src/mcp/oauth/handler.ts +++ b/packages/api/src/mcp/oauth/handler.ts @@ -18,6 +18,7 @@ import type { OAuthMetadata, } from './types'; import { sanitizeUrlForLogging } from '~/mcp/utils'; +import { FetchLike } from '@modelcontextprotocol/sdk/shared/transport'; /** Type for the OAuth metadata from the SDK */ type SDKOAuthMetadata = Parameters[1]['metadata']; @@ -26,10 +27,29 @@ export class MCPOAuthHandler { private static readonly FLOW_TYPE = 'mcp_oauth'; private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes + /** + * Creates a fetch function with custom headers injected + */ + private static createOAuthFetch(headers: Record): FetchLike { + return async (url: string | URL, init?: RequestInit): Promise => { + const newHeaders = new Headers(init?.headers ?? {}); + for (const [key, value] of Object.entries(headers)) { + newHeaders.set(key, value); + } + return fetch(url, { + ...init, + headers: newHeaders, + }); + }; + } + /** * Discovers OAuth metadata from the server */ - private static async discoverMetadata(serverUrl: string): Promise<{ + private static async discoverMetadata( + serverUrl: string, + oauthHeaders: Record, + ): Promise<{ metadata: OAuthMetadata; resourceMetadata?: OAuthProtectedResourceMetadata; authServerUrl: URL; @@ -41,12 +61,14 @@ export class MCPOAuthHandler { let authServerUrl = new URL(serverUrl); let resourceMetadata: OAuthProtectedResourceMetadata | undefined; + const fetchFn = this.createOAuthFetch(oauthHeaders); + try { // Try to discover resource metadata first logger.debug( `[MCPOAuth] Attempting to discover protected resource metadata from ${serverUrl}`, ); - resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl); + resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {}, fetchFn); if (resourceMetadata?.authorization_servers?.length) { authServerUrl = new URL(resourceMetadata.authorization_servers[0]); @@ -66,7 +88,9 @@ export class MCPOAuthHandler { logger.debug( `[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`, ); - const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl); + const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl, { + fetchFn, + }); if (!rawMetadata) { logger.error( @@ -92,6 +116,7 @@ export class MCPOAuthHandler { private static async registerOAuthClient( serverUrl: string, metadata: OAuthMetadata, + oauthHeaders: Record, resourceMetadata?: OAuthProtectedResourceMetadata, redirectUri?: string, ): Promise { @@ -159,6 +184,7 @@ export class MCPOAuthHandler { const clientInfo = await registerClient(serverUrl, { metadata: metadata as unknown as SDKOAuthMetadata, clientMetadata, + fetchFn: this.createOAuthFetch(oauthHeaders), }); logger.debug( @@ -181,7 +207,8 @@ export class MCPOAuthHandler { serverName: string, serverUrl: string, userId: string, - config: MCPOptions['oauth'] | undefined, + oauthHeaders: Record, + config?: MCPOptions['oauth'], ): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> { logger.debug( `[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`, @@ -259,7 +286,10 @@ export class MCPOAuthHandler { logger.debug( `[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`, ); - const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl); + const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata( + serverUrl, + oauthHeaders, + ); logger.debug( `[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`, @@ -272,6 +302,7 @@ export class MCPOAuthHandler { const clientInfo = await this.registerOAuthClient( authServerUrl.toString(), metadata, + oauthHeaders, resourceMetadata, redirectUri, ); @@ -365,6 +396,7 @@ export class MCPOAuthHandler { flowId: string, authorizationCode: string, flowManager: FlowStateManager, + oauthHeaders: Record, ): Promise { try { /** Flow state which contains our metadata */ @@ -404,6 +436,7 @@ export class MCPOAuthHandler { codeVerifier: metadata.codeVerifier, authorizationCode, resource, + fetchFn: this.createOAuthFetch(oauthHeaders), }); logger.debug('[MCPOAuth] Raw tokens from exchange:', { @@ -476,6 +509,7 @@ export class MCPOAuthHandler { static async refreshOAuthTokens( refreshToken: string, metadata: { serverName: string; serverUrl?: string; clientInfo?: OAuthClientInformation }, + oauthHeaders: Record, config?: MCPOptions['oauth'], ): Promise { logger.debug(`[MCPOAuth] Refreshing tokens for ${metadata.serverName}`); @@ -509,7 +543,9 @@ export class MCPOAuthHandler { throw new Error('No token URL available for refresh'); } else { /** Auto-discover OAuth configuration for refresh */ - const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl); + const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, { + fetchFn: this.createOAuthFetch(oauthHeaders), + }); if (!oauthMetadata) { throw new Error('Failed to discover OAuth metadata for token refresh'); } @@ -533,6 +569,7 @@ export class MCPOAuthHandler { const headers: HeadersInit = { 'Content-Type': 'application/x-www-form-urlencoded', Accept: 'application/json', + ...oauthHeaders, }; /** Handle authentication based on server's advertised methods */ @@ -613,6 +650,7 @@ export class MCPOAuthHandler { const headers: HeadersInit = { 'Content-Type': 'application/x-www-form-urlencoded', Accept: 'application/json', + ...oauthHeaders, }; /** Handle authentication based on configured methods */ @@ -684,7 +722,9 @@ export class MCPOAuthHandler { } /** Auto-discover OAuth configuration for refresh */ - const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl); + const oauthMetadata = await discoverAuthorizationServerMetadata(metadata.serverUrl, { + fetchFn: this.createOAuthFetch(oauthHeaders), + }); if (!oauthMetadata?.token_endpoint) { throw new Error('No token endpoint found in OAuth metadata'); @@ -700,6 +740,7 @@ export class MCPOAuthHandler { const headers: HeadersInit = { 'Content-Type': 'application/x-www-form-urlencoded', Accept: 'application/json', + ...oauthHeaders, }; const response = await fetch(tokenUrl, { @@ -742,6 +783,7 @@ export class MCPOAuthHandler { revocationEndpoint?: string; revocationEndpointAuthMethodsSupported?: string[]; }, + oauthHeaders: Record = {}, ): Promise { // build the revoke URL, falling back to the server URL + /revoke if no revocation endpoint is provided const revokeUrl: URL = @@ -759,6 +801,7 @@ export class MCPOAuthHandler { // init the request headers const headers: Record = { 'Content-Type': 'application/x-www-form-urlencoded', + ...oauthHeaders, }; // init the request body diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index 58d70ac118e7..72299e96a5e7 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -62,6 +62,8 @@ const BaseOptionsSchema = z.object({ revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), }) .optional(), + /** Custom headers to send with OAuth requests (registration, discovery, token exchange, etc.) */ + oauth_headers: z.record(z.string(), z.string()).optional(), customUserVars: z .record( z.string(),