Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions api/server/controllers/UserController.js
Original file line number Diff line number Diff line change
Expand Up @@ -327,30 +327,43 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
const revocationEndpointAuthMethodsSupported =
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
clientMetadata.revocation_endpoint_auth_methods_supported;
const oauthHeaders = serverConfig.oauth_headers ?? {};

if (tokens?.access_token) {
try {
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
});
await MCPOAuthHandler.revokeOAuthToken(
serverName,
tokens.access_token,
'access',
{
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
},
oauthHeaders,
);
} catch (error) {
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
}
}

if (tokens?.refresh_token) {
try {
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
});
await MCPOAuthHandler.revokeOAuthToken(
serverName,
tokens.refresh_token,
'refresh',
{
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
},
oauthHeaders,
);
} catch (error) {
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
}
Expand Down
16 changes: 16 additions & 0 deletions api/server/routes/__tests__/mcp.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,13 @@ describe('MCP Routes', () => {
}),
};

const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({}),
};

getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth',
Expand All @@ -146,6 +151,7 @@ describe('MCP Routes', () => {
'test-server',
'https://test-server.com',
'test-user-id',
{},
{ clientId: 'test-client-id' },
);
});
Expand Down Expand Up @@ -314,6 +320,7 @@ describe('MCP Routes', () => {
};
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

Expand All @@ -336,6 +343,7 @@ describe('MCP Routes', () => {
'test-flow-id',
'test-auth-code',
mockFlowManager,
{},
);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
Expand Down Expand Up @@ -392,6 +400,11 @@ describe('MCP Routes', () => {
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);

const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
Expand Down Expand Up @@ -427,6 +440,7 @@ describe('MCP Routes', () => {

const mockMcpManager = {
getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

Expand Down Expand Up @@ -1234,6 +1248,7 @@ describe('MCP Routes', () => {
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

Expand Down Expand Up @@ -1281,6 +1296,7 @@ describe('MCP Routes', () => {
.fn()
.mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]),
}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

Expand Down
14 changes: 13 additions & 1 deletion api/server/routes/mcp.js
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
serverName,
serverUrl,
userId,
getOAuthHeaders(serverName),
oauthConfig,
);

Expand Down Expand Up @@ -132,7 +133,12 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
});

logger.debug('[MCP OAuth] Completing OAuth flow');
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
const tokens = await MCPOAuthHandler.completeOAuthFlow(
flowId,
code,
flowManager,
getOAuthHeaders(serverName),
);
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');

/** Persist tokens immediately so reconnection uses fresh credentials */
Expand Down Expand Up @@ -538,4 +544,10 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
}
});

function getOAuthHeaders(serverName) {
const mcpManager = getMCPManager();
const serverConfig = mcpManager.getRawConfig(serverName);
return serverConfig?.oauth_headers ?? {};
}

module.exports = router;
3 changes: 3 additions & 0 deletions packages/api/src/mcp/MCPConnectionFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ export class MCPConnectionFactory {
serverName: metadata.serverName,
clientInfo: metadata.clientInfo,
},
this.serverConfig.oauth_headers ?? {},
this.serverConfig.oauth,
);
};
Expand All @@ -161,6 +162,7 @@ export class MCPConnectionFactory {
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth_headers ?? {},
config?.oauth,
);

Expand Down Expand Up @@ -358,6 +360,7 @@ export class MCPConnectionFactory {
this.serverName,
serverUrl,
this.userId!,
this.serverConfig.oauth_headers ?? {},
this.serverConfig.oauth,
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ describe('MCPConnectionFactory', () => {
'test-server',
'https://api.example.com',
'user123',
{},
undefined,
);
expect(oauthOptions.oauthStart).toHaveBeenCalledWith('https://auth.example.com');
Expand Down
Loading