Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion messages/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@
"public": "Featured",
"publicDescription": "Featured for all users to use",
"featured": "Featured",
"featuredDescription": "Featured for all users to use"
"featuredDescription": "Featured for all users to use",
"perUserAuthentication": "Per-User Authentication",
"perUserAuthenticationDescription": "Each user must authenticate separately with their own credentials"
}
}
1 change: 1 addition & 0 deletions src/app/(chat)/mcp/modify/[id]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export default async function Page({
initialConfig={mcpClient.config}
name={mcpClient.name}
id={mcpClient.id}
perUserAuth={mcpClient.perUserAuth}
/>
) : (
<Alert variant="destructive">MCP client not found</Alert>
Expand Down
6 changes: 4 additions & 2 deletions src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,16 @@ export async function POST(request: Request) {

const stream = createUIMessageStream({
execute: async ({ writer: dataStream }) => {
const mcpClients = await mcpClientsManager.getClients();
const mcpTools = await mcpClientsManager.tools();
const mcpClients = await mcpClientsManager.getClients(session.user.id);
const mcpTools = await mcpClientsManager.tools(session.user.id);
logger.info(
`mcp-server count: ${mcpClients.length}, mcp-tools count :${Object.keys(mcpTools).length}`,
);
const MCP_TOOLS = await safe()
.map(errorIf(() => !isToolCallAllowed && "Not allowed"))
.map(() =>
loadMcpTools({
userId: session.user.id,
mentions,
allowedMcpServers,
}),
Expand Down Expand Up @@ -245,6 +246,7 @@ export async function POST(request: Request) {
const output = await manualToolExecuteByLastMessage(
part,
{ ...MCP_TOOLS, ...WORKFLOW_TOOLS, ...APP_DEFAULT_TOOLS },
session.user.id,
request.signal,
);
part.output = output;
Expand Down
5 changes: 4 additions & 1 deletion src/app/api/chat/shared.chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ export function mergeSystemPrompt(
export function manualToolExecuteByLastMessage(
part: ToolUIPart,
tools: Record<string, VercelAIMcpTool | VercelAIWorkflowTool | Tool>,
userId?: string,
abortSignal?: AbortSignal,
) {
const { input } = part;
Expand All @@ -141,6 +142,7 @@ export function manualToolExecuteByLastMessage(
tool._mcpServerId,
tool._originToolName,
input,
userId,
);
}
return tool.execute!(input, {
Expand Down Expand Up @@ -394,10 +396,11 @@ export const workflowToVercelAITools = (
};

export const loadMcpTools = (opt?: {
userId?: string;
mentions?: ChatMention[];
allowedMcpServers?: Record<string, AllowedMCPServer>;
}) =>
safe(() => mcpClientsManager.tools())
safe(() => mcpClientsManager.tools(opt?.userId))
.map((tools) => {
if (opt?.mentions?.length) {
return filterMCPToolsByMentions(tools, opt.mentions);
Expand Down
124 changes: 100 additions & 24 deletions src/app/api/mcp/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,61 @@ export async function selectMcpClientsAction() {
const accessibleServers = await mcpRepository.selectAllForUser(
currentUser.id,
);
const accessibleIds = new Set(accessibleServers.map((s) => s.id));

// Warm up clients for the current user
await Promise.allSettled(
accessibleServers.map((server) =>
mcpClientsManager.getClient(server.id, currentUser.id),
),
);

// Get all active clients and filter to only accessible ones
const list = await mcpClientsManager.getClients();
return list
.filter(({ id }) => accessibleIds.has(id))
.map(({ client, id }) => {
const server = accessibleServers.find((s) => s.id === id);
return {
...client.getInfo(),
id,
userId: server?.userId,
visibility: server?.visibility,
isOwner: server?.userId === currentUser.id,
canManage: server
? server.userId === currentUser.id || currentUser.role === "admin"
: false,
};
});
const list = await mcpClientsManager.getClients(currentUser.id);
const activeClientsMap = new Map(list.map((c) => [c.id, c.client]));

// Check authorization status for per-user auth servers
const authStatuses = await Promise.all(
accessibleServers.map(async (server) => {
if (!server.perUserAuth) return { id: server.id, isAuthorized: true };
const session = await mcpOAuthRepository.getAuthenticatedSession(
server.id,
currentUser.id,
);
return { id: server.id, isAuthorized: !!session?.tokens };
}),
);
const authStatusMap = new Map(
authStatuses.map((s) => [s.id, s.isAuthorized]),
);

return accessibleServers.map((server) => {
const client = activeClientsMap.get(server.id);
const info = client?.getInfo();

return {
id: server.id,
name: server.name,
config: server.userId === currentUser.id ? server.config : undefined,
status: info?.status ?? ("disconnected" as const),
enabled: info?.enabled ?? true,
userId: server.userId,
visibility: server.visibility,
perUserAuth: server.perUserAuth ?? false,
isAuthorized: authStatusMap.get(server.id),
toolInfo:
info?.toolInfo && info.toolInfo.length > 0
? info.toolInfo
: (server.toolInfo ?? []),
isOwner: server.userId === currentUser.id,
canManage:
server.userId === currentUser.id || currentUser.role === "admin",
};
});
}

export async function selectMcpClientAction(id: string) {
const client = await mcpClientsManager.getClient(id);
const currentUser = await getCurrentUser();
const client = await mcpClientsManager.getClient(id, currentUser?.id);
if (!client) {
throw new Error("Client not found");
}
Expand Down Expand Up @@ -105,6 +137,7 @@ export async function saveMcpClientAction(
...server,
userId: currentUser.id,
visibility: server.visibility || "private",
toolInfo: server.toolInfo ?? undefined,
};

return mcpClientsManager.persistClient(serverWithUser);
Expand Down Expand Up @@ -134,23 +167,29 @@ export async function removeMcpClientAction(id: string) {
}

export async function refreshMcpClientAction(id: string) {
await mcpClientsManager.refreshClient(id);
const currentUser = await getCurrentUser();
await mcpClientsManager.refreshClient(id, currentUser?.id);
}

export async function authorizeMcpClientAction(id: string) {
const currentUser = await getCurrentUser();
await refreshMcpClientAction(id);
const client = await mcpClientsManager.getClient(id);
const client = await mcpClientsManager.getClient(id, currentUser?.id);
if (client?.client.status != "authorizing") {
throw new Error("Not Authorizing");
}
return client.client.getAuthorizationUrl()?.toString();
}

export async function checkTokenMcpClientAction(id: string) {
const session = await mcpOAuthRepository.getAuthenticatedSession(id);
const currentUser = await getCurrentUser();
const session = await mcpOAuthRepository.getAuthenticatedSession(
id,
currentUser?.id,
);

// for wait connect to mcp server
await mcpClientsManager.getClient(id).catch(() => null);
await mcpClientsManager.getClient(id, currentUser?.id).catch(() => null);

return !!session?.tokens;
}
Expand All @@ -160,15 +199,22 @@ export async function callMcpToolAction(
toolName: string,
input: unknown,
) {
return mcpClientsManager.toolCall(id, toolName, input);
const currentUser = await getCurrentUser();
return mcpClientsManager.toolCall(id, toolName, input, currentUser?.id);
}

export async function callMcpToolByServerNameAction(
serverName: string,
toolName: string,
input: unknown,
) {
return mcpClientsManager.toolCallByServerName(serverName, toolName, input);
const currentUser = await getCurrentUser();
return mcpClientsManager.toolCallByServerName(
serverName,
toolName,
input,
currentUser?.id,
);
}

export async function shareMcpServerAction(
Expand All @@ -186,3 +232,33 @@ export async function shareMcpServerAction(

return { success: true };
}

export async function updatePerUserAuthAction(
id: string,
perUserAuth: boolean,
) {
// Get the MCP server to check ownership
const mcpServer = await mcpRepository.selectById(id);
if (!mcpServer) {
throw new Error("MCP server not found");
}

// Check if user has permission to manage this specific MCP server
const canManage = await canManageMCPServer(
mcpServer.userId,
mcpServer.visibility,
);
if (!canManage) {
throw new Error(
"You don't have permission to update this MCP connection settings",
);
}

// Update the perUserAuth of the MCP server
await mcpRepository.updatePerUserAuth(id, perUserAuth);

// Refresh the client to apply changes
await mcpClientsManager.refreshClient(id);

return { success: true };
}
59 changes: 43 additions & 16 deletions src/app/api/mcp/list/route.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { MCPServerInfo } from "app-types/mcp";
import { mcpClientsManager } from "lib/ai/mcp/mcp-manager";
import { mcpRepository } from "lib/db/repository";
import { mcpOAuthRepository, mcpRepository } from "lib/db/repository";
import { getCurrentUser } from "lib/auth/permissions";

export async function GET() {
Expand All @@ -10,47 +10,74 @@ export async function GET() {
return Response.json({ error: "Unauthorized" }, { status: 401 });
}

const [servers, memoryClients] = await Promise.all([
mcpRepository.selectAllForUser(currentUser.id),
mcpClientsManager.getClients(),
]);
const servers = await mcpRepository.selectAllForUser(currentUser.id);
const memoryClientsBefore = await mcpClientsManager.getClients(
currentUser.id,
);

const memoryMap = new Map(
memoryClients.map(({ id, client }) => [id, client] as const),
memoryClientsBefore.map(({ id, client }) => [id, client] as const),
);

const addTargets = servers.filter((server) => !memoryMap.has(server.id));

const serverIds = new Set(servers.map((s) => s.id));
const removeTargets = memoryClients.filter(({ id }) => !serverIds.has(id));
const removeTargets = memoryClientsBefore.filter(
({ id }) => !serverIds.has(id),
);

if (addTargets.length > 0) {
// no need to wait for this
Promise.allSettled(
addTargets.map((server) => mcpClientsManager.refreshClient(server.id)),
await Promise.allSettled(
addTargets.map((server) =>
mcpClientsManager.refreshClient(server.id, currentUser.id),
),
);
}
if (removeTargets.length > 0) {
// no need to wait for this
Promise.allSettled(
await Promise.allSettled(
removeTargets.map((client) =>
mcpClientsManager.disconnectClient(client.id),
mcpClientsManager.disconnectClient(client.clientId),
),
);
}

// Fetch again to get updated statuses
const memoryClients = await mcpClientsManager.getClients(currentUser.id);
const updatedMemoryMap = new Map(
memoryClients.map(({ id, client }) => [id, client] as const),
);

// Check authorization status for per-user auth servers
const authStatuses = await Promise.all(
servers.map(async (server) => {
if (!server.perUserAuth) return { id: server.id, isAuthorized: true };
const session = await mcpOAuthRepository.getAuthenticatedSession(
server.id,
currentUser.id,
);
return { id: server.id, isAuthorized: !!session?.tokens };
}),
);
const authStatusMap = new Map(
authStatuses.map((s) => [s.id, s.isAuthorized]),
);

const result = servers.map((server) => {
const mem = memoryMap.get(server.id);
const mem = updatedMemoryMap.get(server.id);
const info = mem?.getInfo();
const isOwner = server.userId === currentUser.id;
const mcpInfo: MCPServerInfo = {
...server,
// Hide config from non-owners to prevent credential exposure
config: isOwner ? server.config : undefined,
enabled: info?.enabled ?? true,
status: info?.status ?? "connected",
status: info?.status ?? "disconnected",
error: info?.error,
toolInfo: info?.toolInfo ?? [],
isAuthorized: authStatusMap.get(server.id),
toolInfo:
info?.toolInfo && info.toolInfo.length > 0
? info.toolInfo
: (server.toolInfo ?? []),
};
return mcpInfo;
});
Expand Down
10 changes: 8 additions & 2 deletions src/app/api/mcp/oauth/callback/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,17 @@ export async function GET(request: NextRequest) {
});
}

const client = await mcpClientsManager.getClient(session.mcpServerId);
const client = await mcpClientsManager.getClient(
session.mcpServerId,
session.userId || undefined,
);

try {
await client?.client.finishAuth(callbackData.code, callbackData.state);
await mcpClientsManager.refreshClient(session.mcpServerId);
await mcpClientsManager.refreshClient(
session.mcpServerId,
session.userId || undefined,
);

return createOAuthResponsePage({
type: "success",
Expand Down
2 changes: 1 addition & 1 deletion src/components/agent/edit-agent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ export default function EditAgent({
});

(mcpList as (MCPServerInfo & { id: string })[])?.forEach((mcp) => {
mcp.toolInfo.forEach((tool) => {
mcp.toolInfo?.forEach((tool) => {
if (toolNames.includes(tool.name)) {
allMentions.push({
type: "mcpTool",
Expand Down
2 changes: 1 addition & 1 deletion src/components/chat-bot-voice.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export function ChatBotVoice() {
.flatMap((v) => {
const tools = allowedMcpServers[v.id].tools;
return tools.map((tool) => {
const toolInfo = v.toolInfo.find((t) => t.name === tool);
const toolInfo = v.toolInfo?.find((t) => t.name === tool);
const mention: ChatMention = {
type: "mcpTool",
serverName: v.name,
Expand Down
Loading
Loading